Unverified Commit 937398cf authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: Flash Indexer (#5785)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Signed-off-by: default avatarjthomson04 <jothomson@nvidia.com>
Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Signed-off-by: default avatarJanelle Cai <jcai18@mit.edu>
Co-authored-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarJanelle Cai <jcai18@mit.edu>
parent de27efe6
......@@ -1735,6 +1735,16 @@ dependencies = [
"memchr",
]
[[package]]
name = "ctor"
version = "0.1.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d2301688392eb071b0bf1a37be05c469d3cc4dbbd95df672fe28ab021e6a096"
dependencies = [
"quote",
"syn 1.0.109",
]
[[package]]
name = "cudarc"
version = "0.17.8"
......@@ -2295,9 +2305,15 @@ dependencies = [
"anyhow",
"async-trait",
"clap 4.5.53",
"dashmap 6.1.0",
"dynamo-mocker",
"dynamo-runtime",
"dynamo-tokens",
"flume",
"futures",
"indicatif 0.18.3",
"minstant",
"parking_lot",
"prometheus",
"rand 0.9.2",
"rstest 0.18.2",
......@@ -2308,6 +2324,7 @@ dependencies = [
"tokio",
"tokio-util",
"tracing",
"uuid 1.18.1",
"xxhash-rust",
]
......@@ -2899,6 +2916,9 @@ name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
dependencies = [
"getrandom 0.2.16",
]
[[package]]
name = "fax"
......@@ -3023,6 +3043,18 @@ dependencies = [
"num-traits",
]
[[package]]
name = "flume"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be"
dependencies = [
"fastrand",
"futures-core",
"futures-sink",
"spin",
]
[[package]]
name = "fnv"
version = "1.0.7"
......@@ -5182,6 +5214,16 @@ dependencies = [
"simd-adler32",
]
[[package]]
name = "minstant"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fb9b5c752f145ac5046bccc3c4f62892e3c950c1d1eab80c5949cd68a2078db"
dependencies = [
"ctor",
"web-time",
]
[[package]]
name = "mio"
version = "0.6.23"
......@@ -8674,6 +8716,15 @@ dependencies = [
"vob",
]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [
"lock_api",
]
[[package]]
name = "spki"
version = "0.7.3"
......
......@@ -250,6 +250,12 @@ def parse_args():
default=False,
help="KV Router: Track output blocks during generation. When enabled, the router adds placeholder blocks as tokens are generated and applies fractional decay based on progress toward expected_output_tokens. By default, output blocks are not tracked.",
)
parser.add_argument(
"--router-event-threads",
type=int,
default=int(os.environ.get("DYN_ROUTER_EVENT_THREADS", "1")),
help="KV Router: Number of event processing threads. When > 1, uses a concurrent radix tree with a thread pool for higher throughput. Can be set via DYN_ROUTER_EVENT_THREADS env var (default: 1).",
)
parser.add_argument(
"--enforce-disagg",
action="store_true",
......@@ -436,6 +442,7 @@ async def async_main():
router_ttl_secs=flags.router_ttl,
router_max_tree_size=flags.router_max_tree_size,
router_prune_target_ratio=flags.router_prune_target_ratio,
router_event_threads=flags.router_event_threads,
)
elif flags.router_mode == "random":
router_mode = RouterMode.Random
......
......@@ -15,6 +15,7 @@ routing decisions.
import argparse
import asyncio
import logging
import os
from typing import Optional
import uvloop
......@@ -256,6 +257,13 @@ def parse_args():
help="KV Router: Target size ratio after pruning (0.0-1.0). Only used when --no-kv-events is set. Determines how aggressively to prune the tree (default: 0.8)",
)
parser.add_argument(
"--router-event-threads",
type=int,
default=int(os.environ.get("DYN_ROUTER_EVENT_THREADS", "1")),
help="KV Router: Number of event processing threads. When > 1, uses a concurrent radix tree with a thread pool for higher throughput. Can be set via DYN_ROUTER_EVENT_THREADS env var (default: 1).",
)
return parser.parse_args()
......@@ -302,6 +310,7 @@ async def worker(runtime: DistributedRuntime):
router_ttl_secs=args.router_ttl_secs,
router_max_tree_size=args.router_max_tree_size,
router_prune_target_ratio=args.router_prune_target_ratio,
router_event_threads=args.router_event_threads,
)
# Create service component - use "router" as component name
......
......@@ -174,6 +174,8 @@ The main KV-aware routing arguments:
- `--router-prune-target-ratio`: Target size ratio to prune down to when `--router-max-tree-size` is exceeded. For example, with a value of 0.8 (default) and max tree size of 1048576, the router will prune down to approximately 838860 blocks when the threshold is exceeded. Defaults to 0.8 when `--no-kv-events` is used. This creates headroom before the next pruning cycle.
- `--router-event-threads`: Number of event processing threads for the KV indexer. When set to 1 (default), the router uses a single-threaded radix tree with channel-based event processing, which supports TTL-based expiration and pruning. When set to a value greater than 1, the router uses a concurrent radix tree with a thread pool of the specified size for higher event throughput. Note: the concurrent indexer does not support TTL/pruning (`--router-ttl`, `--router-max-tree-size`, `--router-prune-target-ratio` are ignored when `--router-event-threads > 1`). Can be set via `DYN_ROUTER_EVENT_THREADS` env var.
>[!Note]
> **State persistence** depends on the event transport mode:
> - **NATS Core / Event Plane mode** (default): State persists on workers—router rebuilds state by querying workers on startup. This is the default when workers have `local_indexer` enabled (which is the default). Works with both NATS Core and ZMQ event planes.
......
......@@ -130,6 +130,12 @@ The KVIndexer builds and maintains a global view of cached blocks in a prefix tr
The KVIndexer has a method `find_matches_for_request`, which takes in tokens and returns a dictionary with keys of worker id and values of the number of matched KV Blocks.
The KVIndexer supports two backend implementations, selected via `--router-event-threads`:
- **Single-threaded RadixTree** (default, `--router-event-threads 1`): Events are processed in a dedicated single-threaded tokio runtime via channel-based dispatch. Supports TTL-based expiration and size-based pruning (for `--no-kv-events` approximate mode).
- **ConcurrentRadixTree** (`--router-event-threads N` where N > 1): A thread-safe radix tree with a pool of N worker threads for event processing. Uses sticky worker routing (events for the same worker always go to the same thread) to ensure per-worker event serialization. Read operations (`find_matches`) execute concurrently with writes. Does not support TTL/pruning.
### Inter-Router Communication
In distributed deployments with multiple routers, each router maintains visibility over only a portion of the total requests. To ensure consistent routing decisions, routers synchronize their states through three event types:
......
......@@ -1613,8 +1613,11 @@ version = "0.9.0"
dependencies = [
"anyhow",
"async-trait",
"dashmap 6.1.0",
"dynamo-runtime",
"dynamo-tokens",
"flume",
"parking_lot",
"prometheus",
"rand 0.9.2",
"serde",
......@@ -2139,6 +2142,9 @@ name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
dependencies = [
"getrandom 0.2.16",
]
[[package]]
name = "fax"
......@@ -2252,6 +2258,18 @@ dependencies = [
"miniz_oxide",
]
[[package]]
name = "flume"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be"
dependencies = [
"fastrand",
"futures-core",
"futures-sink",
"spin",
]
[[package]]
name = "fnv"
version = "1.0.7"
......@@ -6757,6 +6775,15 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [
"lock_api",
]
[[package]]
name = "spki"
version = "0.7.3"
......
......@@ -52,7 +52,7 @@ impl KvRouterConfig {
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8))]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_event_threads=1))]
#[allow(clippy::too_many_arguments)]
fn new(
overlap_score_weight: f64,
......@@ -68,6 +68,7 @@ impl KvRouterConfig {
router_ttl_secs: f64,
router_max_tree_size: usize,
router_prune_target_ratio: f64,
router_event_threads: u32,
) -> Self {
KvRouterConfig {
inner: RsKvRouterConfig {
......@@ -84,6 +85,7 @@ impl KvRouterConfig {
router_ttl_secs,
router_max_tree_size,
router_prune_target_ratio,
router_event_threads,
},
}
}
......
......@@ -995,6 +995,7 @@ class KvRouterConfig:
router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1048576,
router_prune_target_ratio: float = 0.8,
router_event_threads: int = 1,
) -> None:
"""
Create a KV router configuration.
......@@ -1018,6 +1019,8 @@ class KvRouterConfig:
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
router_event_threads: Number of event processing threads (default: 1).
When > 1, uses a concurrent radix tree with a thread pool.
"""
...
......
......@@ -13,7 +13,7 @@ repository.workspace = true
[features]
default = []
metrics = ["dep:dynamo-runtime"]
bench = ["dep:clap", "dep:indicatif"]
bench = ["dep:clap", "dep:indicatif", "dep:serde_json", "dynamo-runtime/integration", "dep:uuid"]
[dependencies]
# repo
......@@ -23,24 +23,35 @@ dynamo-tokens = { workspace = true }
# workspace
anyhow = { workspace = true }
async-trait = { workspace = true }
dashmap = { workspace = true }
prometheus = { workspace = true }
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true, optional = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }
tracing = { workspace = true }
xxhash-rust = { workspace = true }
# dependencies
flume = "0.12.0"
parking_lot = { workspace = true }
# bench (optional)
clap = { version = "4.5", features = ["derive"], optional = true }
indicatif = { version = "0.18.0", optional = true }
uuid = { workspace = true, optional = true }
[dev-dependencies]
rstest = "0.18.2"
rstest_reuse = "0.7.0"
serde_json = { workspace = true }
tokio = { workspace = true, features = ["rt", "macros", "time"] }
dynamo-mocker = { workspace = true }
minstant = "0.1.7"
futures = "0.3"
[[bench]]
name = "radix_tree_microbench"
......@@ -51,3 +62,8 @@ required-features = ["bench"]
name = "kv_indexer_bench"
harness = false
required-features = ["bench"]
[[bench]]
name = "mooncake_bench"
harness = false
required-features = ["bench"]
# KV Router Index Data Structures
This document explains the KV cache index implementations: `RadixTree`, `ConcurrentRadixTree`, and `PositionalIndexer` (NestedMap).
## Motivation: The Four Block Identifiers
Every cached KV block in a distributed LLM system needs four pieces of information:
### 1. Local Block Hash (`LocalBlockHash`, u64)
**What**: Hash of the tokens *within* a single block (e.g., 64 tokens).
**Why**: Identifies the content of this specific block, independent of context. Two blocks with the same tokens have the same local hash.
```
Block at position 5: tokens [101, 102, 103, ...]
LocalBlockHash = hash(tokens) = 0xABCD1234
```
### 2. External Sequence Block Hash (`ExternalSequenceBlockHash`, u64)
**What**: Cumulative hash of the entire sequence up to and including this block.
**Why**: Uniquely identifies a block's position in a *specific* sequence history. Two blocks with the same local content but different prefixes have different sequence hashes.
```
Sequence A: [block0, block1, block2]
Sequence B: [block0', block1', block2] // block2 has same content but different prefix
block2 in A: seq_hash = hash(hash(hash(block0) + block1) + block2) = 0x1111
block2 in B: seq_hash = hash(hash(hash(block0') + block1') + block2) = 0x2222
```
**Computation**: `seq_hash[i] = hash(seq_hash[i-1] || local_hash[i])` where `seq_hash[0] = local_hash[0]`
> **Important: Engine-Provided Hashes**
>
> In practice, the `ExternalSequenceBlockHash` may come directly from the inference engine (e.g., vLLM, TensorRT-LLM) using a rolling hash algorithm that we don't know or control. The engine computes these hashes internally and reports them via KV cache events.
>
> **Implications for index implementations:**
>
> - **RadixTree**: Can handle engine-provided hashes because it traverses the tree structure using `LocalBlockHash` for navigation and only uses `ExternalSequenceBlockHash` as an opaque identifier for lookups. It doesn't need to recompute hashes.
>
> - **NestedMap**: Requires the ability to compute `ExternalSequenceBlockHash` incrementally for its lazy hash optimization in `find_matches`. To use NestedMap, one of the following is required:
> 1. **Force a known hasher**: Configure the engine to use a specific hashing algorithm that the router can replicate, OR
> 2. **Recompute on the relay**: Have the publisher/relay layer recompute the rolling hash using a known algorithm before forwarding events to the router.
>
> Without this, NestedMap's `find_matches` will fail when encountering `SeqEntry::Multi` cases (multiple seq_hashes at the same position+local_hash) because it cannot disambiguate which entry to use.
### 3. Worker ID (`WorkerWithDpRank`)
**What**: Identifies which worker (inference server) has this block cached.
**Why**: The router needs to know which workers can serve a request based on their cached blocks.
### 4. Position (`u64`)
**What**: The block's index in the sequence (0, 1, 2, ...).
**Why**: Enables efficient prefix matching. Position 0 is the first block, position N-1 is the last.
---
## The Core Operations
Both data structures support three operations:
| Operation | Description | Hot Path? |
|-----------|-------------|-----------|
| `store_blocks` | Add blocks for a worker | No (background) |
| `remove_blocks` | Remove blocks for a worker | No (background) |
| `find_matches` | Find workers with matching prefix | **Yes** (per-request) |
The key insight: **reads (find_matches) are far more frequent than writes (store/remove)**. This motivates different structural tradeoffs.
---
## RadixTree: Tree-Based Index
### Structure
```
RadixTree
├── root: SharedRadixBlock (Rc<RefCell<RadixBlock>>)
└── lookup: HashMap<Worker, HashMap<SeqHash, SharedRadixBlock>>
RadixBlock
├── children: HashMap<LocalBlockHash, SharedRadixBlock>
├── workers: HashSet<Worker>
├── block_hash: Option<SeqHash>
└── recent_uses: VecDeque<Instant>
```
### Visual Representation
```
[root]
/ \
local=0xA local=0xB
↓ ↓
[block0] [block0']
workers: workers:
{W0,W1} {W2}
|
local=0xC
[block1]
workers:
{W0,W1}
|
local=0xD
[block2]
workers:
{W0} ← W1 diverged here
```
### How Operations Work
**store_blocks(worker, parent_hash, blocks)**:
1. Find parent via `lookup[worker][parent_hash]`
2. For each block, traverse/create child nodes using `local_hash`
3. Add worker to each node's `workers` set
4. Update `lookup[worker][seq_hash] = node`
**remove_blocks(worker, block_hashes)**:
1. For each hash, find node via `lookup[worker][hash]`
2. Remove worker from node's `workers` set
3. If `workers` empty, clear children (cascading cleanup)
4. Remove from `lookup[worker]`
**find_matches(local_hashes, early_exit)**:
1. Start at root with all workers as candidates
2. For each position, traverse to child matching `local_hash`
3. Intersect candidates with node's `workers`
4. Track depth where each worker drops out
5. Return `{worker -> depth}` scores
### Complexity
| Operation | Time | Space |
|-----------|------|-------|
| store_blocks (N blocks) | O(N) | O(N) nodes |
| remove_blocks (N blocks) | O(N) | - |
| find_matches (depth D) | O(D × W) | O(W) |
Where W = number of workers.
---
## PositionalIndexer (NestedMap): Position-First HashMap Index
### Structure
```
PositionalIndexer
├── index: DashMap<(Position, LocalHash), SeqEntry>
├── worker_blocks: DashMap<Worker, RwLock<HashMap<SeqHash, (Position, LocalHash)>>>
└── jump_size: usize
SeqEntry (enum for memory optimization)
├── Single(SeqHash, HashSet<Worker>) // Common case: one seq_hash
└── Multi(HashMap<SeqHash, HashSet<Worker>>) // Rare: multiple prefixes
```
`PositionalIndexer` implements `SyncIndexer` and is thread-safe via `DashMap` (sharded
concurrent map) and `RwLock`. It is designed to be wrapped in a `ThreadPoolIndexer` which
routes write events to dedicated OS threads and executes reads inline.
The `index` uses a flat compound key `(position, local_hash)` in a `DashMap`, which
distributes lock contention across shards while enabling O(1) random-position access for
the jump optimization. The `worker_blocks` reverse lookup uses `DashMap` for the outer
per-worker map and `RwLock<HashMap>` for each worker's block set, since writes to a
given worker are serialized by sticky routing in `ThreadPoolIndexer`.
### Visual Representation
```
index (DashMap with compound keys):
┌──────────────────────┬──────────────────────────────────────┐
│ (pos=0, local=0xA) │ Single(seq=0x1111, {W0,W1}) │
│ (pos=0, local=0xB) │ Single(seq=0x2222, {W2}) │
│ (pos=1, local=0xC) │ Single(seq=0x3333, {W0,W1}) │
│ (pos=2, local=0xD) │ Multi{ │
│ │ seq=0x4444 → {W0}, │
│ │ seq=0x5555 → {W1} ← diverged │
│ │ } │
└──────────────────────┴──────────────────────────────────────┘
worker_blocks (DashMap<Worker, RwLock<HashMap>>):
┌─────────┬─────────────────────────────────────────────────┐
│ W0 │ seq=0x1111 → (pos=0, local=0xA) │
│ │ seq=0x3333 → (pos=1, local=0xC) │
│ │ seq=0x4444 → (pos=2, local=0xD) │
├─────────┼─────────────────────────────────────────────────┤
│ W1 │ seq=0x1111 → (pos=0, local=0xA) │
│ │ seq=0x3333 → (pos=1, local=0xC) │
│ │ seq=0x5555 → (pos=2, local=0xD) │
└─────────┴─────────────────────────────────────────────────┘
```
### How Operations Work
**store_blocks(worker, parent_hash, blocks)**:
1. Find starting position: `pos = worker_blocks[worker][parent_hash].position + 1`
2. For each block at position `i`:
- Insert into `index[(pos+i, local_hash)]` → add worker to SeqEntry
- Insert into `worker_blocks[worker][seq_hash] = (pos+i, local_hash)`
**remove_blocks(worker, block_hashes)**:
1. For each hash, lookup `(pos, local_hash) = worker_blocks[worker][hash]`
2. Remove worker from `index[(pos, local_hash)]`
3. Remove from `worker_blocks[worker]`
4. Cleanup empty SeqEntry entries from the DashMap
**find_matches(local_hashes, early_exit)** with Jump Optimization:
1. Start at position 0, initialize candidates from first block
2. **Jump**: Skip ahead by `jump_size` positions (e.g., 32)
3. At each jump point, check if candidates still match (count-only, no clone)
4. If workers dropped, **scan back** to find exact drain points
5. Continue until sequence exhausted or one worker remains
```
Query: [b0, b1, b2, ..., b63, b64, ..., b127, ...]
↑ ↑ ↑
pos=0 pos=64 pos=128
│ │ │
└── jump ──────────→└── jump ─────────→│
all match? some dropped?
↓ ↓
continue scan [64,128]
```
**Lazy Hash Optimization**:
- Most (position, local_hash) pairs have only ONE seq_hash (SeqEntry::Single)
- Skip seq_hash computation entirely in this case
- Only compute when disambiguation needed (SeqEntry::Multi)
**dump_events()**:
1. Iterate `worker_blocks`, collecting all blocks per worker
2. Sort each worker's blocks by position (parents before children)
3. Emit one single-block `RouterEvent::Stored` per block, synthesizing
`parent_hash` from any seq_hash at the prior position
4. Events can be replayed into a fresh `PositionalIndexer` to reconstruct
the same index state
### Complexity
| Operation | Time | Space |
|-----------|------|-------|
| store_blocks (N blocks) | O(N) | O(N) entries |
| remove_blocks (N blocks) | O(N) | - |
| find_matches (depth D) | O(D/J + J×W) | O(W) |
Where J = jump_size, W = number of workers. The jump optimization reduces D iterations to D/J jumps plus occasional scans.
---
## Comparison
| Aspect | RadixTree | PositionalIndexer |
|--------|-----------|-------------------|
| **Structure** | Tree with Rc<RefCell<>> nodes | DashMap with compound keys |
| **Concurrent variant** | ConcurrentRadixTree (Arc<RwLock<>> + DashMap) | Thread-safe by default (DashMap + RwLock) |
| **find_matches** | O(D×W) tree traversal | O(D/J) with jump optimization |
| **store_blocks** | O(N) node creation | O(N) DashMap inserts |
| **remove_blocks** | O(N) with cascading cleanup | O(N) with entry cleanup |
| **dump_events** | BFS traversal of tree | Sort by position per worker |
| **Memory** | Higher (Rc/Arc overhead per node) | Lower (flat entries) |
| **Cache locality** | Poor (pointer chasing) | Better (position-first) |
### Benchmark Results (1M blocks, depth 1024, 128 workers)
| Operation | RadixTree | NestedMap | Winner |
|-----------|-----------|-----------|--------|
| STORE_BLOCK | 90µs | 98µs | RadixTree (1.1x) |
| REMOVE_BLOCK | 91µs | 233µs | RadixTree (2.5x) |
| FIND_MATCHES (HIT) | 227µs | **44µs** | **NestedMap (5.2x)** |
| FIND_MATCHES (PARTIAL) | 216µs | **44µs** | **NestedMap (4.9x)** |
**Recommendation**: Use NestedMap for read-heavy workloads (typical router usage).
---
## Why Position Matters for PositionalIndexer
The compound key `(position, local_hash)` in the DashMap enables the jump optimization:
```rust
// Without position-first: must traverse entire tree
for pos in 0..depth {
node = node.children[local_hashes[pos]]; // O(depth) traversals
}
// With position-first: can jump directly to any position
let workers_at_64 = index.get(&(64, local_hashes[64])); // O(1) lookup
let workers_at_128 = index.get(&(128, local_hashes[128])); // O(1) lookup
// Skip positions 1-63, 65-127 entirely!
```
---
## SeqEntry Optimization
The innermost level uses an enum to avoid HashMap allocation in the common case:
```rust
enum SeqEntry {
// Common: one prefix leads to this (position, local_hash)
Single(SeqHash, HashSet<Worker>),
// Rare: different prefixes converge on same (position, local_hash)
Multi(HashMap<SeqHash, HashSet<Worker>>),
}
```
**When does Multi occur?**
Only when two different sequences have:
1. Same local block content at position P
2. Different prefix histories (different seq_hash)
Example:
```
Sequence A: [tok1, tok2, tok3] → positions 0,1,2
Sequence B: [tok4, tok5, tok3] → positions 0,1,2
^^^^
Same local content at pos=2
but different seq_hash!
```
This is rare in practice, so `Single` saves ~48 bytes per entry.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Combined benchmark for KvIndexer and KvIndexerSharded.
//! Combined benchmark for KvIndexer, KvIndexerSharded, and PositionalIndexer (nested).
//!
//! Provides two modes:
//! - `microbench`: Per-operation latency benchmarks comparing single vs sharded indexer
//! - `microbench`: Per-operation latency benchmarks comparing indexer implementations
//! - `stress`: Queue saturation stress test under load
//!
//! Supported indexer types: single, sharded, nested, all
//!
//! Run with:
//! cargo bench --package dynamo-kv-router --bench kv_indexer_bench --features bench -- microbench --help
//! cargo bench --package dynamo-kv-router --bench kv_indexer_bench --features bench -- stress --help
use clap::{Args, Parser, Subcommand, ValueEnum};
use dynamo_kv_router::{
bench_utils::{LatencyStats, SequenceData, generate_sequences, median},
indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvIndexerSharded},
ConcurrentRadixTree,
bench_utils::{LatencyStats, SequenceData, generate_sequences},
indexer::{
KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvIndexerSharded, ThreadPoolIndexer,
},
nested_map::PositionalIndexer,
protocols::{LocalBlockHash, RouterEvent},
};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio::time::interval;
use tokio_util::sync::CancellationToken;
// ============================================================================
......@@ -29,7 +36,7 @@ use tokio_util::sync::CancellationToken;
#[derive(Parser)]
#[command(name = "kv_indexer_bench")]
#[command(about = "Combined benchmark for KvIndexer and KvIndexerSharded")]
#[command(about = "Combined benchmark for KvIndexer, KvIndexerSharded, and PositionalIndexer")]
struct Cli {
#[command(subcommand)]
command: Command,
......@@ -41,7 +48,7 @@ struct Cli {
#[derive(Subcommand)]
enum Command {
/// Per-operation latency benchmarks comparing single vs sharded indexer
/// Per-operation latency benchmarks comparing indexer implementations
Microbench(MicrobenchArgs),
/// Queue saturation stress test under load
Stress(StressArgs),
......@@ -54,8 +61,12 @@ enum IndexerType {
Single,
/// Sharded KvIndexer (multiple shards with separate trees)
Sharded,
/// Run both and compare
Both,
/// Nested PositionalIndexer (position-based HashMap with jump search)
Nested,
/// Concurrent radix tree (lock-per-node with DashMap lookup)
Concurrent,
/// Run all indexer types and compare
All,
}
/// Common arguments shared between subcommands
......@@ -104,13 +115,17 @@ struct MicrobenchArgs {
num_prefix_prompts: usize,
/// Indexer type to benchmark
#[arg(long, value_enum, default_value = "both")]
#[arg(long, value_enum, default_value = "all")]
indexer_type: IndexerType,
/// Number of shards for sharded indexer
#[arg(long, default_value = "4")]
num_shards: usize,
/// Jump size for nested/positional indexer
#[arg(long, default_value = "32")]
jump_size: usize,
/// Run only specific benchmark (store, find_matches, remove, or all)
#[arg(long, default_value = "all")]
benchmark_type: String,
......@@ -148,6 +163,10 @@ struct StressArgs {
/// Number of shards for sharded indexer
#[arg(long, default_value = "4")]
num_shards: usize,
/// Jump size for nested/positional indexer
#[arg(long, default_value = "32")]
jump_size: usize,
}
// ============================================================================
......@@ -203,6 +222,44 @@ impl BenchableIndexer for KvIndexerSharded {
}
}
#[async_trait::async_trait]
impl BenchableIndexer for ThreadPoolIndexer<PositionalIndexer> {
async fn apply_event(&mut self, event: RouterEvent) {
KvIndexerInterface::apply_event(self, event).await;
}
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<(), dynamo_kv_router::indexer::KvRouterError> {
KvIndexerInterface::find_matches(self, sequence).await?;
Ok(())
}
fn name(&self) -> &str {
"PositionalIndexer (nested)"
}
}
#[async_trait::async_trait]
impl BenchableIndexer for ThreadPoolIndexer<ConcurrentRadixTree> {
async fn apply_event(&mut self, event: RouterEvent) {
KvIndexerInterface::apply_event(self, event).await;
}
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<(), dynamo_kv_router::indexer::KvRouterError> {
KvIndexerInterface::find_matches(self, sequence).await?;
Ok(())
}
fn name(&self) -> &str {
"ConcurrentRadixTree"
}
}
// ============================================================================
// Microbench Mode
// ============================================================================
......@@ -523,122 +580,111 @@ fn print_microbench_comparison(results: &[MicrobenchResults], _depth: usize) {
println!("COMPARISON SUMMARY");
println!("========================================\n");
let single = &results[0];
let sharded = &results[1];
// Build dynamic column headers
let mut header = format!("{:<30}", "Metric");
for result in results {
header.push_str(&format!(
" {:>15}",
result
.indexer_name
.split_whitespace()
.next()
.unwrap_or(&result.indexer_name)
));
}
println!("{}", header);
println!("{}", "-".repeat(30 + results.len() * 16));
println!(
"{:<30} {:>15} {:>15} {:>10}",
"Metric", "Single", "Sharded", "Ratio"
);
println!("{}", "-".repeat(72));
// Construction
let single_constr = single.construction_time.as_secs_f64() * 1000.0;
let sharded_constr = sharded.construction_time.as_secs_f64() * 1000.0;
println!(
"{:<30} {:>12.2}ms {:>12.2}ms {:>9.2}x",
"Construction time",
single_constr,
sharded_constr,
single_constr / sharded_constr
);
// Construction time
let mut row = format!("{:<30}", "Construction time (ms)");
for result in results {
row.push_str(&format!(
" {:>15.2}",
result.construction_time.as_secs_f64() * 1000.0
));
}
println!("{}", row);
// Store p50
if let (Some(s1), Some(s2)) = (&single.store_stats, &sharded.store_stats) {
let s1_us = s1.p50.as_nanos() as f64 / 1000.0;
let s2_us = s2.p50.as_nanos() as f64 / 1000.0;
println!(
"{:<30} {:>12.2}us {:>12.2}us {:>9.2}x",
"Store p50",
s1_us,
s2_us,
s1_us / s2_us
);
let mut row = format!("{:<30}", "Store p50 (us)");
for result in results {
if let Some(stats) = &result.store_stats {
row.push_str(&format!(" {:>15.2}", stats.p50.as_nanos() as f64 / 1000.0));
} else {
row.push_str(&format!(" {:>15}", "-"));
}
}
println!("{}", row);
// Find matches hit p50
if let (Some(s1), Some(s2)) = (
&single.find_matches_hit_stats,
&sharded.find_matches_hit_stats,
) {
let s1_us = s1.p50.as_nanos() as f64 / 1000.0;
let s2_us = s2.p50.as_nanos() as f64 / 1000.0;
println!(
"{:<30} {:>12.2}us {:>12.2}us {:>9.2}x",
"Find matches (hit) p50",
s1_us,
s2_us,
s1_us / s2_us
);
let mut row = format!("{:<30}", "Find hit p50 (us)");
for result in results {
if let Some(stats) = &result.find_matches_hit_stats {
row.push_str(&format!(" {:>15.2}", stats.p50.as_nanos() as f64 / 1000.0));
} else {
row.push_str(&format!(" {:>15}", "-"));
}
}
println!("{}", row);
// Find matches hit p99
if let (Some(s1), Some(s2)) = (
&single.find_matches_hit_stats,
&sharded.find_matches_hit_stats,
) {
let s1_us = s1.p99.as_nanos() as f64 / 1000.0;
let s2_us = s2.p99.as_nanos() as f64 / 1000.0;
println!(
"{:<30} {:>12.2}us {:>12.2}us {:>9.2}x",
"Find matches (hit) p99",
s1_us,
s2_us,
s1_us / s2_us
);
let mut row = format!("{:<30}", "Find hit p99 (us)");
for result in results {
if let Some(stats) = &result.find_matches_hit_stats {
row.push_str(&format!(" {:>15.2}", stats.p99.as_nanos() as f64 / 1000.0));
} else {
row.push_str(&format!(" {:>15}", "-"));
}
}
println!("{}", row);
// Find matches miss p50
if let (Some(s1), Some(s2)) = (
&single.find_matches_miss_stats,
&sharded.find_matches_miss_stats,
) {
let s1_us = s1.p50.as_nanos() as f64 / 1000.0;
let s2_us = s2.p50.as_nanos() as f64 / 1000.0;
println!(
"{:<30} {:>12.2}us {:>12.2}us {:>9.2}x",
"Find matches (miss) p50",
s1_us,
s2_us,
s1_us / s2_us
);
let mut row = format!("{:<30}", "Find miss p50 (us)");
for result in results {
if let Some(stats) = &result.find_matches_miss_stats {
row.push_str(&format!(" {:>15.2}", stats.p50.as_nanos() as f64 / 1000.0));
} else {
row.push_str(&format!(" {:>15}", "-"));
}
}
println!("{}", row);
// Remove p50
if let (Some(s1), Some(s2)) = (&single.remove_stats, &sharded.remove_stats) {
let s1_us = s1.p50.as_nanos() as f64 / 1000.0;
let s2_us = s2.p50.as_nanos() as f64 / 1000.0;
println!(
"{:<30} {:>12.2}us {:>12.2}us {:>9.2}x",
"Remove p50",
s1_us,
s2_us,
s1_us / s2_us
);
let mut row = format!("{:<30}", "Remove p50 (us)");
for result in results {
if let Some(stats) = &result.remove_stats {
row.push_str(&format!(" {:>15.2}", stats.p50.as_nanos() as f64 / 1000.0));
} else {
row.push_str(&format!(" {:>15}", "-"));
}
}
println!("{}", row);
// Throughput comparison
// Throughput
println!();
println!(
"{:<30} {:>15} {:>15} {:>10}",
"Throughput (ops/sec)", "Single", "Sharded", "Ratio"
);
println!("{}", "-".repeat(72));
if let (Some(s1), Some(s2)) = (
&single.find_matches_hit_stats,
&sharded.find_matches_hit_stats,
) {
println!(
"{:<30} {:>12.0}/s {:>12.0}/s {:>9.2}x",
"Find matches (hit)",
s1.throughput_ops_sec,
s2.throughput_ops_sec,
s2.throughput_ops_sec / s1.throughput_ops_sec
);
let mut header = format!("{:<30}", "Throughput (ops/sec)");
for result in results {
header.push_str(&format!(
" {:>15}",
result
.indexer_name
.split_whitespace()
.next()
.unwrap_or(&result.indexer_name)
));
}
println!("{}", header);
println!("{}", "-".repeat(30 + results.len() * 16));
let mut row = format!("{:<30}", "Find matches (hit)");
for result in results {
if let Some(stats) = &result.find_matches_hit_stats {
row.push_str(&format!(" {:>15.0}", stats.throughput_ops_sec));
} else {
row.push_str(&format!(" {:>15}", "-"));
}
println!("\nNote: Ratio > 1.0 means sharded is faster for that metric.");
}
println!("{}", row);
}
async fn run_microbench_mode(args: MicrobenchArgs) {
......@@ -692,7 +738,7 @@ async fn run_microbench_mode(args: MicrobenchArgs) {
let mut results = Vec::new();
// Benchmark single indexer
if matches!(args.indexer_type, IndexerType::Single | IndexerType::Both) {
if matches!(args.indexer_type, IndexerType::Single | IndexerType::All) {
let token = CancellationToken::new();
let mut indexer = KvIndexer::new(token.clone(), args.common.block_size, metrics.clone());
let result = run_microbenchmarks(&mut indexer, sequences, extra_sequences, &args).await;
......@@ -702,7 +748,7 @@ async fn run_microbench_mode(args: MicrobenchArgs) {
}
// Benchmark sharded indexer
if matches!(args.indexer_type, IndexerType::Sharded | IndexerType::Both) {
if matches!(args.indexer_type, IndexerType::Sharded | IndexerType::All) {
let token = CancellationToken::new();
let mut indexer = KvIndexerSharded::new(
token.clone(),
......@@ -716,6 +762,35 @@ async fn run_microbench_mode(args: MicrobenchArgs) {
tokio::time::sleep(Duration::from_millis(50)).await;
}
// Benchmark nested indexer
if matches!(args.indexer_type, IndexerType::Nested | IndexerType::All) {
let mut indexer = ThreadPoolIndexer::new(
PositionalIndexer::new(args.jump_size),
args.num_shards,
args.common.block_size,
);
let result = run_microbenchmarks(&mut indexer, sequences, extra_sequences, &args).await;
results.push(result);
indexer.shutdown();
tokio::time::sleep(Duration::from_millis(50)).await;
}
// Benchmark concurrent radix tree indexer
if matches!(
args.indexer_type,
IndexerType::Concurrent | IndexerType::All
) {
let mut indexer = ThreadPoolIndexer::new(
ConcurrentRadixTree::new(),
args.num_shards,
args.common.block_size,
);
let result = run_microbenchmarks(&mut indexer, sequences, extra_sequences, &args).await;
results.push(result);
indexer.shutdown();
tokio::time::sleep(Duration::from_millis(50)).await;
}
// Print results
if args.format == "csv" {
MicrobenchResults::print_csv_header();
......@@ -727,7 +802,7 @@ async fn run_microbench_mode(args: MicrobenchArgs) {
result.print(args.common.depth);
}
if results.len() == 2 {
if results.len() >= 2 {
print_microbench_comparison(&results, args.common.depth);
}
}
......@@ -795,16 +870,6 @@ async fn run_stress_test<I: BenchableIndexer + 'static>(
// Phase 3: Pre-generate Lookup Sequences
println!("\nPhase 3: Pre-generating Lookup Sequences");
let expected_requests = (args.arrival_rate * args.duration as f64).ceil() as usize + 100;
let lookup_sequences: Vec<Vec<LocalBlockHash>> = (0..expected_requests)
.map(|i| {
let seq = &sequences[i % sequences.len()];
seq.local_hashes.clone()
})
.collect();
println!(
" Pre-generated {} lookup sequences",
lookup_sequences.len()
);
// Phase 4: Stress Test
println!("\nPhase 4: Stress Test");
......@@ -817,11 +882,12 @@ async fn run_stress_test<I: BenchableIndexer + 'static>(
let start = Instant::now();
let mut request_id = 0u64;
let interval = Duration::from_secs_f64(1.0 / args.arrival_rate);
let mut interval = interval(Duration::from_secs_f64(1.0 / args.arrival_rate));
while start.elapsed() < Duration::from_secs(args.duration) {
let submit_time = Instant::now();
let seq = lookup_sequences[request_id as usize].clone();
let seq = sequences[request_id as usize % sequences.len()]
.local_hashes
.clone();
// Track in-flight
let current = in_flight.fetch_add(1, Ordering::Relaxed) + 1;
......@@ -834,6 +900,7 @@ async fn run_stress_test<I: BenchableIndexer + 'static>(
let verbose = args.common.verbose;
tokio::spawn(async move {
let submit_time = Instant::now();
let result = indexer.find_matches(seq).await;
let complete_time = Instant::now();
in_flight_clone.fetch_sub(1, Ordering::Relaxed);
......@@ -854,7 +921,7 @@ async fn run_stress_test<I: BenchableIndexer + 'static>(
});
request_id += 1;
tokio::time::sleep(interval).await;
interval.tick().await;
}
let submitted = request_id;
......@@ -1039,96 +1106,90 @@ fn print_stress_comparison(results: &[StressResults], args: &StressArgs) {
println!("STRESS TEST COMPARISON SUMMARY");
println!("========================================\n");
let single = &results[0];
let sharded = &results[1];
println!(
"{:<35} {:>18} {:>18} {:>10}",
"Metric", "Single", "Sharded", "Ratio"
);
println!("{}", "-".repeat(85));
// Build dynamic column headers
let mut header = format!("{:<35}", "Metric");
for result in results {
let short_name = result
.indexer_name
.split_whitespace()
.next()
.unwrap_or(&result.indexer_name);
header.push_str(&format!(" {:>18}", short_name));
}
println!("{}", header);
println!("{}", "-".repeat(35 + results.len() * 19));
// Construction time
let single_constr = single.construction_time.as_secs_f64() * 1000.0;
let sharded_constr = sharded.construction_time.as_secs_f64() * 1000.0;
println!(
"{:<35} {:>15.2}ms {:>15.2}ms {:>9.2}x",
"Construction time",
single_constr,
sharded_constr,
single_constr / sharded_constr
);
let mut row = format!("{:<35}", "Construction time (ms)");
for result in results {
row.push_str(&format!(
" {:>18.2}",
result.construction_time.as_secs_f64() * 1000.0
));
}
println!("{}", row);
// Baseline service time
let single_baseline = single.baseline_service_time.as_nanos() as f64 / 1000.0;
let sharded_baseline = sharded.baseline_service_time.as_nanos() as f64 / 1000.0;
println!(
"{:<35} {:>15.2}us {:>15.2}us {:>9.2}x",
"Baseline service time",
single_baseline,
sharded_baseline,
single_baseline / sharded_baseline
);
let mut row = format!("{:<35}", "Baseline service time (us)");
for result in results {
row.push_str(&format!(
" {:>18.2}",
result.baseline_service_time.as_nanos() as f64 / 1000.0
));
}
println!("{}", row);
// Completed requests
println!(
"{:<35} {:>18} {:>18} {:>9.2}x",
"Completed requests",
single.completed,
sharded.completed,
sharded.completed as f64 / single.completed as f64
);
let mut row = format!("{:<35}", "Completed requests");
for result in results {
row.push_str(&format!(" {:>18}", result.completed));
}
println!("{}", row);
// Max in-flight
println!(
"{:<35} {:>18} {:>18}",
"Max in-flight", single.max_in_flight, sharded.max_in_flight
);
let mut row = format!("{:<35}", "Max in-flight");
for result in results {
row.push_str(&format!(" {:>18}", result.max_in_flight));
}
println!("{}", row);
// Timed out
println!(
"{:<35} {:>18} {:>18}",
"Timed out", single.timed_out, sharded.timed_out
);
let mut row = format!("{:<35}", "Timed out");
for result in results {
row.push_str(&format!(" {:>18}", result.timed_out));
}
println!("{}", row);
// Latency comparison
if let (Some(s1), Some(s2)) = (
LatencyStats::from_durations(single.latencies.clone()),
LatencyStats::from_durations(sharded.latencies.clone()),
) {
let s1_p50 = s1.p50.as_nanos() as f64 / 1000.0;
let s2_p50 = s2.p50.as_nanos() as f64 / 1000.0;
println!(
"{:<35} {:>15.2}us {:>15.2}us {:>9.2}x",
"Latency p50",
s1_p50,
s2_p50,
s1_p50 / s2_p50
);
// Latency p50
let mut row = format!("{:<35}", "Latency p50 (us)");
for result in results {
if let Some(stats) = LatencyStats::from_durations(result.latencies.clone()) {
row.push_str(&format!(" {:>18.2}", stats.p50.as_nanos() as f64 / 1000.0));
} else {
row.push_str(&format!(" {:>18}", "-"));
}
}
println!("{}", row);
let s1_p99 = s1.p99.as_nanos() as f64 / 1000.0;
let s2_p99 = s2.p99.as_nanos() as f64 / 1000.0;
println!(
"{:<35} {:>15.2}us {:>15.2}us {:>9.2}x",
"Latency p99",
s1_p99,
s2_p99,
s1_p99 / s2_p99
);
// Latency p99
let mut row = format!("{:<35}", "Latency p99 (us)");
for result in results {
if let Some(stats) = LatencyStats::from_durations(result.latencies.clone()) {
row.push_str(&format!(" {:>18.2}", stats.p99.as_nanos() as f64 / 1000.0));
} else {
row.push_str(&format!(" {:>18}", "-"));
}
}
println!("{}", row);
// Achieved throughput
let test_duration = args.duration as f64 + args.in_flight_timeout as f64;
let s1_throughput = single.completed as f64 / test_duration;
let s2_throughput = sharded.completed as f64 / test_duration;
println!(
"{:<35} {:>14.1}/s {:>14.1}/s {:>9.2}x",
"Achieved throughput",
s1_throughput,
s2_throughput,
s2_throughput / s1_throughput
);
let mut row = format!("{:<35}", "Achieved throughput (req/s)");
for result in results {
let throughput = result.completed as f64 / test_duration;
row.push_str(&format!(" {:>18.1}", throughput));
}
println!("\nNote: Ratio > 1.0 means sharded is better for that metric.");
println!("{}", row);
}
async fn run_stress_mode(args: StressArgs) {
......@@ -1159,9 +1220,12 @@ async fn run_stress_mode(args: StressArgs) {
eprintln!("Error: arrival_rate must be > 0.0");
std::process::exit(1);
}
if matches!(args.indexer_type, IndexerType::Sharded | IndexerType::Both) && args.num_shards == 0
if matches!(
args.indexer_type,
IndexerType::Sharded | IndexerType::Nested | IndexerType::All
) && args.num_shards == 0
{
eprintln!("Error: num_shards must be > 0 when using Sharded or Both indexer type");
eprintln!("Error: num_shards must be > 0 when using Sharded, Nested, or All indexer type");
std::process::exit(1);
}
......@@ -1186,8 +1250,12 @@ async fn run_stress_mode(args: StressArgs) {
println!(" Duration: {}s", args.duration);
println!(" In-flight timeout: {}s", args.in_flight_timeout);
println!(" Indexer type: {:?}", args.indexer_type);
if matches!(args.indexer_type, IndexerType::Sharded | IndexerType::Both) {
println!(" Num shards: {}", args.num_shards);
if matches!(args.indexer_type, IndexerType::Sharded | IndexerType::All) {
println!(" Num shards (sharded): {}", args.num_shards);
}
if matches!(args.indexer_type, IndexerType::Nested | IndexerType::All) {
println!(" Num workers (nested): {}", args.num_shards);
println!(" Jump size (nested): {}", args.jump_size);
}
// Generate sequences
......@@ -1209,7 +1277,7 @@ async fn run_stress_mode(args: StressArgs) {
let mut all_results = Vec::new();
// Test single indexer
if matches!(args.indexer_type, IndexerType::Single | IndexerType::Both) {
if matches!(args.indexer_type, IndexerType::Single | IndexerType::All) {
let token = CancellationToken::new();
let mut indexer = KvIndexer::new(token.clone(), args.common.block_size, metrics.clone());
......@@ -1251,7 +1319,7 @@ async fn run_stress_mode(args: StressArgs) {
}
// Test sharded indexer
if matches!(args.indexer_type, IndexerType::Sharded | IndexerType::Both) {
if matches!(args.indexer_type, IndexerType::Sharded | IndexerType::All) {
let token = CancellationToken::new();
let mut indexer = KvIndexerSharded::new(
token.clone(),
......@@ -1297,8 +1365,109 @@ async fn run_stress_mode(args: StressArgs) {
tokio::time::sleep(Duration::from_millis(50)).await;
}
// Print comparison if both were run
if all_results.len() == 2 {
// Test nested indexer
if matches!(args.indexer_type, IndexerType::Nested | IndexerType::All) {
let indexer = ThreadPoolIndexer::new(
PositionalIndexer::new(args.jump_size),
args.num_shards,
args.common.block_size,
);
println!(
"\n Applying {} store events to PositionalIndexer...",
sequences.len()
);
let construction_start = Instant::now();
for (event_id, seq) in sequences.iter().enumerate() {
let event = seq.to_store_event(event_id as u64);
indexer.apply_event(event).await;
if args.common.verbose && (event_id + 1) % 100 == 0 {
println!(" Applied {}/{} events...", event_id + 1, sequences.len());
}
}
indexer.flush().await;
let construction_time = construction_start.elapsed();
let construction_events = sequences.len() as u64;
println!(" Tree construction completed in {:?}", construction_time);
println!(
" Throughput: {:.0} events/sec",
construction_events as f64 / construction_time.as_secs_f64()
);
tokio::time::sleep(Duration::from_millis(100)).await;
let indexer = Arc::new(indexer);
let mut results = run_stress_test(indexer.clone(), &sequences, &args).await;
results.construction_time = construction_time;
results.construction_events = construction_events;
print_stress_results(&args, &results);
all_results.push(results);
indexer.shutdown();
tokio::time::sleep(Duration::from_millis(50)).await;
}
// Test concurrent radix tree indexer
if matches!(
args.indexer_type,
IndexerType::Concurrent | IndexerType::All
) {
let indexer = ThreadPoolIndexer::new(
ConcurrentRadixTree::new(),
args.num_shards,
args.common.block_size,
);
println!(
"\n Applying {} store events to ConcurrentRadixTree...",
sequences.len()
);
let construction_start = Instant::now();
for (event_id, seq) in sequences.iter().enumerate() {
let event = seq.to_store_event(event_id as u64);
indexer.apply_event(event).await;
if args.common.verbose && (event_id + 1) % 100 == 0 {
println!(" Applied {}/{} events...", event_id + 1, sequences.len());
}
}
indexer.flush().await;
let construction_time = construction_start.elapsed();
let construction_events = sequences.len() as u64;
println!(" Tree construction completed in {:?}", construction_time);
println!(
" Throughput: {:.0} events/sec",
construction_events as f64 / construction_time.as_secs_f64()
);
tokio::time::sleep(Duration::from_millis(100)).await;
let indexer = Arc::new(indexer);
let mut results = run_stress_test(indexer.clone(), &sequences, &args).await;
results.construction_time = construction_time;
results.construction_events = construction_events;
print_stress_results(&args, &results);
all_results.push(results);
indexer.shutdown();
tokio::time::sleep(Duration::from_millis(50)).await;
}
// Print comparison if multiple were run
if all_results.len() >= 2 {
print_stress_comparison(&all_results, &args);
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use clap::{Parser, Subcommand};
use dynamo_kv_router::LocalBlockHash;
use dynamo_kv_router::indexer::{
KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvIndexerSharded,
};
use dynamo_kv_router::protocols::RouterEvent;
use dynamo_kv_router::{ConcurrentRadixTree, PositionalIndexer, ThreadPoolIndexer};
use rand::prelude::*;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::sync::Arc;
use uuid::Uuid;
use dynamo_kv_router::protocols::{KvCacheEvent, KvCacheEventData};
use dynamo_mocker::Scheduler;
use dynamo_mocker::protocols::{DirectRequest, KvCacheEventSink, MockEngineArgs};
use indicatif::{ProgressBar, ProgressStyle};
use std::sync::Mutex;
use tokio::task::JoinHandle;
use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use serde::{Deserialize, Serialize};
/// Indexer backend selection and its backend-specific parameters.
#[derive(Subcommand, Debug, Clone)]
enum IndexerArgs {
/// Single-threaded radix tree indexer.
RadixTree {},
/// Sharded radix tree indexer that partitions workers across independent shards.
RadixTreeSharded {
/// Number of independent shards to split workers across.
#[clap(long, default_value = "4")]
num_shards: usize,
},
/// Position-based nested map indexer with jump search.
NestedMap {
/// Number of positions to skip during jump search before scanning back.
#[clap(long, default_value = "8")]
jump_size: usize,
/// Number of OS threads that consume and apply KV cache events.
#[clap(long, default_value = "16")]
num_event_workers: usize,
},
/// Lock-based concurrent radix tree indexer.
ConcurrentRadixTree {
/// Number of OS threads that consume and apply KV cache events.
#[clap(long, default_value = "16")]
num_event_workers: usize,
},
}
impl IndexerArgs {
/// Construct the concrete indexer from the parsed CLI args.
fn build(self, args: &Args) -> Arc<dyn KvIndexerInterface + Send + Sync> {
let cancel_token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
match self {
IndexerArgs::RadixTree {} => {
Arc::new(KvIndexer::new(cancel_token, args.block_size, metrics))
}
IndexerArgs::RadixTreeSharded { num_shards } => Arc::new(KvIndexerSharded::new(
cancel_token,
num_shards,
args.block_size,
metrics,
)),
IndexerArgs::NestedMap {
jump_size,
num_event_workers,
} => Arc::new(ThreadPoolIndexer::new(
PositionalIndexer::new(jump_size),
num_event_workers,
args.block_size,
)),
IndexerArgs::ConcurrentRadixTree { num_event_workers } => {
Arc::new(ThreadPoolIndexer::new(
ConcurrentRadixTree::new(),
num_event_workers,
args.block_size,
))
}
}
}
}
#[derive(Parser, Debug)]
#[clap(version, about, long_about = None)]
struct Args {
/// Path to a JSONL mooncake trace file. Each line is a JSON object with
/// fields: uuid, timestamp, hash_ids, output_length.
mooncake_trace_path: String,
/// Number of GPU blocks available in the mock engine's KV cache.
/// Smaller values force more evictions and produce more remove events.
#[clap(long, default_value = "2048")]
num_gpu_blocks: usize,
/// Number of tokens per KV cache block.
#[clap(long, default_value = "512")]
block_size: u32,
/// Wall-clock duration (ms) over which the trace is replayed during event
/// generation. Longer values produce more accurate inter-request timing but
/// increase setup time.
#[clap(long, default_value = "30000")]
trace_simulation_duration_ms: u64,
/// Wall-clock duration (ms) over which the benchmark replays requests and
/// events against the indexer under test.
#[clap(long, default_value = "60000")]
benchmark_duration_ms: u64,
/// Number of unique simulated inference workers. Each gets a random
/// partition of the trace and its own mock engine for event generation.
#[clap(short, long, default_value = "64")]
num_unique_inference_workers: usize,
/// How many times to duplicate the set of unique workers during the
/// benchmark phase. Total workers = num_unique_inference_workers * factor.
/// Duplicated workers replay identical traces with distinct worker IDs.
#[clap(short = 'd', long, default_value = "1")]
inference_worker_duplication_factor: usize,
/// RNG seed for reproducible worker-to-trace assignment.
#[clap(long, default_value = "42")]
seed: u64,
/// Indexer backend to benchmark (defaults to radix-tree if not specified).
#[clap(subcommand)]
indexer: Option<IndexerArgs>,
/// Ignored - passed by cargo bench harness.
#[arg(long, hide = true, global = true)]
bench: bool,
}
impl Args {
/// Return the indexer config, falling back to RadixTree if none was specified.
fn get_indexer(&self) -> IndexerArgs {
self.indexer.clone().unwrap_or(IndexerArgs::RadixTree {})
}
}
/// A single request deserialized from the mooncake trace JSONL.
#[derive(Serialize, Deserialize, Clone)]
struct MooncakeRequest {
#[serde(default = "Uuid::new_v4")]
uuid: uuid::Uuid,
timestamp: u64,
hash_ids: Vec<u64>,
output_length: u64,
}
/// Collects KV cache events emitted by the mock engine during event generation,
/// tagging each with the wall-clock instant it was produced.
struct EventCollector {
events: Mutex<Option<Vec<(KvCacheEvent, Instant)>>>,
}
impl EventCollector {
fn new() -> Arc<Self> {
Arc::new(Self {
events: Mutex::new(Some(Vec::new())),
})
}
/// Take ownership of the collected events. Can only be called once.
fn get_events(self: Arc<Self>) -> Vec<(KvCacheEvent, Instant)> {
self.events.lock().unwrap().take().unwrap()
}
}
impl KvCacheEventSink for EventCollector {
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()> {
let timestamp = Instant::now();
if let Some(events) = self.events.lock().unwrap().as_mut() {
events.push((event, timestamp));
}
Ok(())
}
}
/// A single entry in a worker's merged benchmark timeline.
#[derive(Clone)]
enum WorkerTraceEntry {
/// A find_matches request with pre-computed block hashes.
Request(Vec<LocalBlockHash>),
/// A KV cache event (store/remove/clear) to apply to the indexer.
Event(KvCacheEvent),
}
/// A timestamped entry in a worker's benchmark trace, used to replay requests
/// and events at the correct relative timing.
#[derive(Clone)]
struct WorkerTrace {
entry: WorkerTraceEntry,
timestamp_us: u64,
}
/// Load the mooncake trace from disk and randomly partition requests across
/// `num_unique_inference_workers` worker buckets using the configured seed.
fn process_mooncake_trace(args: &Args) -> anyhow::Result<Vec<Vec<MooncakeRequest>>> {
let mut traces: Vec<Vec<MooncakeRequest>> = Vec::new();
for _ in 0..args.num_unique_inference_workers {
traces.push(Vec::new());
}
let mut rng = StdRng::seed_from_u64(args.seed);
let file = File::open(&args.mooncake_trace_path)?;
let reader = BufReader::new(file);
println!("Loading trace...");
let progress = make_progress_bar(None);
for line in reader.lines() {
let request = serde_json::from_str::<MooncakeRequest>(&line?)?;
traces[rng.random_range(0..args.num_unique_inference_workers)].push(request);
progress.inc(1);
}
Ok(traces)
}
/// Linearly rescale all timestamps in a worker's trace so the total span equals
/// `duration` milliseconds.
fn scale_mooncake_trace(trace: &Vec<MooncakeRequest>, duration: u64) -> Vec<MooncakeRequest> {
let total_duration = trace.last().unwrap().timestamp - trace.first().unwrap().timestamp;
trace
.iter()
.map(|request| MooncakeRequest {
timestamp: request.timestamp * duration / total_duration,
..request.clone()
})
.collect::<Vec<MooncakeRequest>>()
}
/// Expand a request's block-level hash_ids into per-token IDs by repeating each
/// hash_id `block_size` times.
fn tokens_from_request(request: &MooncakeRequest, block_size: u32) -> Vec<u32> {
request
.hash_ids
.iter()
.flat_map(|id| (0..block_size).map(|_| *id as u32))
.collect()
}
/// Create a styled progress bar, optionally with a known total length.
fn make_progress_bar(total: Option<u64>) -> ProgressBar {
let progress = match total {
Some(total) => ProgressBar::new(total),
None => ProgressBar::no_length(),
};
progress.set_style(
ProgressStyle::with_template(
"[{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({eta}) {msg}",
)
.unwrap()
.progress_chars("#>-"),
);
progress
}
/// Replay each worker's request trace through a mock engine in real-time to
/// produce the KV cache events (store/remove/clear) that the engine would emit.
///
/// Returns one event list per worker, each entry paired with the wall-clock
/// instant it was produced. Event ordering within a worker is guaranteed
/// monotonically non-decreasing by timestamp.
async fn generate_events(
traces: &Vec<Vec<MooncakeRequest>>,
args: &Args,
) -> anyhow::Result<Vec<Vec<(KvCacheEvent, Instant)>>> {
println!("Generating events...");
let sched_args = MockEngineArgs::builder()
.num_gpu_blocks(args.num_gpu_blocks)
.block_size(args.block_size as usize)
.speedup_ratio(0.0)
.enable_prefix_caching(true)
.max_num_batched_tokens(None)
.max_num_seqs(None)
.build()?;
let scaled_traces = traces
.iter()
.map(|worker_trace| scale_mooncake_trace(worker_trace, args.trace_simulation_duration_ms));
let progress = make_progress_bar(Some(
traces.iter().map(|worker| worker.len() as u64).sum::<u64>(),
));
let mut tasks: Vec<JoinHandle<Vec<(KvCacheEvent, Instant)>>> = Vec::new();
for worker_trace in scaled_traces {
let sched_args = sched_args.clone();
let progress = progress.clone();
let block_size = args.block_size;
tasks.push(tokio::spawn(async move {
let collector = EventCollector::new();
let scheduler = Scheduler::new(sched_args, 0, None, Some(collector.clone()), None);
let mut i = 0;
let mut target = Instant::now();
while i < worker_trace.len() {
let prev_i = i;
scheduler
.receive(DirectRequest {
tokens: tokens_from_request(&worker_trace[i], block_size),
max_output_tokens: worker_trace[i].output_length as usize,
uuid: Some(worker_trace[i].uuid),
dp_rank: 0,
})
.await;
i += 1;
while i < worker_trace.len()
&& worker_trace[i].timestamp == worker_trace[i - 1].timestamp
{
scheduler
.receive(DirectRequest {
tokens: tokens_from_request(&worker_trace[i], block_size),
max_output_tokens: worker_trace[i].output_length as usize,
uuid: Some(worker_trace[i].uuid),
dp_rank: 0,
})
.await;
i += 1;
}
if i < worker_trace.len() {
target += Duration::from_millis(
worker_trace[i].timestamp - worker_trace[i - 1].timestamp,
);
}
tokio::time::sleep_until(tokio::time::Instant::from(target)).await;
progress.inc((i - prev_i) as u64);
}
collector.get_events()
}));
}
let mut events = Vec::new();
for task in tasks {
events.push(task.await?);
}
for worker_events in &events {
for i in 1..worker_events.len() {
assert!(worker_events[i].1 >= worker_events[i - 1].1);
}
}
println!(
"Generated {} events. Processing...",
events.iter().map(|e| e.len()).sum::<usize>()
);
if progress.elapsed() > Duration::from_millis(args.trace_simulation_duration_ms * 11 / 10) {
eprintln!(
"Warning: Generated events took significantly longer than the trace simulation duration. Inaccurate timing information has been produced. Rerun with a larger --trace-simulation-duration-ms."
);
}
let mut num_stored_events = 0;
let mut num_removed_events = 0;
for event in events.iter().flatten() {
match event.0.data {
KvCacheEventData::Stored(_) => num_stored_events += 1,
KvCacheEventData::Removed(_) => num_removed_events += 1,
_ => (),
}
}
println!("Store events: {}", num_stored_events);
println!("Remove events: {}", num_removed_events);
Ok(events)
}
/// Merge each worker's request trace and event trace into a single
/// time-ordered sequence of `WorkerTrace` entries suitable for benchmark
/// replay.
///
/// Timestamps are rescaled from the original trace / simulation durations
/// into the benchmark duration (microseconds).
fn prepare_worker_traces(
traces: Vec<Vec<MooncakeRequest>>,
events: Vec<Vec<(KvCacheEvent, Instant)>>,
args: &Args,
) -> Vec<Vec<WorkerTrace>> {
assert!(traces.len() == events.len());
let scaled_request_traces: Vec<_> = traces
.into_iter()
.map(|trace| {
let trace_duration_ms =
trace.last().unwrap().timestamp - trace.first().unwrap().timestamp;
trace
.into_iter()
.map(|request| WorkerTrace {
timestamp_us: request.timestamp * 1000 * args.benchmark_duration_ms
/ trace_duration_ms,
entry: WorkerTraceEntry::Request(
request
.hash_ids
.iter()
.map(|id| LocalBlockHash(*id))
.collect(),
),
})
.collect::<Vec<_>>()
})
.collect();
let scaled_event_traces: Vec<_> = events
.into_iter()
.map(|worker_events| {
let start_instant = worker_events.first().unwrap().1;
worker_events
.into_iter()
.map(|(event, timestamp)| WorkerTrace {
timestamp_us: (timestamp - start_instant).as_micros() as u64
* args.benchmark_duration_ms
/ args.trace_simulation_duration_ms,
entry: WorkerTraceEntry::Event(event),
})
.collect::<Vec<_>>()
})
.collect();
scaled_request_traces
.into_iter()
.zip(scaled_event_traces.into_iter())
.map(|(request_trace, event_trace)| {
let mut merged: Vec<WorkerTrace> = request_trace
.into_iter()
.chain(event_trace.into_iter())
.collect();
merged.sort_by_key(|entry| entry.timestamp_us);
merged
})
.collect()
}
/// Run the benchmark: replay each worker's merged trace against the indexer,
/// measuring find_matches latency and event processing throughput.
///
/// Workers are spawned as tokio tasks, each replaying its trace at the
/// original inter-entry timing. After all workers finish, the event queue is
/// flushed and latency percentiles / throughput stats are printed.
async fn run_benchmark(
indexer: Arc<dyn KvIndexerInterface + Send + Sync>,
traces: Vec<Vec<MooncakeRequest>>,
events: Vec<Vec<(KvCacheEvent, Instant)>>,
args: &Args,
) -> anyhow::Result<()> {
let worker_traces = prepare_worker_traces(traces, events, args);
let worker_traces = worker_traces
.into_iter()
.map(|trace| Arc::new(trace))
.collect::<Vec<_>>();
let progress = make_progress_bar(Some(
worker_traces
.iter()
.map(|trace| trace.len() as u64)
.sum::<u64>()
* args.inference_worker_duplication_factor as u64,
));
let mut tasks = Vec::new();
for replica in 0..args.inference_worker_duplication_factor {
for (worker_id, worker_trace) in worker_traces.iter().enumerate() {
let indexer = indexer.clone();
let trace = worker_trace.clone();
let progress = progress.clone();
let worker_id = worker_id + replica * worker_traces.len();
tasks.push(tokio::spawn(async move {
let mut request_latencies = Vec::with_capacity(trace.len());
let submit = |entry: WorkerTrace| async {
match entry.entry {
WorkerTraceEntry::Request(request) => {
let start = minstant::Instant::now();
indexer.find_matches(request).await?;
Ok::<Option<u64>, anyhow::Error>(
Some(start.elapsed().as_nanos() as u64),
)
}
WorkerTraceEntry::Event(event) => {
indexer
.apply_event(RouterEvent {
worker_id: worker_id as u64,
event,
})
.await;
Ok(None)
}
}
};
let mut target = Instant::now();
let mut trace = trace.iter().peekable();
let mut local_count = 0;
while let Some(entry) = trace.next() {
let mut processed = 1;
let entry_timestamp_us = entry.timestamp_us;
if let Some(latency) = submit(entry.clone()).await? {
request_latencies.push(latency);
}
while let Some(next) = trace.peek() {
if next.timestamp_us == entry_timestamp_us {
if let Some(latency) = submit(trace.next().unwrap().clone()).await? {
request_latencies.push(latency);
}
processed += 1;
} else {
break;
}
}
if let Some(next) = trace.peek() {
target += Duration::from_micros(next.timestamp_us - entry_timestamp_us);
}
if target > Instant::now() {
tokio::time::sleep_until(target).await;
}
local_count += processed;
if local_count > 100 {
progress.inc(local_count);
local_count = 0;
}
}
progress.inc(local_count);
Ok::<_, anyhow::Error>(request_latencies)
}));
}
}
let mut latencies = Vec::new();
for task in tasks {
latencies.extend(task.await??);
}
if progress.elapsed() > Duration::from_millis(args.benchmark_duration_ms * 11 / 10) {
eprintln!(
"WARNING: The benchmarker is unable to keep up with the request/event generation rate. Rerun with a larger --benchmark-duration-ms."
)
}
println!("Flushing event queue...");
let request_duration = progress.elapsed();
let flush_start = Instant::now();
let flush_size = indexer.flush().await;
let flush_duration = flush_start.elapsed();
let event_duration = progress.elapsed();
let total_events = worker_traces
.iter()
.map(|trace| {
trace
.iter()
.filter(|trace| matches!(trace.entry, WorkerTraceEntry::Event(_)))
.count()
})
.sum::<usize>()
* args.inference_worker_duplication_factor;
let total_requests = worker_traces.iter().map(|trace| trace.len()).sum::<usize>()
* args.inference_worker_duplication_factor
- total_events;
let event_queue_flush_percentage = flush_size as f32 / total_events as f32 * 100.0;
println!("Event queue flush duration: {:?}", flush_duration);
println!(
"Event queue flush size: {} ({}% of total events)",
flush_size, event_queue_flush_percentage
);
if event_queue_flush_percentage > 5.0 {
eprintln!(
"ERROR: Over 5% of events were unable to be completed within the benchmark duration.
Results are invalid. Rerun with a smaller trace or less worker duplication."
);
}
println!(
"Request Throughput: {} req/s",
total_requests as f32 / request_duration.as_millis() as f32 * 1000.0
);
println!(
"Event Throughput: {} events/s",
total_events as f32 / event_duration.as_millis() as f32 * 1000.0
);
latencies.sort_unstable();
println!(
"Latency p50: {}us",
latencies[latencies.len() / 2] as f32 / 1000.0
);
println!(
"Latency p95: {}us",
latencies[latencies.len() * 95 / 100] as f32 / 1000.0
);
println!(
"Latency p99: {}us",
latencies[latencies.len() * 99 / 100] as f32 / 1000.0
);
println!(
"Latency max: {}us",
*latencies.last().unwrap() as f32 / 1000.0
);
Ok(())
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args = Args::parse();
let traces = process_mooncake_trace(&args)?;
let events = generate_events(&traces, &args).await?;
let indexer = args.get_indexer().build(&args);
run_benchmark(indexer, traces, events, &args).await?;
Ok(())
}
......@@ -15,30 +15,31 @@
use clap::{Parser, ValueEnum};
use dynamo_kv_router::{
RadixTree, RouterEvent,
ConcurrentRadixTree, OverlapScores, PositionalIndexer, RadixTree, RouterEvent, SyncIndexer,
bench_utils::{LatencyStats, SequenceData, generate_sequences},
compute_block_hash_for_seq,
flat_hashmap::FlatHashMap,
protocols::LocalBlockHash,
};
use std::time::{Duration, Instant};
/// Unified interface for RadixTree and FlatHashMap benchmarking.
/// Unified interface for RadixTree, ConcurrentRadixTree, and PositionalIndexer benchmarking.
///
/// Both structures have feature parity for store, remove, find_matches, and current_size.
/// All structures have feature parity for store, remove, find_matches, and current_size.
/// The key difference is find_matches input:
/// - RadixTree: uses LocalBlockHash (tokens_hash)
/// - FlatHashMap: uses ExternalSequenceBlockHash (cumulative sequence hash)
/// - RadixTree/ConcurrentRadixTree: uses LocalBlockHash (tokens_hash)
/// - PositionalIndexer: uses LocalBlockHash (same as tree; internal mapping uses sequence hashes)
enum KvIndex {
Tree(RadixTree),
Flat(FlatHashMap),
Concurrent(ConcurrentRadixTree),
Nested(PositionalIndexer),
}
impl KvIndex {
fn name(&self) -> &'static str {
match self {
KvIndex::Tree(_) => "RadixTree",
KvIndex::Flat(_) => "FlatHashMap",
KvIndex::Concurrent(_) => "ConcurrentRadixTree",
KvIndex::Nested(_) => "PositionalIndexer",
}
}
......@@ -47,8 +48,11 @@ impl KvIndex {
KvIndex::Tree(tree) => {
let _ = tree.apply_event(event);
}
KvIndex::Flat(map) => {
map.apply_event(event);
KvIndex::Concurrent(tree) => {
let _ = tree.apply_event(event);
}
KvIndex::Nested(map) => {
let _ = map.apply_event(event).ok();
}
}
}
......@@ -58,7 +62,8 @@ impl KvIndex {
let start = Instant::now();
let _ = match self {
KvIndex::Tree(tree) => tree.find_matches(local_hashes, early_exit),
KvIndex::Flat(map) => map.find_matches(local_hashes, early_exit),
KvIndex::Concurrent(tree) => tree.find_matches_impl(&local_hashes, early_exit),
KvIndex::Nested(map) => map.find_matches(&local_hashes, early_exit),
};
start.elapsed()
}
......@@ -70,7 +75,8 @@ impl KvIndex {
let start = Instant::now();
let _ = match self {
KvIndex::Tree(tree) => tree.find_matches(miss_hashes, early_exit),
KvIndex::Flat(map) => map.find_matches(miss_hashes, early_exit),
KvIndex::Concurrent(tree) => tree.find_matches_impl(&miss_hashes, early_exit),
KvIndex::Nested(map) => map.find_matches(&miss_hashes, early_exit),
};
start.elapsed()
}
......@@ -89,7 +95,8 @@ impl KvIndex {
let start = Instant::now();
let _ = match self {
KvIndex::Tree(tree) => tree.find_matches(partial, early_exit),
KvIndex::Flat(map) => map.find_matches(partial, early_exit),
KvIndex::Concurrent(tree) => tree.find_matches_impl(&partial, early_exit),
KvIndex::Nested(map) => map.find_matches(&partial, early_exit),
};
start.elapsed()
}
......@@ -97,7 +104,27 @@ impl KvIndex {
fn current_size(&self) -> usize {
match self {
KvIndex::Tree(tree) => tree.current_size(),
KvIndex::Flat(map) => map.current_size(),
KvIndex::Concurrent(tree) => tree.current_size(),
KvIndex::Nested(map) => map.current_size(),
}
}
fn find_matches(&self, local_hashes: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
match self {
KvIndex::Tree(tree) => tree.find_matches(local_hashes, early_exit),
KvIndex::Concurrent(tree) => tree.find_matches_impl(&local_hashes, early_exit),
KvIndex::Nested(map) => map.find_matches(&local_hashes, early_exit),
}
}
fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
match self {
KvIndex::Tree(tree) => tree.dump_tree_as_events(),
KvIndex::Concurrent(tree) => tree.dump_tree_as_events(),
KvIndex::Nested(_) => {
// NestedMap does not support dump_tree_as_events
vec![]
}
}
}
}
......@@ -197,44 +224,22 @@ struct Args {
#[arg(long, default_value = "42")]
seed: u64,
/// Use flat HashMap baseline instead of radix tree (for comparison)
/// Use nested map instead of radix tree (for comparison)
#[arg(long)]
flat_hashmap: bool,
}
/// Build a pre-populated RadixTree (for sweep/dump benchmarks that specifically need RadixTree)
fn build_tree(sequences: &[SequenceData]) -> RadixTree {
let num_blocks: usize = sequences.iter().map(|s| s.local_hashes.len()).sum();
print!(
" Building tree with {} sequences ({} blocks)... ",
sequences.len(),
num_blocks
);
std::io::Write::flush(&mut std::io::stdout()).unwrap();
nested_map: bool,
let start = Instant::now();
let mut tree = RadixTree::new();
for (event_id, seq) in sequences.iter().enumerate() {
let event = seq.to_store_event(event_id as u64);
let _ = tree.apply_event(event);
}
let elapsed = start.elapsed();
println!(
"done in {:.2?} ({:.2} sequences/sec, {:.2} blocks/sec)",
elapsed,
sequences.len() as f64 / elapsed.as_secs_f64(),
num_blocks as f64 / elapsed.as_secs_f64()
);
tree
/// Use concurrent radix tree instead of single-threaded radix tree
#[arg(long)]
concurrent: bool,
}
/// Build a pre-populated KvIndex (prints timing info)
fn build_index(sequences: &[SequenceData], use_flat_hashmap: bool) -> KvIndex {
fn build_index(sequences: &[SequenceData], use_nested_map: bool, use_concurrent: bool) -> KvIndex {
let num_blocks: usize = sequences.iter().map(|s| s.local_hashes.len()).sum();
let name = if use_flat_hashmap {
"FlatHashMap"
let name = if use_nested_map {
"NestedMap"
} else if use_concurrent {
"ConcurrentRadixTree"
} else {
"RadixTree"
};
......@@ -247,8 +252,10 @@ fn build_index(sequences: &[SequenceData], use_flat_hashmap: bool) -> KvIndex {
std::io::Write::flush(&mut std::io::stdout()).unwrap();
let start = Instant::now();
let mut index = if use_flat_hashmap {
KvIndex::Flat(FlatHashMap::new())
let mut index = if use_nested_map {
KvIndex::Nested(PositionalIndexer::new(32))
} else if use_concurrent {
KvIndex::Concurrent(ConcurrentRadixTree::new())
} else {
KvIndex::Tree(RadixTree::new())
};
......@@ -329,10 +336,10 @@ fn bench_store_remove_cycle(args: &Args, time_store: bool) {
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
true,
);
let mut index = build_index(&sequences, args.flat_hashmap);
let mut index = build_index(&sequences, args.nested_map, args.concurrent);
println!("\n=== Benchmarking {} ({}) ===", op_name, index.name());
println!(" Size: {} blocks", index.current_size());
......@@ -391,10 +398,10 @@ fn bench_find_matches(args: &Args) {
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
true,
);
let index = build_index(&sequences, args.flat_hashmap);
let index = build_index(&sequences, args.nested_map, args.concurrent);
println!("\n=== Benchmarking FIND_MATCHES ({}) ===", index.name());
println!(
" Built with {} sequences, {} total blocks",
......@@ -696,12 +703,12 @@ fn bench_sweep(args: &Args) {
args.prefix_prompt_ratio,
num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
true,
);
let tree_sequences = &all_sequences[..num_sequences];
let extra_sequences = &all_sequences[num_sequences..];
let mut tree = build_tree(tree_sequences);
let mut index = build_index(tree_sequences, args.nested_map, args.concurrent);
// --- STORE benchmark ---
let mut store_durations = Vec::with_capacity(args.sweep_iterations);
......@@ -718,12 +725,12 @@ fn bench_sweep(args: &Args) {
let store_event = truncated.to_store_event(i as u64);
let start = Instant::now();
let _ = tree.apply_event(store_event);
index.apply_event(store_event);
store_durations.push(start.elapsed());
// Remove to restore tree state (untimed)
// Remove to restore index state (untimed)
let remove_event = truncated.to_remove_event(i as u64);
let _ = tree.apply_event(remove_event);
index.apply_event(remove_event);
}
// --- REMOVE benchmark ---
......@@ -738,12 +745,12 @@ fn bench_sweep(args: &Args) {
let remove_event = truncated.to_remove_event(i as u64);
let start = Instant::now();
let _ = tree.apply_event(remove_event);
index.apply_event(remove_event);
remove_durations.push(start.elapsed());
// Re-add to restore state (untimed)
let store_event = truncated.to_store_event(i as u64 + 1000000);
let _ = tree.apply_event(store_event);
index.apply_event(store_event);
}
// --- FIND_MATCHES benchmark ---
......@@ -768,7 +775,7 @@ fn bench_sweep(args: &Args) {
};
let start = Instant::now();
let _ = tree.find_matches(query, false);
let _ = index.find_matches(query, false);
find_matches_durations.push(start.elapsed());
}
......@@ -808,19 +815,20 @@ fn bench_dump(args: &Args) {
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
true,
);
let tree = build_tree(&sequences);
let index = build_index(&sequences, args.nested_map, args.concurrent);
println!(
" Tree built with {} sequences, {} total blocks",
" {} built with {} sequences, {} total blocks",
index.name(),
sequences.len(),
tree.current_size()
index.current_size()
);
// Single iteration timing
let start = Instant::now();
let events = tree.dump_tree_as_events();
let events = index.dump_tree_as_events();
let elapsed = start.elapsed();
println!("\nDUMP_TREE_AS_EVENTS Results:");
......
......@@ -391,7 +391,7 @@ mod tests {
max_tree_size: usize::MAX,
prune_target_ratio: 0.5,
};
let mut indexer = KvIndexer::new_with_frequency(
let indexer = KvIndexer::new_with_frequency(
cancel.clone(),
None,
KV_BLOCK_SIZE,
......@@ -443,7 +443,7 @@ mod tests {
max_tree_size: usize::MAX,
prune_target_ratio: 0.5,
};
let mut indexer = KvIndexer::new_with_frequency(
let indexer = KvIndexer::new_with_frequency(
cancel.clone(),
None,
KV_BLOCK_SIZE,
......
......@@ -188,12 +188,12 @@ pub fn generate_sequences(
let local_hashes: Vec<LocalBlockHash> = (0..depth)
.map(|block_idx| {
let block_idx_u64 = block_idx as u64;
if let Some(gid) = group_id {
if block_idx < prefix_length {
if let Some(gid) = group_id
&& block_idx < prefix_length
{
// Shared prefix based on group_id
return LocalBlockHash(0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64);
}
}
// Unique suffix (or no shared prefix)
LocalBlockHash((seq_id_u64 << 32) | block_idx_u64)
})
......@@ -205,13 +205,13 @@ pub fn generate_sequences(
let external_hashes: Vec<ExternalSequenceBlockHash> = (0..depth)
.map(|block_idx| {
let block_idx_u64 = block_idx as u64;
if let Some(gid) = group_id {
if block_idx < prefix_length {
if let Some(gid) = group_id
&& block_idx < prefix_length
{
return ExternalSequenceBlockHash(
0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64,
);
}
}
ExternalSequenceBlockHash((seq_id_u64 << 32) | block_idx_u64)
})
.collect();
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Concurrent Radix Tree implementation for KV cache routing.
//!
//! This module provides a thread-safe radix tree data structure that enables concurrent
//! `find_matches` operations while maintaining correctness for write operations.
//!
//! Unlike `RadixTree` which uses `Rc<RefCell<>>` and requires single-threaded access,
//! `ConcurrentRadixTree` uses `Arc<RwLock<>>` per node and a
//! `DashMap<..., RwLock<HashMap<...>>>` for the lookup table.
//!
//! # Limitations vs RadixTree
//!
//! - Does NOT support `expiration_duration` / frequency tracking
//! - `new_with_frequency()` is not provided
//! - `find_matches` does not populate `OverlapScores.frequencies`
//!
//! # Concurrency Model
//!
//! - Multiple `find_matches` can run in parallel (read locks only)
//! - Write operations (`apply_event`, `remove_worker`) acquire write locks
//! - The outer `DashMap` distributes contention across shards; inner `RwLock`
//! per worker allows per-worker write concurrency.
//! - Deadlock prevention: always lock parent before child, hand-over-hand locking
use std::{
collections::{HashMap, HashSet, VecDeque},
sync::Arc,
};
use dashmap::DashMap;
use parking_lot::RwLock;
use crate::indexer::SyncIndexer;
use crate::protocols::*;
/// Thread-safe shared reference to a Block.
type SharedBlock = Arc<RwLock<Block>>;
/// A block in the concurrent radix tree.
#[derive(Debug)]
struct Block {
/// A map of child blocks, keyed by their local block hash.
children: HashMap<LocalBlockHash, SharedBlock>,
/// The set of workers that have this block cached.
workers: HashSet<WorkerWithDpRank>,
/// The external sequence block hash for this block (None for root).
block_hash: Option<ExternalSequenceBlockHash>,
// NOTE: No recent_uses field.
// Frequency tracking is not supported - keeps find_matches fully read-only.
}
impl Block {
/// Create a new `Block` (used for root node).
fn new() -> Self {
Self {
children: HashMap::new(),
workers: HashSet::new(),
block_hash: None,
}
}
/// Create a new `Block` with a specific block hash.
fn with_hash(block_hash: ExternalSequenceBlockHash) -> Self {
Self {
children: HashMap::new(),
workers: HashSet::new(),
block_hash: Some(block_hash),
}
}
}
/// Thread-safe radix tree for concurrent KV cache lookups.
///
/// Unlike `RadixTree` which uses `Rc<RefCell<>>` and requires single-threaded access,
/// `ConcurrentRadixTree` uses `Arc<RwLock<>>` per node and a
/// `DashMap<..., RwLock<HashMap<...>>>` for the lookup table,
/// enabling concurrent `find_matches` operations.
///
/// # Limitations vs RadixTree
///
/// - Does NOT support `expiration_duration` / frequency tracking
/// - `new_with_frequency()` is not provided
/// - `find_matches` does not populate `OverlapScores.frequencies`
///
/// # Concurrency Model
///
/// - Multiple `find_matches` can run in parallel (read locks only)
/// - Write operations (`apply_event`, `remove_worker`) acquire write locks
/// - The outer `DashMap` distributes contention across shards; inner `RwLock`
/// per worker allows per-worker write concurrency.
/// - Deadlock prevention: always lock parent before child, hand-over-hand locking
pub struct ConcurrentRadixTree {
/// This is the root of the radix/prefix tree.
/// This will only contain root blocks.
root: SharedBlock,
/// Per-worker lookup table for O(1) block access.
/// Outer `DashMap` distributes lock contention across shards; inner `RwLock`
/// per worker protects that worker's block-hash map.
lookup: DashMap<WorkerWithDpRank, RwLock<HashMap<ExternalSequenceBlockHash, SharedBlock>>>,
}
impl Default for ConcurrentRadixTree {
fn default() -> Self {
Self::new()
}
}
// Dropping blocks can cause a cascade of drops that can overflow the stack.
// This custom drop implementation avoids this using an iterative approach.
impl Drop for ConcurrentRadixTree {
fn drop(&mut self) {
let mut stack: Vec<SharedBlock> = Vec::new();
// Break root -> children edge up front
{
let mut root = self.root.write();
stack.extend(root.children.drain().map(|(_, v)| v));
}
// Remove all lookup references (they may include blocks not reachable from root).
// We have &mut self so no concurrent access; drain the DashMap by clearing it
// after collecting all inner values.
let entries: Vec<_> = self
.lookup
.iter()
.flat_map(|entry| entry.value().read().values().cloned().collect::<Vec<_>>())
.collect();
stack.extend(entries);
self.lookup.clear();
// Iteratively free any uniquely-owned blocks without recursion
while let Some(block) = stack.pop() {
if let Ok(rwlock) = Arc::try_unwrap(block) {
let mut inner = rwlock.into_inner();
stack.extend(inner.children.drain().map(|(_, v)| v));
}
}
}
}
impl ConcurrentRadixTree {
/// Create a new `ConcurrentRadixTree`.
pub fn new() -> Self {
Self {
root: Arc::new(RwLock::new(Block::new())),
lookup: DashMap::new(),
}
}
/// Traverse the radix tree to find the best match for a given sequence of [`LocalBlockHash`]es.
///
/// This operation is thread-safe and can run concurrently with other `find_matches` calls.
/// Uses hand-over-hand read locking to minimize lock contention.
///
/// ### Arguments
///
/// * `sequence` - A slice of `LocalBlockHash` representing the sequence to match.
/// * `early_exit` - A boolean indicating whether to exit early if a single match is found.
///
/// ### Returns
///
/// An `OverlapScores` representing the match scores.
/// Note: `frequencies` field will be empty since frequency tracking is not supported.
pub fn find_matches_impl(
&self,
sequence: &[LocalBlockHash],
early_exit: bool,
) -> OverlapScores {
let mut scores = OverlapScores::new();
if sequence.is_empty() {
return scores;
}
// Get first child from root.
let first_child = {
let guard = self.root.read();
guard.children.get(&sequence[0]).cloned()
};
let Some(first_child) = first_child else {
return scores;
};
// Initialize active worker set from first child.
let (mut active, mut active_count) = {
let guard = first_child.read();
(guard.workers.clone(), guard.workers.len())
};
if active.is_empty() {
return scores;
}
if early_exit && active_count == 1 {
for worker in &active {
scores.scores.insert(*worker, 1);
}
for worker in scores.scores.keys() {
if let Some(inner_lock) = self.lookup.get(worker) {
scores.tree_sizes.insert(*worker, inner_lock.read().len());
}
}
return scores;
}
let mut current = first_child;
let mut matched_depth = 1u32;
// Traverse remaining levels. In a clean tree, workers at a child node
// are always a subset of the parent (along the same path), so:
// - workers can only drop out, never join, as we descend
// - if child.workers.len() == active_count, the sets are identical
//
// However, because apply_removed does NOT cascade to descendants, a
// child may transiently have MORE workers than its parent (stale
// entries from an ancestor remove whose descendant remove events
// haven't arrived yet). We detect this via child_count > active_count
// and fall back to a full membership check.
for (idx, local_hash) in sequence.iter().enumerate().skip(1) {
let next_block = {
let guard = current.read();
guard.children.get(local_hash).cloned()
};
let Some(block) = next_block else {
break;
};
{
let guard = block.read();
let child_count = guard.workers.len();
if child_count < active_count {
// Workers dropped out. Record scores for those that left.
// Score = matched_depth (number of nodes they were present at).
for worker in &active {
if !guard.workers.contains(worker) {
scores.scores.insert(*worker, matched_depth);
}
}
active.clone_from(&guard.workers);
active_count = child_count;
if active_count == 0 {
break;
}
} else if child_count > active_count {
// child_count > active_count means stale entries exist
// (child retains workers already removed from an ancestor).
// Fall back to full membership check: keep only workers
// present in both active and this child, scoring dropouts.
active.retain(|w| {
if guard.workers.contains(w) {
true
} else {
scores.scores.insert(*w, matched_depth);
false
}
});
active_count = active.len();
if active_count == 0 {
break;
}
}
// child_count == active_count: fast path, sets are identical
// (or, in the rare edge case, different membership with same
// cardinality -- accepted as a transient routing quality
// degradation that resolves once pending remove events arrive).
if early_exit && active_count == 1 {
matched_depth = (idx + 1) as u32;
break;
}
}
current = block;
matched_depth = (idx + 1) as u32;
}
// Record scores for workers that survived through the deepest matched level.
for worker in &active {
scores.scores.insert(*worker, matched_depth);
}
// Get tree sizes from lookup.
for worker in scores.scores.keys() {
if let Some(inner_lock) = self.lookup.get(worker) {
scores.tree_sizes.insert(*worker, inner_lock.read().len());
}
}
scores
}
/// Apply a [`RouterEvent`] to the radix tree.
///
/// This operation is thread-safe. Interior mutability via locks allows
/// `&self` instead of `&mut self`.
///
/// ### Arguments
///
/// * `event` - The `RouterEvent` to apply.
pub fn apply_event(&self, event: RouterEvent) -> Result<(), KvCacheEventError> {
let (worker_id, kv_event) = (event.worker_id, event.event);
let (id, op) = (kv_event.event_id, kv_event.data);
// Construct WorkerWithDpRank from worker_id and dp_rank from the event
let worker = WorkerWithDpRank::new(worker_id, kv_event.dp_rank);
match op {
KvCacheEventData::Stored(op) => self.apply_stored(worker, op, id),
KvCacheEventData::Removed(op) => self.apply_removed(worker, op, id),
KvCacheEventData::Cleared => {
self.clear_all_blocks(worker.worker_id);
Ok(())
}
}
}
/// Apply a store operation.
fn apply_stored(
&self,
worker: WorkerWithDpRank,
op: KvCacheStoreData,
id: u64,
) -> Result<(), KvCacheEventError> {
// Ensure this worker has an entry in the outer map.
if !self.lookup.contains_key(&worker) {
self.lookup
.entry(worker)
.or_insert_with(|| RwLock::new(HashMap::new()));
}
let inner_ref = self.lookup.get(&worker).unwrap();
let mut worker_lookup = inner_ref.write();
// Find parent block
let mut current = match op.parent_hash {
Some(parent) => match worker_lookup.get(&parent) {
Some(block) => block.clone(),
None => {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
id,
parent_hash = ?op.parent_hash,
num_blocks = op.blocks.len(),
"Failed to find parent block; skipping store operation"
);
return Err(KvCacheEventError::ParentBlockNotFound);
}
},
None => self.root.clone(),
};
let mut needs_worker_insert = false;
// In each iteration, we lock the parent block and insert the worker into it from
// the previous iteration. This avoids locking a block twice.
for block_data in op.blocks {
let child = {
let mut parent_guard = current.write();
// Insert worker into this node if it was the child from the
// previous iteration (skip for the initial parent, which is
// not one of the blocks being stored).
if needs_worker_insert {
parent_guard.workers.insert(worker);
}
needs_worker_insert = true;
// parent_guard is dropped at the end of this block
match parent_guard.children.get(&block_data.tokens_hash) {
Some(existing) => {
// Verify our simplifying assumption: block_hash is uniform across workers
{
let existing_guard = existing.read();
if existing_guard.block_hash != Some(block_data.block_hash) {
tracing::warn!(
expected = ?block_data.block_hash,
actual = ?existing_guard.block_hash,
"block_hash mismatch: sequence hashes should be uniform across workers"
);
}
}
existing.clone()
}
None => {
// Reuse from lookup or create new
let new_block = worker_lookup
.get(&block_data.block_hash)
.cloned()
.unwrap_or_else(|| {
Arc::new(RwLock::new(Block::with_hash(block_data.block_hash)))
});
parent_guard
.children
.insert(block_data.tokens_hash, new_block.clone());
new_block
}
}
};
// Update lookup
worker_lookup.insert(block_data.block_hash, child.clone());
current = child;
}
// Insert worker into the last child (not yet handled since there is
// no subsequent iteration to pick it up).
if needs_worker_insert {
current.write().workers.insert(worker);
}
Ok(())
}
/// Apply a remove operation.
///
/// This method does NOT cascade to descendants. Each block hash in the event
/// is removed individually in O(1). Descendant blocks may transiently retain
/// the worker in their `workers` set until their own explicit remove events
/// arrive. `find_matches_impl` handles this by detecting stale entries when
/// `child_count > active_count`.
fn apply_removed(
&self,
worker: WorkerWithDpRank,
op: KvCacheRemoveData,
id: u64,
) -> Result<(), KvCacheEventError> {
let Some(inner_ref) = self.lookup.get(&worker) else {
return Err(KvCacheEventError::BlockNotFound);
};
let mut worker_lookup = inner_ref.write();
for block_hash in op.block_hashes {
let Some(block) = worker_lookup.remove(&block_hash) else {
tracing::debug!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
id,
block_hash = ?block_hash,
"Block not found during remove; skipping"
);
continue;
};
// Remove the worker from this block's worker set.
let mut guard = block.write();
guard.workers.remove(&worker);
if guard.workers.is_empty() {
guard.children.clear();
}
}
Ok(())
}
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains in lookup with empty blocks.
/// If `keep_worker` is false, the worker is completely removed from lookup.
fn remove_or_clear_worker_blocks(&self, worker_id: WorkerId, keep_worker: bool) {
// Collect all WorkerWithDpRank keys that match this worker_id
let workers: Vec<WorkerWithDpRank> = self
.lookup
.iter()
.filter(|entry| entry.key().worker_id == worker_id)
.map(|entry| *entry.key())
.collect();
for worker in workers {
if let Some((_, inner_lock)) = self.lookup.remove(&worker) {
// We now own the inner RwLock; extract the HashMap.
let blocks = inner_lock.into_inner();
for (_, block) in blocks {
let mut guard = block.write();
guard.workers.remove(&worker);
if guard.workers.is_empty() {
guard.children.clear();
}
}
if keep_worker {
self.lookup.insert(worker, RwLock::new(HashMap::new()));
}
}
}
}
/// Remove a worker and all their blocks from the tree.
pub fn remove_worker(&self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, false);
}
/// Clear all blocks for a worker but keep the worker tracked.
pub fn clear_all_blocks(&self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, true);
}
/// Get all worker IDs currently tracked in the radix tree.
/// Returns unique worker_ids (ignoring dp_rank differences).
pub fn get_workers(&self) -> Vec<WorkerId> {
let mut worker_ids: Vec<WorkerId> = self
.lookup
.iter()
.map(|entry| entry.key().worker_id)
.collect::<HashSet<_>>()
.into_iter()
.collect();
worker_ids.sort_unstable();
worker_ids
}
/// Dump the radix tree as a series of RouterEvents that can reconstruct the tree.
/// Uses BFS traversal to ensure that the tree reconstruction is unique,
/// though the exact event ordering will be lost.
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
tracing::debug!(
"Dumping concurrent radix tree as events (contains information about {:?} workers)",
self.lookup.len()
);
let mut events = Vec::new();
let mut event_id = 0u64;
// Queue entries: (current_block, parent_hash, tokens_hash)
let mut queue = VecDeque::new();
// Process root's children first
{
let root_guard = self.root.read();
for (tokens_hash, child_block) in &root_guard.children {
queue.push_back((child_block.clone(), None, *tokens_hash));
}
}
while let Some((current_block, parent_hash, tokens_hash)) = queue.pop_front() {
let current_guard = current_block.read();
// Get this block's hash (same for all workers)
let block_hash = current_guard
.block_hash
.expect("non-root block must have block_hash");
// For each worker that has this block
for worker in &current_guard.workers {
// Create a store event for this worker
let event = RouterEvent {
worker_id: worker.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: vec![KvCacheStoredBlockData {
block_hash,
mm_extra_info: None,
tokens_hash,
}],
}),
dp_rank: worker.dp_rank,
},
};
events.push(event);
event_id += 1;
}
// Enqueue children with this block's hash as their parent
for (child_tokens_hash, child_block) in &current_guard.children {
queue.push_back((child_block.clone(), Some(block_hash), *child_tokens_hash));
}
}
events
}
/// Get total number of blocks across all workers.
pub fn current_size(&self) -> usize {
self.lookup
.iter()
.map(|entry| entry.value().read().len())
.sum()
}
}
// ============================================================================
// SyncIndexer implementation for ConcurrentRadixTree
// ============================================================================
impl SyncIndexer for ConcurrentRadixTree {
fn find_matches(&self, sequence: &[LocalBlockHash], early_exit: bool) -> OverlapScores {
// Delegate to the existing find_matches method
self.find_matches_impl(sequence, early_exit)
}
fn apply_event(&self, event: RouterEvent) -> Result<(), KvCacheEventError> {
self.apply_event(event)
}
fn remove_worker(&self, worker_id: WorkerId) {
self.remove_worker(worker_id);
}
fn dump_events(&self) -> Vec<RouterEvent> {
self.dump_tree_as_events()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::{create_remove_event, create_store_event};
use std::sync::Arc;
use std::thread;
#[test]
fn test_concurrent_radix_tree_basic() {
let trie = ConcurrentRadixTree::new();
let worker_1 = 0;
let worker_2 = 1;
trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.unwrap();
let scores = trie.find_matches_impl(
&[LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(trie.lookup.len(), 1);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.read()
.len(),
3
);
trie.apply_event(create_store_event(worker_2, 1, vec![1, 4, 5], None))
.unwrap();
let scores = trie.find_matches_impl(
&[LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap(),
&1
);
assert_eq!(trie.lookup.len(), 2);
}
#[test]
fn test_concurrent_radix_tree_remove() {
let trie = ConcurrentRadixTree::new();
let worker_1 = 0;
let worker_2 = 1;
trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_2, 1, vec![1, 4, 5], None))
.unwrap();
trie.apply_event(create_remove_event(worker_2, 2, vec![5]))
.unwrap();
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.read()
.len(),
2
);
trie.apply_event(create_remove_event(worker_2, 3, vec![4]))
.unwrap();
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.read()
.len(),
1
);
}
#[test]
fn test_concurrent_radix_tree_apply_event_errors() {
let trie = ConcurrentRadixTree::new();
let worker_0 = 0;
// Parent block not found
let result = trie.apply_event(create_store_event(
worker_0,
0,
vec![1, 2, 3],
Some(ExternalSequenceBlockHash(12345)),
));
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
KvCacheEventError::ParentBlockNotFound
));
}
#[test]
fn test_clear_all_blocks() {
let trie = ConcurrentRadixTree::new();
let worker_0 = 0;
let worker_1 = 1;
trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![0, 2, 3], None))
.unwrap();
let result = trie.find_matches_impl(&[LocalBlockHash(0)], false).scores;
assert_eq!(result.len(), 2);
trie.clear_all_blocks(worker_0);
assert!(
trie.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.read()
.is_empty()
);
let result = trie
.find_matches_impl(&[LocalBlockHash(0), LocalBlockHash(2)], false)
.scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 2);
}
#[test]
fn test_remove_worker() {
let trie = ConcurrentRadixTree::new();
let worker_0 = 0;
let worker_1 = 1;
trie.apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![1, 2, 3], None))
.unwrap();
assert_eq!(trie.lookup.len(), 2);
trie.remove_worker(worker_0);
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert_eq!(trie.lookup.len(), 1);
let result = trie
.find_matches_impl(
&[LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
)
.scores;
assert_eq!(result.len(), 1);
assert!(!result.contains_key(&WorkerWithDpRank::from_worker_id(worker_0)));
assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_1)));
}
#[test]
fn test_concurrent_radix_tree_default() {
let trie: ConcurrentRadixTree = Default::default();
assert!(trie.root.read().children.is_empty());
assert!(trie.root.read().workers.is_empty());
assert!(trie.lookup.is_empty());
}
#[test]
fn test_concurrent_find_matches() {
let trie = Arc::new(ConcurrentRadixTree::new());
// Populate tree
trie.apply_event(create_store_event(0, 0, vec![1, 2, 3, 4, 5], None))
.unwrap();
trie.apply_event(create_store_event(1, 0, vec![1, 2, 6, 7, 8], None))
.unwrap();
let sequence = vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
LocalBlockHash(4),
LocalBlockHash(5),
];
// Spawn multiple threads doing concurrent find_matches
let handles: Vec<_> = (0..10)
.map(|_| {
let tree = trie.clone();
let seq = sequence.clone();
thread::spawn(move || tree.find_matches_impl(&seq, false))
})
.collect();
// All should return the same result
let expected_worker_0_score = 5;
let expected_worker_1_score = 2;
for h in handles {
let result = h.join().unwrap();
assert_eq!(
result
.scores
.get(&WorkerWithDpRank::from_worker_id(0))
.unwrap(),
&expected_worker_0_score
);
assert_eq!(
result
.scores
.get(&WorkerWithDpRank::from_worker_id(1))
.unwrap(),
&expected_worker_1_score
);
}
}
#[test]
fn test_concurrent_read_write() {
let trie = Arc::new(ConcurrentRadixTree::new());
// Pre-populate
for i in 0..5 {
trie.apply_event(create_store_event(i, 0, vec![1, 2, 3], None))
.unwrap();
}
let sequence = vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)];
// Spawn readers
let reader_handles: Vec<_> = (0..5)
.map(|_| {
let tree = trie.clone();
let seq = sequence.clone();
thread::spawn(move || {
for _ in 0..100 {
let _ = tree.find_matches_impl(&seq, false);
}
})
})
.collect();
// Spawn writers (adding more workers)
let writer_handles: Vec<_> = (5..10)
.map(|i| {
let tree = trie.clone();
thread::spawn(move || {
for j in 0..10 {
let _ =
tree.apply_event(create_store_event(i, j, vec![1, 2, 3, 4 + j], None));
}
})
})
.collect();
// Wait for all threads
for h in reader_handles {
h.join().unwrap();
}
for h in writer_handles {
h.join().unwrap();
}
// Tree should have 10 workers now
assert_eq!(trie.get_workers().len(), 10);
}
#[test]
fn test_remove_parent_does_not_cascade() {
let trie = ConcurrentRadixTree::new();
let worker_1 = 0;
// Create a chain: root -> block1 -> block2 -> block3
trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.unwrap();
let worker_key = WorkerWithDpRank::from_worker_id(worker_1);
assert_eq!(trie.lookup.get(&worker_key).unwrap().read().len(), 3);
// Remove ONLY block1 -- descendants should NOT be cascade-removed
trie.apply_event(create_remove_event(worker_1, 2, vec![1]))
.unwrap();
let inner_ref = trie.lookup.get(&worker_key).unwrap();
let worker_lookup = inner_ref.read();
assert!(
!worker_lookup.contains_key(&ExternalSequenceBlockHash(100)),
"block1 should be removed"
);
assert!(
worker_lookup.contains_key(&ExternalSequenceBlockHash(200)),
"block2 should remain (no cascade)"
);
assert!(
worker_lookup.contains_key(&ExternalSequenceBlockHash(300)),
"block3 should remain (no cascade)"
);
assert_eq!(worker_lookup.len(), 2);
}
#[test]
fn test_remove_all_blocks_individually() {
// Verifies that explicitly removing all blocks (as the engine would)
// cleans up fully, even without cascade.
let trie = ConcurrentRadixTree::new();
let worker_1 = 0;
trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.unwrap();
let worker_key = WorkerWithDpRank::from_worker_id(worker_1);
// Remove all three blocks explicitly in one event
trie.apply_event(create_remove_event(worker_1, 2, vec![1, 2, 3]))
.unwrap();
let inner_ref = trie.lookup.get(&worker_key).unwrap();
let worker_lookup = inner_ref.read();
assert_eq!(worker_lookup.len(), 0, "all blocks should be removed");
}
#[test]
fn test_find_matches_with_stale_entries() {
// Two workers share a full path. Remove worker_1 from the root block
// only (simulating a partial remove). find_matches should still
// produce correct scores for worker_2, and worker_1 should score at
// the stale descendant depth (transiently inflated but not a crash).
let trie = ConcurrentRadixTree::new();
let worker_1 = 0;
let worker_2 = 1;
// Both workers have blocks 1 -> 2 -> 3
trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_2, 2, vec![1, 2, 3], None))
.unwrap();
// Remove worker_1 from block 1 only (no cascade to 2,3)
trie.apply_event(create_remove_event(worker_1, 3, vec![1]))
.unwrap();
let scores = trie.find_matches_impl(
&[LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
// worker_2 was never removed, should have full depth
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_2)),
Some(&3),
"worker_2 should score 3 (fully present)"
);
// worker_1 was removed from block 1 so it drops out at depth 1.
// But because blocks 2 and 3 still have worker_1 (stale), the
// child_count > active_count path fires and detects the dropout.
// The exact score depends on the detection logic: worker_1 is absent
// from block 1's workers, so it should be scored at depth 0 from the
// first child initialization (it won't appear in `active` at all).
// So worker_1 should NOT appear in scores (it was never in active).
assert!(
!scores
.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1)),
"worker_1 should not appear in scores (removed from root-level block)"
);
}
// ========================================================================
// ThreadPoolIndexer<ConcurrentRadixTree> Tests
// ========================================================================
mod thread_pool_indexer_tests {
use tokio::time::Duration;
use super::*;
use crate::indexer::{KvIndexerInterface, ThreadPoolIndexer};
fn make_indexer(
num_workers: usize,
kv_block_size: u32,
) -> ThreadPoolIndexer<ConcurrentRadixTree> {
ThreadPoolIndexer::new(ConcurrentRadixTree::new(), num_workers, kv_block_size)
}
#[tokio::test]
async fn test_thread_pool_indexer_basic() {
let indexer = make_indexer(4, 16);
let worker_1 = 0;
let worker_2 = 1;
indexer
.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.await;
indexer
.apply_event(create_store_event(worker_2, 1, vec![1, 4, 5], None))
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
let scores = indexer
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap(),
&1
);
indexer.shutdown();
}
#[tokio::test]
async fn test_thread_pool_indexer_remove_worker() {
let indexer = make_indexer(2, 16);
let worker_0 = 0;
let worker_1 = 1;
indexer
.apply_event(create_store_event(worker_0, 1, vec![1, 2, 3], None))
.await;
indexer
.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(indexer.backend().get_workers().len(), 2);
indexer.remove_worker(worker_0).await;
let workers = indexer.backend().get_workers();
assert_eq!(workers.len(), 1);
assert!(!workers.contains(&worker_0));
assert!(workers.contains(&worker_1));
indexer.shutdown();
}
#[tokio::test]
async fn test_thread_pool_indexer_dump_events() {
let indexer = make_indexer(2, 16);
indexer
.apply_event(create_store_event(0, 1, vec![1, 2, 3], None))
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
let events = indexer.dump_events().await.unwrap();
assert_eq!(events.len(), 3);
indexer.shutdown();
}
#[tokio::test]
async fn test_thread_pool_indexer_find_matches_for_request() {
let indexer = make_indexer(2, 1);
indexer
.apply_event(create_store_event(0, 1, vec![100, 200, 300], None))
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
let scores = indexer.find_matches_for_request(&[100, 200, 300]).await;
assert!(scores.is_ok());
indexer.shutdown();
}
#[tokio::test]
async fn test_thread_pool_indexer_sticky_routing() {
let indexer = make_indexer(4, 16);
for i in 0..10 {
indexer
.apply_event(create_store_event(0, i, vec![i as u64], None))
.await;
}
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(indexer.backend().current_size(), 10);
indexer.shutdown();
}
#[tokio::test]
async fn test_thread_pool_indexer_multiple_workers() {
let indexer = make_indexer(4, 16);
for worker_id in 0..8 {
indexer
.apply_event(create_store_event(
worker_id,
1,
vec![1, 2, worker_id as u64 + 10],
None,
))
.await;
}
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(indexer.backend().get_workers().len(), 8);
let scores = indexer
.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)])
.await
.unwrap();
assert_eq!(scores.scores.len(), 8);
for (_, score) in scores.scores.iter() {
assert_eq!(*score, 2);
}
indexer.shutdown();
}
#[tokio::test]
async fn test_thread_pool_indexer_shutdown_idempotent() {
let indexer = make_indexer(2, 16);
indexer
.apply_event(create_store_event(0, 1, vec![1, 2, 3], None))
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
indexer.shutdown();
indexer.shutdown();
}
#[tokio::test]
async fn test_thread_pool_indexer_concurrent_operations() {
use std::sync::Arc;
let indexer = Arc::new(make_indexer(4, 16));
for worker_id in 0..4 {
indexer
.apply_event(create_store_event(worker_id, 1, vec![1, 2, 3, 4, 5], None))
.await;
}
tokio::time::sleep(Duration::from_millis(100)).await;
let sequence = vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)];
let mut handles = Vec::new();
for _ in 0..10 {
let idx = indexer.clone();
let seq = sequence.clone();
handles.push(tokio::spawn(
async move { idx.find_matches(seq).await.unwrap() },
));
}
for handle in handles {
let scores = handle.await.unwrap();
assert_eq!(scores.scores.len(), 4);
}
indexer.shutdown();
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Flat HashMap baseline for benchmarking comparison with RadixTree.
//!
//! This module provides a `FlatHashMap` structure that has full feature parity with `RadixTree`
//! but uses flat HashMaps instead of a tree structure. This isolates the overhead of
//! tree traversal (pointer chasing) from pure HashMap operations.
//!
//! The `find_matches` API matches RadixTree exactly: it takes `LocalBlockHash` values
//! and internally computes the cumulative sequence hashes for lookup.
use std::collections::{HashMap, HashSet};
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, OverlapScores, RouterEvent, WorkerId, WorkerWithDpRank,
compute_seq_hash_for_block,
};
/// A flat HashMap-based structure for KV cache indexing.
///
/// Unlike RadixTree which uses a tree of nodes connected by pointers,
/// FlatHashMap uses bidirectional HashMaps. This provides the same
/// find_matches semantics but with better cache locality.
///
/// # Structure
///
/// - `block_to_workers`: Maps ExternalSequenceBlockHash -> Set of workers that have this block.
/// Used for efficient find_matches lookups.
/// - `worker_to_blocks`: Maps Worker -> Set of ExternalSequenceBlockHash they have.
/// Used for remove operations and current_size.
pub struct FlatHashMap {
/// Primary index: block -> workers (for find_matches)
block_to_workers: HashMap<ExternalSequenceBlockHash, HashSet<WorkerWithDpRank>>,
/// Secondary index: worker -> blocks (for remove and current_size)
worker_to_blocks: HashMap<WorkerWithDpRank, HashSet<ExternalSequenceBlockHash>>,
}
impl FlatHashMap {
/// Create a new empty FlatHashMap.
pub fn new() -> Self {
Self {
block_to_workers: HashMap::new(),
worker_to_blocks: HashMap::new(),
}
}
/// Store blocks for a worker.
///
/// Updates both indexes for each block.
pub fn store(&mut self, worker: WorkerWithDpRank, block_hashes: &[ExternalSequenceBlockHash]) {
let worker_blocks = self.worker_to_blocks.entry(worker).or_default();
for &block_hash in block_hashes {
// Add to block -> workers index
self.block_to_workers
.entry(block_hash)
.or_default()
.insert(worker);
// Add to worker -> blocks index
worker_blocks.insert(block_hash);
}
}
/// Remove blocks for a worker.
///
/// Updates both indexes for each block.
pub fn remove(&mut self, worker: WorkerWithDpRank, block_hashes: &[ExternalSequenceBlockHash]) {
let Some(worker_blocks) = self.worker_to_blocks.get_mut(&worker) else {
return;
};
for &block_hash in block_hashes {
// Remove from worker -> blocks index
worker_blocks.remove(&block_hash);
// Remove from block -> workers index
if let Some(workers) = self.block_to_workers.get_mut(&block_hash) {
workers.remove(&worker);
if workers.is_empty() {
self.block_to_workers.remove(&block_hash);
}
}
}
// Clean up empty worker entry
if worker_blocks.is_empty() {
self.worker_to_blocks.remove(&worker);
}
}
/// Find matches for a sequence of local block hashes.
///
/// This has the same signature as `RadixTree::find_matches`: it takes `LocalBlockHash`
/// values and internally computes the cumulative sequence hashes for lookup.
///
/// Returns OverlapScores showing which workers have matching blocks.
/// Stops at first non-match (same semantics as RadixTree).
///
/// # Algorithm
///
/// 1. Compute cumulative sequence hashes from local block hashes
/// 2. For each sequence hash:
/// - Look up which workers have this block
/// - Intersect with previously matching workers (in place)
/// - Track depth for scoring
/// - Stop if no workers remain
///
/// This is O(depth) HashMap lookups + O(num_workers) set operations per level.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
let mut scores = OverlapScores::new();
if sequence.is_empty() {
return scores;
}
// Compute cumulative sequence hashes from local block hashes
let seq_hashes = compute_seq_hash_for_block(&sequence);
// Track active workers and their match depth
// Workers drop out when they miss a block; their final score is the depth they reached
let mut active_workers: Option<HashSet<WorkerWithDpRank>> = None;
let mut depth = 0u32;
for seq_hash in seq_hashes {
let block_hash = ExternalSequenceBlockHash(seq_hash);
// Look up workers that have this block
let Some(workers) = self.block_to_workers.get(&block_hash) else {
break; // No workers have this block, stop
};
// Intersect with previously active workers (or initialize on first block)
match &mut active_workers {
None => {
// First block: initialize with workers that have it
active_workers = Some(workers.clone());
}
Some(active) => {
// Record score for workers about to drop out (they matched up to current depth)
for &worker in active.iter() {
if !workers.contains(&worker) {
scores.scores.insert(worker, depth);
}
}
// Keep only workers that have this block (in-place, no allocation)
active.retain(|w| workers.contains(w));
}
}
depth += 1;
let active = active_workers.as_ref().unwrap();
if active.is_empty() {
break;
}
// Early exit if only one worker matches
if early_exit && active.len() == 1 {
break;
}
}
// Record final scores for workers that matched all blocks (or until early exit)
if let Some(active) = active_workers {
for worker in active {
scores.scores.insert(worker, depth);
}
}
// Populate tree sizes for workers with scores
for &worker in scores.scores.keys() {
if let Some(blocks) = self.worker_to_blocks.get(&worker) {
scores.tree_sizes.insert(worker, blocks.len());
}
}
scores
}
/// Apply a RouterEvent (for API compatibility with RadixTree).
pub fn apply_event(&mut self, event: RouterEvent) {
let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank);
match event.event.data {
KvCacheEventData::Stored(store_data) => {
let hashes: Vec<_> = store_data.blocks.iter().map(|b| b.block_hash).collect();
self.store(worker, &hashes);
}
KvCacheEventData::Removed(remove_data) => {
self.remove(worker, &remove_data.block_hashes);
}
KvCacheEventData::Cleared => {
self.clear_all_blocks(worker.worker_id);
}
}
}
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains in lookup with empty blocks.
/// If `keep_worker` is false, the worker is completely removed from lookup.
fn remove_or_clear_worker_blocks(&mut self, worker_id: WorkerId, keep_worker: bool) {
// Collect all WorkerWithDpRank keys that match this worker_id
let workers: Vec<WorkerWithDpRank> = self
.worker_to_blocks
.keys()
.filter(|w| w.worker_id == worker_id)
.copied()
.collect();
for worker in workers {
if let Some(blocks) = self.worker_to_blocks.remove(&worker) {
for block_hash in blocks {
if let Some(workers_set) = self.block_to_workers.get_mut(&block_hash) {
workers_set.remove(&worker);
if workers_set.is_empty() {
self.block_to_workers.remove(&block_hash);
}
}
}
if keep_worker {
// Re-insert worker with empty blocks set to keep it tracked
self.worker_to_blocks.insert(worker, HashSet::new());
}
}
}
}
/// Remove a worker and all their blocks from the index.
pub fn remove_worker(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, false);
}
/// Clear all blocks for a worker but keep the worker tracked.
pub fn clear_all_blocks(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, true);
}
/// Get all worker IDs currently tracked in the index.
/// Returns unique worker_ids sorted (ignoring dp_rank differences).
pub fn get_workers(&self) -> Vec<WorkerId> {
let mut worker_ids: Vec<WorkerId> = self
.worker_to_blocks
.keys()
.map(|w| w.worker_id)
.collect::<HashSet<_>>()
.into_iter()
.collect();
worker_ids.sort_unstable();
worker_ids
}
/// Dump the index as a series of RouterEvents that can reconstruct the state.
/// For API compatibility with RadixTree.
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
let mut events = Vec::new();
let mut event_id = 0u64;
for (&worker, blocks) in &self.worker_to_blocks {
for &block_hash in blocks {
let event = RouterEvent {
worker_id: worker.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, // FlatHashMap doesn't track parent relationships
blocks: vec![KvCacheStoredBlockData {
block_hash,
mm_extra_info: None,
// We don't have the original tokens_hash, use a placeholder
tokens_hash: LocalBlockHash(0),
}],
}),
dp_rank: worker.dp_rank,
},
};
events.push(event);
event_id += 1;
}
}
events
}
/// Returns the total number of (worker, block) pairs stored.
pub fn current_size(&self) -> usize {
self.worker_to_blocks.values().map(|s| s.len()).sum()
}
}
impl Default for FlatHashMap {
fn default() -> Self {
Self::new()
}
}
......@@ -35,6 +35,7 @@
use std::time::Instant;
use async_trait::async_trait;
use dashmap::DashMap;
#[cfg(feature = "metrics")]
pub use dynamo_runtime::protocols::maybe_error::MaybeError;
#[cfg(feature = "metrics")]
......@@ -57,9 +58,9 @@ use serde::{Deserialize, Serialize};
#[cfg(feature = "metrics")]
use std::sync::OnceLock;
use std::{
collections::{HashMap, VecDeque},
collections::VecDeque,
iter,
sync::{Arc, Mutex},
sync::{Arc, Mutex, atomic::AtomicUsize},
thread::JoinHandle,
time::Duration,
};
......@@ -67,105 +68,11 @@ use tokio::sync::{broadcast, mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use crate::approx::{BlockEntry, PruneConfig, PruneManager};
use crate::flat_hashmap::FlatHashMap;
// use crate::nested_map::NestedMap;
use crate::protocols::*;
pub use crate::radix_tree::RadixTree;
use dynamo_tokens::SequenceHash;
// ------
// KvIndex - Unified interface for RadixTree and FlatHashMap
// ------
/// Unified interface for KV cache indexing.
///
/// Both `RadixTree` and `FlatHashMap` implement the same core operations:
/// - `find_matches`: Find workers with matching cached blocks
/// - `apply_event`: Apply store/remove events
/// - `remove_worker`: Remove a worker's entries
/// - `get_workers`: Get all tracked workers
/// - `dump_tree_as_events`: Dump state as events
/// - `current_size`: Get total (worker, block) pairs
pub enum KvIndex {
Tree(RadixTree),
Flat(FlatHashMap),
}
impl KvIndex {
/// Create a new KvIndex using RadixTree.
pub fn new_tree() -> Self {
KvIndex::Tree(RadixTree::new())
}
/// Create a new KvIndex using RadixTree with frequency tracking.
pub fn new_tree_with_frequency(expiration_duration: Option<std::time::Duration>) -> Self {
KvIndex::Tree(RadixTree::new_with_frequency(expiration_duration))
}
/// Create a new KvIndex using FlatHashMap.
pub fn new_flat() -> Self {
KvIndex::Flat(FlatHashMap::new())
}
/// Find matches for a sequence of local block hashes.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
match self {
KvIndex::Tree(tree) => tree.find_matches(sequence, early_exit),
KvIndex::Flat(map) => map.find_matches(sequence, early_exit),
}
}
/// Apply a RouterEvent to the index.
pub fn apply_event(&mut self, event: RouterEvent) -> Result<(), KvCacheEventError> {
match self {
KvIndex::Tree(tree) => tree.apply_event(event),
KvIndex::Flat(map) => {
map.apply_event(event);
Ok(())
}
}
}
/// Remove a worker and all their blocks from the index.
pub fn remove_worker(&mut self, worker_id: WorkerId) {
match self {
KvIndex::Tree(tree) => tree.remove_worker(worker_id),
KvIndex::Flat(map) => map.remove_worker(worker_id),
}
}
/// Clear all blocks for a worker but keep the worker tracked.
pub fn clear_all_blocks(&mut self, worker_id: WorkerId) {
match self {
KvIndex::Tree(tree) => tree.clear_all_blocks(worker_id),
KvIndex::Flat(map) => map.clear_all_blocks(worker_id),
}
}
/// Get all worker IDs currently tracked.
pub fn get_workers(&self) -> Vec<WorkerId> {
match self {
KvIndex::Tree(tree) => tree.get_workers(),
KvIndex::Flat(map) => map.get_workers(),
}
}
/// Dump the index as a series of RouterEvents.
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
match self {
KvIndex::Tree(tree) => tree.dump_tree_as_events(),
KvIndex::Flat(map) => map.dump_tree_as_events(),
}
}
/// Returns the total number of (worker, block) pairs stored.
pub fn current_size(&self) -> usize {
match self {
KvIndex::Tree(tree) => tree.current_size(),
KvIndex::Flat(map) => map.current_size(),
}
}
}
/// Errors that can occur in the KV Router.
#[derive(Debug, thiserror::Error)]
pub enum KvRouterError {
......@@ -333,14 +240,14 @@ impl KvIndexerMetrics {
/// A request to find matches in the Radix Tree.
pub struct MatchRequest {
/// A vector of `LocalBlockHash` representing the sequence to match.
sequence: Vec<LocalBlockHash>,
pub sequence: Vec<LocalBlockHash>,
/// A boolean indicating whether to exit early if a single match is found.
early_exit: bool,
pub early_exit: bool,
/// A channel sender to send the `OverlapScores` response.
resp: oneshot::Sender<OverlapScores>,
pub resp: oneshot::Sender<OverlapScores>,
/// Timestamp when the request was created (for queue wait time measurement)
#[cfg(feature = "bench")]
created_at: Instant,
pub created_at: Instant,
}
impl MatchRequest {
......@@ -406,17 +313,17 @@ pub trait KvIndexerInterface {
/// ### Arguments
///
/// * `event` - The `RouterEvent` to apply.
async fn apply_event(&mut self, event: RouterEvent);
async fn apply_event(&self, event: RouterEvent);
/// Remove a worker's entries from the trie.
///
/// ### Arguments
///
/// * `worker` - The worker to remove from the trie.
async fn remove_worker(&mut self, worker: WorkerId);
async fn remove_worker(&self, worker: WorkerId);
/// Shutdown the KV Indexer.
fn shutdown(&mut self);
fn shutdown(&self);
/// Dump the entire tree as RouterEvents.
///
......@@ -439,6 +346,245 @@ pub trait KvIndexerInterface {
tokens_with_hashes: &mut TokensWithHashes,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError>;
/// Async task that returns when all pending events have been processed.
/// For now, we assume that no requests or events are being sent in the meantime.
/// Returns the amount of events still in the queue at the time of the flush.
/// Used primarily for debugging.
async fn flush(&self) -> usize;
}
// ============================================================================
// SyncIndexer trait and ThreadPoolIndexer generic wrapper
// ============================================================================
/// Trait for thread-safe data structures that support KV cache indexing operations.
///
/// All methods take `&self` and are synchronous. Implementations must be safe for
/// concurrent access (via internal locking, DashMap, etc).
///
/// This trait is used with [`ThreadPoolIndexer`], which wraps a `SyncIndexer` to
/// provide the async [`KvIndexerInterface`] with:
/// - Sticky event routing to N worker threads
/// - Inline reads on the caller's thread (no channel dispatch for find_matches)
pub trait SyncIndexer: Send + Sync + 'static {
/// Find matches for a sequence of block hashes.
fn find_matches(&self, sequence: &[LocalBlockHash], early_exit: bool) -> OverlapScores;
/// Apply a router event to the data structure.
fn apply_event(&self, event: RouterEvent) -> Result<(), KvCacheEventError>;
/// Remove all entries for a worker.
fn remove_worker(&self, worker_id: WorkerId);
/// Dump the data structure as router events for reconstruction.
fn dump_events(&self) -> Vec<RouterEvent>;
}
/// Generic wrapper that provides [`KvIndexerInterface`] for any [`SyncIndexer`] backend.
///
/// Spawns N OS threads for processing write events (sticky-routed by WorkerId).
/// Read operations (find_matches) are executed inline on the caller's thread,
/// avoiding channel overhead and allowing reads to scale with callers.
///
/// # Architecture
///
/// ```text
/// +------------------------------------+
/// | N Worker Threads (OS threads) |
/// | |
/// worker_event_channels[0] ----------> | Thread 0: blocking recv loop |
/// worker_event_channels[1] ----------> | Thread 1: blocking recv loop |
/// worker_event_channels[N] ----------> | Thread N: blocking recv loop |
/// | |
/// find_matches() ---(inline)---------> | Arc<T: SyncIndexer> |
/// | (shared, thread-safe) |
/// +------------------------------------+
/// ```
pub struct ThreadPoolIndexer<T: SyncIndexer> {
/// Shared backend - thread-safe via internal locking.
backend: Arc<T>,
/// Maps WorkerId to worker thread index for sticky routing.
worker_assignments: DashMap<WorkerId, usize>,
/// Counter for round-robin assignment of new WorkerIds.
worker_assignment_count: AtomicUsize,
/// Channels to send events to worker threads (one per thread).
/// Sending `None` signals the thread to shut down.
worker_event_channels: Vec<flume::Sender<Option<RouterEvent>>>,
/// Number of worker threads.
num_workers: usize,
/// Block size for KV cache.
kv_block_size: u32,
/// Handles to worker threads for joining on shutdown.
thread_handles: Mutex<Vec<JoinHandle<()>>>,
}
impl<T: SyncIndexer> ThreadPoolIndexer<T> {
/// Create a new `ThreadPoolIndexer` wrapping the given backend.
///
/// Spawns `num_workers` OS threads, each running a blocking recv loop
/// that processes events by calling `backend.apply_event()`.
///
/// # Arguments
///
/// * `backend` - The thread-safe data structure to wrap
/// * `num_workers` - Number of worker threads for event processing
/// * `kv_block_size` - Block size for KV cache
///
/// # Panics
///
/// Panics if `num_workers` is 0.
pub fn new(backend: T, num_workers: usize, kv_block_size: u32) -> Self {
assert!(num_workers > 0, "Number of workers must be greater than 0");
let backend = Arc::new(backend);
let mut worker_event_senders = Vec::new();
let mut thread_handles = Vec::new();
for _ in 0..num_workers {
let (event_sender, event_receiver) = flume::unbounded::<Option<RouterEvent>>();
worker_event_senders.push(event_sender);
let backend = Arc::clone(&backend);
let handle = std::thread::spawn(move || {
while let Ok(Some(event)) = event_receiver.recv() {
if let Err(e) = backend.apply_event(event) {
tracing::warn!("Failed to apply event: {:?}", e);
}
}
tracing::debug!("Worker thread shutting down");
});
thread_handles.push(handle);
}
Self {
backend,
worker_assignments: DashMap::new(),
worker_assignment_count: AtomicUsize::new(0),
worker_event_channels: worker_event_senders,
num_workers,
kv_block_size,
thread_handles: Mutex::new(thread_handles),
}
}
/// Get a reference to the underlying backend.
pub fn backend(&self) -> &T {
&self.backend
}
/// Wait for all worker channels to drain.
///
/// Used primarily for testing and benchmarking to ensure all queued events
/// have been picked up by workers before checking results.
pub async fn flush(&self) {
loop {
let all_empty = self.worker_event_channels.iter().all(|ch| ch.is_empty());
if all_empty {
break;
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
}
}
#[async_trait]
impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
// Execute inline on caller's thread - no channel dispatch
Ok(self.backend.find_matches(&sequence, false))
}
async fn find_matches_for_request(
&self,
tokens: &[u32],
) -> Result<OverlapScores, KvRouterError> {
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None);
Ok(self.backend.find_matches(&sequence, false))
}
async fn apply_event(&self, event: RouterEvent) {
let worker_id = event.worker_id;
// Get or assign worker thread index using sticky round-robin
let thread_idx = *self.worker_assignments.entry(worker_id).or_insert_with(|| {
let idx = self
.worker_assignment_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
idx % self.num_workers
});
// Send event to the assigned worker thread
if let Err(e) = self.worker_event_channels[thread_idx].send(Some(event)) {
tracing::error!(
"Failed to send event to worker thread {}: {:?}",
thread_idx,
e
);
}
}
async fn remove_worker(&self, worker_id: WorkerId) {
// Execute inline - the backend is thread-safe
self.backend.remove_worker(worker_id);
}
fn shutdown(&self) {
// Send shutdown signal (None) to all worker threads
for channel in self.worker_event_channels.iter() {
let _ = channel.send(None);
}
// Take ownership of thread handles and join them
let handles = std::mem::take(
&mut *self
.thread_handles
.lock()
.expect("thread_handles mutex poisoned"),
);
for handle in handles {
if let Err(e) = handle.join() {
tracing::error!("Worker thread panicked during shutdown: {:?}", e);
}
}
}
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
// Execute inline - the backend is thread-safe
Ok(self.backend.dump_events())
}
async fn process_routing_decision_for_request(
&self,
_tokens_with_hashes: &mut TokensWithHashes,
_worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
// No-op: pruning not supported in ThreadPoolIndexer
Ok(())
}
async fn flush(&self) -> usize {
let curr_size: usize = self.worker_event_channels.iter().map(|ch| ch.len()).sum();
loop {
let all_empty = self.worker_event_channels.iter().all(|ch| ch.is_empty());
if all_empty {
break;
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
curr_size
}
}
/// A request to process a routing decision.
......@@ -834,15 +980,15 @@ impl KvIndexerInterface for KvIndexer {
self.find_matches(sequence).await
}
async fn apply_event(&mut self, event: RouterEvent) {
async fn apply_event(&self, event: RouterEvent) {
self.event_tx.send(event).await.unwrap();
}
async fn remove_worker(&mut self, worker: WorkerId) {
async fn remove_worker(&self, worker: WorkerId) {
self.remove_worker_tx.send(worker).await.unwrap();
}
fn shutdown(&mut self) {
fn shutdown(&self) {
self.cancel.cancel();
}
......@@ -871,6 +1017,16 @@ impl KvIndexerInterface for KvIndexer {
self.process_routing_decision_internal(worker, local_hashes, sequence_hashes)
.await
}
async fn flush(&self) -> usize {
let curr_size = self.event_tx.max_capacity() - self.event_tx.capacity();
loop {
if self.event_tx.capacity() == self.event_tx.max_capacity() {
break;
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
curr_size
}
}
impl KvIndexer {
......@@ -1152,16 +1308,16 @@ impl KvIndexerInterface for LocalKvIndexer {
self.indexer.find_matches_for_request(tokens).await
}
async fn apply_event(&mut self, event: RouterEvent) {
async fn apply_event(&self, event: RouterEvent) {
// Use the buffering version
let _ = self.apply_event_with_buffer(event).await;
}
async fn remove_worker(&mut self, worker: WorkerId) {
async fn remove_worker(&self, worker: WorkerId) {
let _ = self.indexer.remove_worker_sender().send(worker).await;
}
fn shutdown(&mut self) {
fn shutdown(&self) {
// Note: Since indexer is Arc<KvIndexer>, we can't call mutable methods directly.
// The indexer will be shut down when the CancellationToken is cancelled
// or when the last Arc reference is dropped.
......@@ -1182,6 +1338,10 @@ impl KvIndexerInterface for LocalKvIndexer {
.process_routing_decision_for_request(tokens_with_hashes, worker)
.await
}
async fn flush(&self) -> usize {
self.indexer.flush().await
}
}
#[derive(Debug, Clone)]
......@@ -1228,15 +1388,15 @@ pub struct KvIndexerSharded {
cancel: CancellationToken,
/// The size of the KV block this indexer can handle.
kv_block_size: u32,
worker_assignments: HashMap<WorkerId, usize>,
worker_counts: Vec<usize>,
worker_assignments: DashMap<WorkerId, usize>,
worker_counts: Arc<Mutex<Vec<usize>>>,
event_tx: Vec<mpsc::Sender<RouterEvent>>,
request_broadcast_tx: broadcast::Sender<ShardedMatchRequest>,
remove_worker_tx: Vec<mpsc::Sender<WorkerId>>,
dump_tx: Vec<mpsc::Sender<DumpRequest>>,
routing_tx: Vec<mpsc::Sender<RoutingDecisionRequest>>,
tasks: Vec<JoinHandle<()>>,
tasks: Arc<Mutex<Vec<JoinHandle<()>>>>,
}
impl KvIndexerSharded {
......@@ -1261,15 +1421,15 @@ impl KvIndexerSharded {
metrics: Arc<KvIndexerMetrics>,
prune_config: Option<PruneConfig>,
) -> Self {
let worker_assignments: HashMap<WorkerId, usize> = HashMap::new();
let worker_counts: Vec<usize> = vec![0; num_shards];
let worker_assignments = DashMap::new();
let worker_counts = Arc::new(Mutex::new(vec![0; num_shards]));
let mut event_tx = Vec::new();
let mut remove_worker_tx = Vec::new();
let mut get_workers_tx = Vec::new();
let mut dump_tx = Vec::new();
let mut routing_tx = Vec::new();
let mut tasks = Vec::new();
let tasks = Arc::new(Mutex::new(Vec::new()));
let (request_broadcast_tx, _) = broadcast::channel::<ShardedMatchRequest>(1048576);
......@@ -1299,7 +1459,7 @@ impl KvIndexerSharded {
.build()
.unwrap();
tasks.push(std::thread::spawn(move || {
tasks.lock().unwrap().push(std::thread::spawn(move || {
runtime.block_on(async move {
let mut trie = RadixTree::new_with_frequency(expiration_duration);
......@@ -1602,41 +1762,42 @@ impl KvIndexerInterface for KvIndexerSharded {
self.find_matches(sequence).await
}
async fn apply_event(&mut self, event: RouterEvent) {
#[allow(clippy::map_entry)]
if !self.worker_assignments.contains_key(&event.worker_id) {
async fn apply_event(&self, event: RouterEvent) {
let shard = self
.worker_assignments
.entry(event.worker_id)
.or_insert_with(|| {
// Get the shard with the smallest amount of workers.
let selected_shard = self
.worker_counts
let worker_counts = self.worker_counts.lock().unwrap();
let selected_shard = worker_counts
.iter()
.enumerate()
.min_by_key(|&(_, value)| value)
.unwrap()
.0;
drop(worker_counts);
self.worker_assignments
.insert(event.worker_id, selected_shard);
self.worker_counts[selected_shard] += 1;
}
// Increment the count for this shard
self.worker_counts.lock().unwrap()[selected_shard] += 1;
selected_shard
});
self.event_tx[self.worker_assignments[&event.worker_id]]
.send(event)
.await
.unwrap();
self.event_tx[*shard].send(event).await.unwrap();
}
async fn remove_worker(&mut self, worker: WorkerId) {
if let Some((_, shard)) = self.worker_assignments.remove_entry(&worker) {
self.worker_counts[shard] -= 1;
async fn remove_worker(&self, worker: WorkerId) {
if let Some((_, shard)) = self.worker_assignments.remove(&worker) {
self.worker_counts.lock().unwrap()[shard] -= 1;
self.remove_worker_tx[shard].send(worker).await.unwrap();
}
}
/// Shutdown the KV Indexer.
fn shutdown(&mut self) {
fn shutdown(&self) {
self.cancel.cancel();
while !self.tasks.is_empty() {
self.tasks.pop().unwrap().join().unwrap();
let mut tasks = self.tasks.lock().unwrap();
while !tasks.is_empty() {
tasks.pop().unwrap().join().unwrap();
}
}
......@@ -1680,6 +1841,25 @@ impl KvIndexerInterface for KvIndexerSharded {
self.process_routing_decision_internal(worker, local_hashes, sequence_hashes)
.await
}
async fn flush(&self) -> usize {
let curr_size = self
.event_tx
.iter()
.map(|tx| tx.max_capacity() - tx.capacity())
.sum();
loop {
if self
.event_tx
.iter()
.all(|tx| tx.capacity() == tx.max_capacity())
{
break;
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
curr_size
}
}
impl KvIndexerSharded {
......@@ -1694,8 +1874,8 @@ impl KvIndexerSharded {
let shard_idx = self
.worker_assignments
.get(&worker.worker_id)
.copied()
.unwrap_or(0);
.map(|shard_idx| *shard_idx)
.unwrap_or_default();
self.routing_tx[shard_idx]
.send(RoutingDecisionRequest {
......@@ -1718,169 +1898,1370 @@ impl Drop for KvIndexerSharded {
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
use crate::concurrent_radix_tree::ConcurrentRadixTree;
use crate::nested_map::PositionalIndexer;
use crate::protocols::{ExternalSequenceBlockHash, LocalBlockHash, compute_seq_hash_for_block};
use rstest::rstest;
use rstest_reuse::{self, *};
use std::time::Instant;
use tokio::time;
use tokio_util::sync::CancellationToken;
fn setup() {
// Logging init removed to avoid dynamo-runtime dependency
// ============================================================================
// Helper functions
// ============================================================================
/// Create a store event with proper sequence hashes computed from local hashes.
fn make_store_event(worker_id: u64, local_hashes: &[u64]) -> RouterEvent {
make_store_event_with_dp_rank(worker_id, local_hashes, 0)
}
/// Create a store event with a specific dp_rank.
fn make_store_event_with_dp_rank(
worker_id: u64,
local_hashes: &[u64],
dp_rank: u32,
) -> RouterEvent {
make_store_event_full(worker_id, local_hashes, dp_rank, None)
}
fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
/// Create a store event with parent hash for continuation sequences.
/// `prefix_hashes` are the hashes of the prefix (to compute parent_hash).
/// `local_hashes` are the new blocks being stored.
fn make_store_event_with_parent(
worker_id: u64,
prefix_hashes: &[u64],
local_hashes: &[u64],
) -> RouterEvent {
// Compute the parent hash from the prefix
let prefix_block_hashes: Vec<LocalBlockHash> =
prefix_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let prefix_seq_hashes = compute_seq_hash_for_block(&prefix_block_hashes);
let parent_hash = prefix_seq_hashes
.last()
.map(|&h| ExternalSequenceBlockHash(h));
// Compute the full sequence including prefix for proper seq_hash calculation
let full_hashes: Vec<u64> = prefix_hashes
.iter()
.chain(local_hashes.iter())
.copied()
.collect();
let full_block_hashes: Vec<LocalBlockHash> =
full_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let full_seq_hashes = compute_seq_hash_for_block(&full_block_hashes);
// Only include the new blocks (skip prefix)
let new_block_hashes: Vec<LocalBlockHash> =
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let new_seq_hashes = &full_seq_hashes[prefix_hashes.len()..];
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: new_block_hashes
.iter()
.map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100),
.zip(new_seq_hashes.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect()
.collect(),
}),
dp_rank: 0,
},
}
}
fn add_blocks(
hashes: Vec<u64>,
/// Create a store event with all options.
fn make_store_event_full(
worker_id: u64,
local_hashes: &[u64],
dp_rank: u32,
parent_hash: Option<ExternalSequenceBlockHash>,
) -> KvCacheEventData {
KvCacheEventData::Stored(KvCacheStoreData {
) -> RouterEvent {
let local_block_hashes: Vec<LocalBlockHash> =
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let seq_hashes = compute_seq_hash_for_block(&local_block_hashes);
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: make_blocks(hashes),
blocks: local_block_hashes
.iter()
.zip(seq_hashes.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank,
},
}
}
/// Create a remove event for blocks with given local hashes.
fn make_remove_event(worker_id: u64, local_hashes: &[u64]) -> RouterEvent {
make_remove_event_with_dp_rank(worker_id, local_hashes, 0)
}
fn create_store_event(
worker_id: WorkerId,
event_id: u64,
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
/// Create a remove event with a specific dp_rank.
fn make_remove_event_with_dp_rank(
worker_id: u64,
local_hashes: &[u64],
dp_rank: u32,
) -> RouterEvent {
let local_block_hashes: Vec<LocalBlockHash> =
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let seq_hashes = compute_seq_hash_for_block(&local_block_hashes);
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: seq_hashes
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect(),
}),
dp_rank,
},
}
}
fn make_indexer(
token: &CancellationToken,
num_shards: usize,
kv_block_size: u32,
) -> Box<dyn KvIndexerInterface> {
let metrics = KvIndexerMetrics::new_unregistered();
if num_shards == 1 {
Box::new(KvIndexer::new(token.clone(), kv_block_size, metrics.into()))
} else {
Box::new(KvIndexerSharded::new(
token.clone(),
num_shards,
kv_block_size,
metrics.into(),
))
/// Create a clear event for a worker.
fn make_clear_event(worker_id: u64) -> RouterEvent {
make_clear_event_with_dp_rank(worker_id, 0)
}
/// Create a clear event with a specific dp_rank.
fn make_clear_event_with_dp_rank(worker_id: u64, dp_rank: u32) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Cleared,
dp_rank,
},
}
}
// ============================================================================
// KvIndexerInterface tests - parametrized over all implementations
// ============================================================================
#[template]
#[rstest]
fn indexer_template(
#[values(1, 3, 8)] num_shards: usize,
#[values(11, 32, 64)] kv_block_size: usize,
) {
}
fn indexer_template(#[values("single", "sharded", "flat", "concurrent")] variant: &str) {}
#[tokio::test]
#[apply(indexer_template)]
async fn test_kv_indexer_new(num_shards: usize, kv_block_size: u32) {
setup();
let token: CancellationToken = CancellationToken::new();
let _ = make_indexer(&token, num_shards, kv_block_size);
fn make_indexer(variant: &str) -> Box<dyn KvIndexerInterface> {
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let kv_block_size = 32;
match variant {
"single" => Box::new(KvIndexer::new(token, kv_block_size, metrics)),
"sharded" => Box::new(KvIndexerSharded::new(token, 4, kv_block_size, metrics)),
"flat" => Box::new(ThreadPoolIndexer::new(
PositionalIndexer::new(32),
4,
kv_block_size,
)),
"concurrent" => Box::new(ThreadPoolIndexer::new(
ConcurrentRadixTree::new(),
4,
kv_block_size,
)),
_ => panic!("Unknown variant: {}", variant),
}
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_find_matches(num_shards: usize, kv_block_size: u32) {
setup();
let token = CancellationToken::new();
let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
async fn test_store_and_find(variant: &str) {
let index = make_indexer(variant);
let sequence = vec![compute_block_hash(b"test data")];
let scores = kv_indexer.find_matches(sequence).await;
// Store a sequence for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Find matches using local hashes
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_partial_match(variant: &str) {
let index = make_indexer(variant);
// Store [1, 2, 3] for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Find matches for [1, 2, 999] - should match first 2 then stop
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(999),
])
.await
.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_remove(variant: &str) {
let index = make_indexer(variant);
// Store sequence for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Remove all blocks
index.apply_event(make_remove_event(0, &[1, 2, 3])).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Find should return nothing
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert!(scores.scores.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_multiple_workers_shared_prefix(variant: &str) {
let index = make_indexer(variant);
// Worker 0 has [1, 2], Worker 1 has [1, 3]
// Since sequence hashes are cumulative, [1] has same hash for both,
// but [1, 2] and [1, 3] have different hashes.
index.apply_event(make_store_event(0, &[1, 2])).await;
index.apply_event(make_store_event(1, &[1, 3])).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Query [1] - both workers should match
let scores = index.find_matches(vec![LocalBlockHash(1)]).await.unwrap();
assert_eq!(scores.scores.len(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
// Query [1, 2] - worker 0 matches both, worker 1 matches only first block
let scores = index
.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)])
.await
.unwrap();
assert_eq!(scores.scores.len(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_remove_worker(variant: &str) {
let index = make_indexer(variant);
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.apply_event(make_store_event(1, &[1, 2, 3])).await;
// Allow time for async event processing
tokio::time::sleep(Duration::from_millis(100)).await;
index.remove_worker(0).await;
// Allow time for async remove_worker processing
tokio::time::sleep(Duration::from_millis(100)).await;
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_large_stores(variant: &str) {
let index = make_indexer(variant);
// Test sequences of increasing sizes
for i in 0..10u64 {
let len = 1 << i; // 1, 2, 4, 8, ..., 512
let worker_id = i;
let sequence: Vec<u64> = (1..=len).map(|x| x + (i * 10000)).collect();
index
.apply_event(make_store_event(worker_id, &sequence))
.await;
}
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify we can find matches for the last stored sequence
let last_seq: Vec<LocalBlockHash> = (1..=512u64)
.map(|x| LocalBlockHash(x + (9 * 10000)))
.collect();
let scores = index.find_matches(last_seq).await.unwrap();
assert!(!scores.scores.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_dump_and_restore(variant: &str) {
let index = make_indexer(variant);
// Store some data
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.apply_event(make_store_event(1, &[1, 2, 4])).await;
// Allow background worker threads to process events.
tokio::time::sleep(Duration::from_millis(100)).await;
// Dump the tree as events
let events = index.dump_events().await.unwrap();
assert!(!events.is_empty());
// Create a new index and replay events
let restored = make_indexer(variant);
for event in events {
restored.apply_event(event).await;
}
// Allow background worker threads to process replayed events.
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify find_matches produces same results
let original_scores = index
.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)])
.await
.unwrap();
let restored_scores = restored
.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)])
.await
.unwrap();
assert_eq!(original_scores.scores, restored_scores.scores);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_clear_all_blocks(variant: &str) {
let index = make_indexer(variant);
// Store some data for two workers
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.apply_event(make_store_event(1, &[1, 2, 3])).await;
// Clear worker 0's blocks using the Cleared event
index.apply_event(make_clear_event(0)).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Worker 0's blocks should be gone, worker 1's remain
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_empty_query(variant: &str) {
let index = make_indexer(variant);
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.flush().await;
// Empty query should return empty scores
let scores = index.find_matches(vec![]).await.unwrap();
assert!(scores.scores.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_miss_query(variant: &str) {
let index = make_indexer(variant);
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.flush().await;
// Query for non-existent blocks
let scores = index
.find_matches(vec![LocalBlockHash(999), LocalBlockHash(998)])
.await
.unwrap();
assert!(scores.scores.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_shutdown(variant: &str) {
let index = make_indexer(variant);
index.shutdown();
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_find_matches_for_request(variant: &str) {
let index = make_indexer(variant);
// Empty index should return no matches
let tokens = vec![1, 2, 3, 4];
let scores = index.find_matches_for_request(&tokens).await.unwrap();
assert!(scores.scores.is_empty());
// Store some data and verify we can find it via tokens
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Allow time for async processing
index.flush().await;
// Note: find_matches_for_request computes block hashes from tokens,
// so we need tokens that hash to the same LocalBlockHash values.
// For this test, we just verify the method works without error.
let scores = index.find_matches_for_request(&tokens).await.unwrap();
// The tokens [1,2,3,4] won't match our stored [1,2,3] local hashes
// because find_matches_for_request computes different hashes from raw tokens
assert!(scores.scores.is_empty() || !scores.scores.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_process_routing_decision(variant: &str) {
let index = make_indexer(variant);
// Create tokens with hashes
let tokens = vec![1u32, 2, 3, 4, 5, 6, 7, 8];
let mut tokens_with_hashes = TokensWithHashes::new(tokens, 32);
let worker = WorkerWithDpRank::new(0, 0);
// Process routing decision - should not error
let result = index
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await;
assert!(result.is_ok());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_parent_hash_chains(variant: &str) {
let index = make_indexer(variant);
// Store initial sequence [1, 2, 3]
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Store continuation [4, 5] with parent pointing to block 3
index
.apply_event(make_store_event_with_parent(0, &[1, 2, 3], &[4, 5]))
.await;
index.flush().await;
// Query for full sequence [1, 2, 3, 4, 5] should match all 5 blocks
let full_seq: Vec<LocalBlockHash> = (1..=5).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(full_seq).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 5);
// Query for just [1, 2, 3] should match 3 blocks
let prefix_seq: Vec<LocalBlockHash> = (1..=3).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(prefix_seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_multiple_dp_ranks(variant: &str) {
let index = make_indexer(variant);
// Same worker_id but different dp_ranks should be tracked separately
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 0))
.await;
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 1))
.await;
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 2))
.await;
index.flush().await;
// Query should return all 3 dp_ranks as separate entries
let seq: Vec<LocalBlockHash> = (1..=3).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(scores.scores.len(), 3);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 1)).unwrap(), 3);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 2)).unwrap(), 3);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_partial_block_removal(variant: &str) {
let index = make_indexer(variant);
// Store [1, 2, 3]
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify all 3 blocks match
let seq: Vec<LocalBlockHash> = (1..=3).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(seq.clone()).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
// Remove only the last block (block 3)
// To do this correctly, we need to compute the seq_hash for block 3 specifically,
// which requires the full sequence context [1,2,3].
let full_hashes: Vec<LocalBlockHash> = (1..=3).map(|i| LocalBlockHash(i)).collect();
let seq_hashes = compute_seq_hash_for_block(&full_hashes);
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]); // Last block's hash
let remove_event = RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![block_3_seq_hash],
}),
dp_rank: 0,
},
};
index.apply_event(remove_event).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Query [1, 2, 3] - should only match 2 blocks now (block 3 is removed)
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
// Query [1, 2] - should still match 2 blocks
let partial_seq: Vec<LocalBlockHash> = (1..=2).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(partial_seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_remove_nonexistent_worker(variant: &str) {
let index = make_indexer(variant);
// Store data for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Remove non-existent worker 999 - should not error or affect worker 0
index.remove_worker(999).await;
// Allow time for async processing
tokio::time::sleep(Duration::from_millis(100)).await;
// Worker 0's data should still be there
let seq: Vec<LocalBlockHash> = (1..=3).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_remove_nonexistent_blocks(variant: &str) {
let index = make_indexer(variant);
// Store [1, 2, 3]
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Try to remove blocks [999, 998] that don't exist - should not error
index.apply_event(make_remove_event(0, &[999, 998])).await;
index.flush().await;
// Original data should still be there
let seq: Vec<LocalBlockHash> = (1..=3).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_clear_then_reuse(variant: &str) {
let index = make_indexer(variant);
// Store initial data
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Clear the worker
index.apply_event(make_clear_event(0)).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify data is gone
let seq: Vec<LocalBlockHash> = (1..=3).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(seq.clone()).await.unwrap();
assert!(scores.scores.is_empty());
// Store new data for the same worker
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify new data is accessible
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_multiple_sequences_per_worker(variant: &str) {
let index = make_indexer(variant);
// Store two disjoint sequences for the same worker
// Sequence 1: [1, 2, 3]
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Sequence 2: [100, 101, 102] (completely different, no parent)
index
.apply_event(make_store_event(0, &[100, 101, 102]))
.await;
index.flush().await;
// Query first sequence
let seq1: Vec<LocalBlockHash> = (1..=3).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(seq1).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
// Query second sequence
let seq2: Vec<LocalBlockHash> = (100..=102).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(seq2).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
// Query a mix that doesn't exist as a sequence - should only match first block
let mixed: Vec<LocalBlockHash> = vec![LocalBlockHash(1), LocalBlockHash(100)];
let scores = index.find_matches(mixed).await.unwrap();
// Only block 1 matches because [1, 100] is not a valid prefix
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 1);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_clear_clears_all_dp_ranks(variant: &str) {
let index = make_indexer(variant);
// Store same sequence for different dp_ranks
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 0))
.await;
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 1))
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify both dp_ranks are present
let seq: Vec<LocalBlockHash> = (1..=3).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(seq.clone()).await.unwrap();
assert_eq!(scores.scores.len(), 2);
// Clear event clears ALL blocks for the worker_id, regardless of dp_rank
index.apply_event(make_clear_event_with_dp_rank(0, 0)).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Both dp_ranks should be cleared
let scores = index.find_matches(seq).await.unwrap();
assert!(
scores.scores.is_empty(),
"Cleared event should clear all dp_ranks for a worker"
);
}
// ============================================================================
// Long sequence tests - especially important for NestedMap/PositionalIndexer
// ============================================================================
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_single_store(variant: &str) {
let index = make_indexer(variant);
// Store a long sequence (128 blocks) in a single event
let seq_len = 128;
let sequence: Vec<u64> = (1..=seq_len).collect();
index.apply_event(make_store_event(0, &sequence)).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Query full sequence - should match all blocks
let full_query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
seq_len as u32
);
// Query prefix (first 64 blocks)
let prefix_query: Vec<LocalBlockHash> = (1..=64).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(prefix_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
64
);
// Query with divergence at position 50
let mut divergent_query: Vec<LocalBlockHash> =
(1..=100).map(|i| LocalBlockHash(i)).collect();
divergent_query[49] = LocalBlockHash(99999); // Position 49 (0-indexed) diverges
let scores = index.find_matches(divergent_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
49
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_multiple_continuations(variant: &str) {
let index = make_indexer(variant);
// Build a long sequence through multiple continuations
// First store: blocks 1-50
let first_chunk: Vec<u64> = (1..=50).collect();
index.apply_event(make_store_event(0, &first_chunk)).await;
// Second store: blocks 51-100 (continuation of first)
let second_chunk: Vec<u64> = (51..=100).collect();
index
.apply_event(make_store_event_with_parent(0, &first_chunk, &second_chunk))
.await;
// Third store: blocks 101-150 (continuation of second)
let prefix_1_2: Vec<u64> = (1..=100).collect();
let third_chunk: Vec<u64> = (101..=150).collect();
index
.apply_event(make_store_event_with_parent(0, &prefix_1_2, &third_chunk))
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Query full sequence - should match all 150 blocks
let full_query: Vec<LocalBlockHash> = (1..=150).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
150
);
// Query crossing continuation boundaries
let cross_boundary_query: Vec<LocalBlockHash> =
(45..=105).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(cross_boundary_query).await.unwrap();
// Query starts at block 45, but stored sequence starts at 1, so this won't match
// because the sequence hash at position 0 of our query (block 45) won't match
// the stored sequence hash at position 0 (block 1)
assert!(
scores.scores.is_empty() || scores.scores.get(&WorkerWithDpRank::new(0, 0)).is_none()
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_branching_continuations(variant: &str) {
let index = make_indexer(variant);
// Common prefix: blocks 1-30
let common_prefix: Vec<u64> = (1..=30).collect();
index.apply_event(make_store_event(0, &common_prefix)).await;
// Branch A: blocks 31-60 on worker 0
let branch_a: Vec<u64> = (31..=60).collect();
index
.apply_event(make_store_event_with_parent(0, &common_prefix, &branch_a))
.await;
// Branch B: blocks 131-160 (different content) on worker 1
// First store the common prefix for worker 1
index.apply_event(make_store_event(1, &common_prefix)).await;
let branch_b: Vec<u64> = (131..=160).collect();
index
.apply_event(make_store_event_with_parent(1, &common_prefix, &branch_b))
.await;
index.flush().await;
// Query common prefix - both workers should match
let prefix_query: Vec<LocalBlockHash> = (1..=30).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(prefix_query).await.unwrap();
assert_eq!(scores.scores.len(), 2);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
30
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
30
);
// Query branch A path - only worker 0 should match fully
let branch_a_query: Vec<LocalBlockHash> = (1..=60).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(branch_a_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
60
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
30
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_partial_removal(variant: &str) {
let index = make_indexer(variant);
// Store a long sequence
let sequence: Vec<u64> = (1..=100).collect();
index.apply_event(make_store_event(0, &sequence)).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify full match
let full_query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(full_query.clone()).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
// Remove blocks 80-100 (the tail)
let tail_hashes: Vec<LocalBlockHash> = (1..=100).map(|i| LocalBlockHash(i)).collect();
let seq_hashes = compute_seq_hash_for_block(&tail_hashes);
let remove_hashes: Vec<ExternalSequenceBlockHash> = seq_hashes[79..100]
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect();
assert!(scores.unwrap().scores.is_empty());
let remove_event = RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: remove_hashes,
}),
dp_rank: 0,
},
};
index.apply_event(remove_event).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Query should now only match first 79 blocks
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
79
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_find_matches_for_request(num_shards: usize, kv_block_size: u32) {
setup();
let token = CancellationToken::new();
let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
async fn test_long_sequence_interleaved_workers(variant: &str) {
let index = make_indexer(variant);
// Multiple workers storing overlapping long sequences concurrently
// Worker 0: blocks 1-100
// Worker 1: blocks 1-75
// Worker 2: blocks 1-50
// Worker 3: blocks 1-25
let seq_100: Vec<u64> = (1..=100).collect();
let seq_75: Vec<u64> = (1..=75).collect();
let seq_50: Vec<u64> = (1..=50).collect();
let seq_25: Vec<u64> = (1..=25).collect();
index.apply_event(make_store_event(0, &seq_100)).await;
index.apply_event(make_store_event(1, &seq_75)).await;
index.apply_event(make_store_event(2, &seq_50)).await;
index.apply_event(make_store_event(3, &seq_25)).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Query for 60 blocks - workers 0,1 match 60, worker 2 matches 50, worker 3 matches 25
let query_60: Vec<LocalBlockHash> = (1..=60).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_60).await.unwrap();
assert_eq!(scores.scores.len(), 4);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
60
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
60
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
50
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(3, 0)).unwrap(),
25
);
}
let tokens = vec![1, 2, 3, 4];
let scores = kv_indexer.find_matches_for_request(&tokens).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_exact_jump_size_boundaries(variant: &str) {
let index = make_indexer(variant);
// Test sequences that align exactly with jump_size boundaries (32 for PositionalIndexer)
// This tests edge cases in the jump search algorithm
// Store sequence of exactly 32 blocks
let seq_32: Vec<u64> = (1..=32).collect();
index.apply_event(make_store_event(0, &seq_32)).await;
// Store sequence of exactly 64 blocks (2x jump_size)
let seq_64: Vec<u64> = (1001..=1064).collect();
index.apply_event(make_store_event(1, &seq_64)).await;
// Store sequence of exactly 96 blocks (3x jump_size)
let seq_96: Vec<u64> = (2001..=2096).collect();
index.apply_event(make_store_event(2, &seq_96)).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify all sequences match correctly
let query_32: Vec<LocalBlockHash> = seq_32.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_32).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
32
);
assert!(scores.unwrap().scores.is_empty());
let query_64: Vec<LocalBlockHash> = seq_64.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_64).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
64
);
let query_96: Vec<LocalBlockHash> = seq_96.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_96).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
96
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_apply_event(num_shards: usize, kv_block_size: u32) {
setup();
let worker_id = 0;
async fn test_long_sequence_off_by_one_jump_boundaries(variant: &str) {
let index = make_indexer(variant);
let token = CancellationToken::new();
let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
// Test sequences at jump_size +/- 1 boundaries to catch off-by-one errors
let seq_31: Vec<u64> = (1..=31).collect();
let seq_33: Vec<u64> = (101..=133).collect();
let seq_63: Vec<u64> = (201..=263).collect();
let seq_65: Vec<u64> = (301..=365).collect();
let event = create_store_event(worker_id, 1, vec![1, 2, 3], None);
kv_indexer.apply_event(event).await;
index.apply_event(make_store_event(0, &seq_31)).await;
index.apply_event(make_store_event(1, &seq_33)).await;
index.apply_event(make_store_event(2, &seq_63)).await;
index.apply_event(make_store_event(3, &seq_65)).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify all sequences match correctly
let query_31: Vec<LocalBlockHash> = seq_31.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_31).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
31
);
let query_33: Vec<LocalBlockHash> = seq_33.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_33).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
33
);
let query_63: Vec<LocalBlockHash> = seq_63.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_63).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
63
);
// No assertion here, just ensuring it runs without panic
let query_65: Vec<LocalBlockHash> = seq_65.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_65).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(3, 0)).unwrap(),
65
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_shutdown(num_shards: usize, kv_block_size: u32) {
setup();
let token = CancellationToken::new();
let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
async fn test_long_sequence_divergence_at_jump_boundaries(variant: &str) {
let index = make_indexer(variant);
kv_indexer.shutdown();
// Store a long sequence
let sequence: Vec<u64> = (1..=128).collect();
index.apply_event(make_store_event(0, &sequence)).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Test divergence exactly at jump boundaries (position 31, 32, 33, 63, 64, 65)
for diverge_pos in [31usize, 32, 33, 63, 64, 65, 95, 96, 97] {
let mut query: Vec<LocalBlockHash> = (1..=128).map(|i| LocalBlockHash(i)).collect();
query[diverge_pos] = LocalBlockHash(99999);
let scores = index.find_matches(query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
diverge_pos as u32,
"Divergence at position {} should match {} blocks",
diverge_pos,
diverge_pos
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_frequency(num_shards: usize, kv_block_size: u32) {
const ONE_MILLIS: Duration = Duration::from_millis(1);
async fn test_long_sequence_deep_continuation_chain(variant: &str) {
let index = make_indexer(variant);
// Build a very long sequence through many small continuations
// This tests the parent_hash chain handling
let chunk_size = 10;
let num_chunks = 20; // Total 200 blocks
let mut full_prefix: Vec<u64> = Vec::new();
for chunk_idx in 0..num_chunks {
let chunk_start = chunk_idx * chunk_size + 1;
let chunk: Vec<u64> = (chunk_start..chunk_start + chunk_size)
.map(|x| x as u64)
.collect();
if chunk_idx == 0 {
index.apply_event(make_store_event(0, &chunk)).await;
} else {
index
.apply_event(make_store_event_with_parent(0, &full_prefix, &chunk))
.await;
}
full_prefix.extend(&chunk);
}
tokio::time::sleep(Duration::from_millis(100)).await;
// Query full sequence
let full_query: Vec<LocalBlockHash> = (1..=200).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
200
);
// Query partial prefix crossing multiple chunk boundaries
let partial_query: Vec<LocalBlockHash> = (1..=75).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(partial_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
75
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_clear_and_rebuild(variant: &str) {
let index = make_indexer(variant);
// Store a long sequence
let sequence: Vec<u64> = (1..=100).collect();
index.apply_event(make_store_event(0, &sequence)).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify it's stored
let query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query.clone()).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
// Clear the worker
index.apply_event(make_clear_event(0)).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify it's cleared
let scores = index.find_matches(query.clone()).await.unwrap();
assert!(scores.scores.is_empty());
// Rebuild with a different sequence
let new_sequence: Vec<u64> = (1001..=1100).collect();
index.apply_event(make_store_event(0, &new_sequence)).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify new sequence works
let new_query: Vec<LocalBlockHash> =
new_sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(new_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
// Verify old sequence no longer matches
let scores = index.find_matches(query).await.unwrap();
assert!(scores.scores.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_multiple_workers_diverging(variant: &str) {
let index = make_indexer(variant);
// Multiple workers with long sequences that share a prefix then diverge
// This tests precise drain point tracking across workers
// All workers share prefix 1-40
let shared_prefix: Vec<u64> = (1..=40).collect();
// Worker 0: prefix + 41-100 (stores full sequence 1-100)
let worker_0_full: Vec<u64> = (1..=100).collect();
// Worker 1: prefix + 141-180 (diverges at block 41)
let worker_1_suffix: Vec<u64> = (141..=180).collect();
// Worker 2: prefix + 241-300 (diverges at block 41)
let worker_2_suffix: Vec<u64> = (241..=300).collect();
// Store for all workers
index.apply_event(make_store_event(0, &worker_0_full)).await;
index.apply_event(make_store_event(1, &shared_prefix)).await;
index
.apply_event(make_store_event_with_parent(
1,
&shared_prefix,
&worker_1_suffix,
))
.await;
index.apply_event(make_store_event(2, &shared_prefix)).await;
index
.apply_event(make_store_event_with_parent(
2,
&shared_prefix,
&worker_2_suffix,
))
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Query 1-100 - worker 0 matches 100, workers 1&2 match 40
let query: Vec<LocalBlockHash> = worker_0_full.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
40
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
40
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_staggered_lengths(variant: &str) {
let index = make_indexer(variant);
// Workers with sequences of staggered lengths to test drain tracking
// Worker 0: 10 blocks
// Worker 1: 20 blocks
// Worker 2: 35 blocks (just past first jump)
// Worker 3: 64 blocks (exactly 2 jumps)
// Worker 4: 100 blocks
for (worker_id, len) in [(0, 10), (1, 20), (2, 35), (3, 64), (4, 100)] {
let sequence: Vec<u64> = (1..=len).collect();
index
.apply_event(make_store_event(worker_id, &sequence))
.await;
}
tokio::time::sleep(Duration::from_millis(100)).await;
// Query for 100 blocks - each worker should match their stored length
let query: Vec<LocalBlockHash> = (1..=100).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
10
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
20
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
35
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(3, 0)).unwrap(),
64
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(4, 0)).unwrap(),
100
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_very_long_sequence(variant: &str) {
let index = make_indexer(variant);
// Test with a very long sequence (1000 blocks)
let seq_len = 1000u64;
let sequence: Vec<u64> = (1..=seq_len).collect();
index.apply_event(make_store_event(0, &sequence)).await;
tokio::time::sleep(Duration::from_millis(100)).await;
// Full match
let full_query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
seq_len as u32
);
// Partial match (first 500)
let partial_query: Vec<LocalBlockHash> = (1..=500).map(|i| LocalBlockHash(i)).collect();
let scores = index.find_matches(partial_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
500
);
// Divergence in the middle
let mut mid_diverge: Vec<LocalBlockHash> = (1..=1000).map(|i| LocalBlockHash(i)).collect();
mid_diverge[499] = LocalBlockHash(99999);
let scores = index.find_matches(mid_diverge).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
499
);
}
// ============================================================================
// Tests specific to tree-based implementations (KvIndexer, KvIndexerSharded)
// These use features not available in PositionalIndexer
// ============================================================================
setup();
let mut kv_indexer: Box<dyn KvIndexerInterface>;
#[template]
#[rstest]
fn tree_indexer_template(#[values("single", "sharded")] variant: &str) {}
fn make_tree_indexer_with_frequency(
variant: &str,
expiration: Duration,
) -> Box<dyn KvIndexerInterface> {
let token = CancellationToken::new();
let expiration = Duration::from_millis(50);
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let kv_block_size = 32;
if num_shards == 1 {
kv_indexer = Box::new(KvIndexer::new_with_frequency(
match variant {
"single" => Box::new(KvIndexer::new_with_frequency(
token,
Some(expiration),
kv_block_size,
metrics,
None,
));
} else {
kv_indexer = Box::new(KvIndexerSharded::new_with_frequency(
)),
"sharded" => Box::new(KvIndexerSharded::new_with_frequency(
token,
num_shards,
4,
Some(expiration),
kv_block_size,
metrics,
None,
));
)),
_ => panic!("Unknown variant: {}", variant),
}
}
#[tokio::test]
#[apply(tree_indexer_template)]
async fn test_frequency(variant: &str) {
const ONE_MILLIS: Duration = Duration::from_millis(1);
let expiration = Duration::from_millis(50);
let kv_indexer = make_tree_indexer_with_frequency(variant, expiration);
// The blocks
let block_hashes = vec![
LocalBlockHash(1),
......@@ -1897,12 +3278,10 @@ mod tests {
);
// Blocks go in cache
let worker_id = 0;
let event = create_store_event(worker_id, 0, vec![1, 2, 3, 4], None);
let event = make_store_event(0, &[1, 2, 3, 4]);
kv_indexer.apply_event(event).await;
// First access
// The store event is applied async so poll briefly
// First access - poll briefly since store event is applied async
let mut overlap = OverlapScores::default();
let timeout = Duration::from_millis(10);
let start = Instant::now();
......@@ -1937,200 +3316,29 @@ mod tests {
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(
overlap.frequencies.len(),
0,
"Blocks were accessed too long ago"
);
// New second access
let _ = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
// Access only the first three blocks
let overlap = kv_indexer
.find_matches(block_hashes[0..3].to_vec())
.await
.unwrap();
// We see the previous two new accesses
assert_eq!(overlap.frequencies, vec![2, 2, 2]);
// The third access did not touch the last block
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(overlap.frequencies, vec![3, 3, 3, 2]);
}
#[tokio::test]
async fn test_dump_tree_as_events_round_trip() {
setup();
// Configuration
let kv_block_size = 32;
let num_shards = 2;
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
// Build a non-trivial indexer with events
let token1 = CancellationToken::new();
let mut original_indexer =
KvIndexerSharded::new(token1.clone(), num_shards, kv_block_size, metrics.clone());
let worker_0 = 0;
let worker_1 = 1;
let worker_2 = 2;
// Apply events to the original indexer
original_indexer
.apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
.await;
original_indexer
.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.await;
original_indexer
.apply_event(create_store_event(
worker_1,
2,
vec![4, 5],
Some(ExternalSequenceBlockHash(100)),
))
.await;
original_indexer
.apply_event(create_store_event(worker_2, 3, vec![6, 7], None))
.await;
original_indexer
.apply_event(create_store_event(
worker_0,
4,
vec![4],
Some(ExternalSequenceBlockHash(100)),
))
.await;
// Allow some time for events to be processed
tokio::time::sleep(Duration::from_millis(50)).await;
// Dump the original indexer
let dump1 = original_indexer.dump_events().await.unwrap();
println!("Dumped {} events", dump1.len());
// Create a new indexer and apply all dumped events
let token2 = CancellationToken::new();
let mut reconstructed_indexer =
KvIndexerSharded::new(token2.clone(), num_shards, kv_block_size, metrics);
for event in &dump1 {
reconstructed_indexer.apply_event(event.clone()).await;
}
// Allow some time for events to be processed
tokio::time::sleep(Duration::from_millis(50)).await;
// Dump the reconstructed indexer
let dump2 = reconstructed_indexer.dump_events().await.unwrap();
// Sort both dumps for comparison (order might differ due to HashMap iteration and sharding)
let mut sorted_dump1 = dump1.clone();
let mut sorted_dump2 = dump2.clone();
// Sort by (worker_id, tokens_hash, parent_hash)
let sort_key = |event: &RouterEvent| {
if let KvCacheEventData::Stored(ref data) = event.event.data {
(
event.worker_id,
data.blocks.first().map(|b| b.tokens_hash.0).unwrap_or(0),
data.parent_hash.map(|h| h.0).unwrap_or(0),
)
} else {
(event.worker_id, 0, 0)
}
};
sorted_dump1.sort_by_key(sort_key);
sorted_dump2.sort_by_key(sort_key);
// Verify the dumps have the same length
assert_eq!(
sorted_dump1.len(),
sorted_dump2.len(),
"Dumps have different lengths: {} vs {}",
sorted_dump1.len(),
sorted_dump2.len()
);
// Verify each event matches
for (i, (event1, event2)) in sorted_dump1.iter().zip(sorted_dump2.iter()).enumerate() {
assert_eq!(
event1.worker_id, event2.worker_id,
"Event {} worker_id mismatch",
i
);
if let (KvCacheEventData::Stored(data1), KvCacheEventData::Stored(data2)) =
(&event1.event.data, &event2.event.data)
{
assert_eq!(
data1.parent_hash, data2.parent_hash,
"Event {} parent_hash mismatch",
i
);
assert_eq!(
data1.blocks.len(),
data2.blocks.len(),
"Event {} blocks length mismatch",
i
);
for (j, (block1, block2)) in
data1.blocks.iter().zip(data2.blocks.iter()).enumerate()
{
assert_eq!(
block1.tokens_hash, block2.tokens_hash,
"Event {} block {} tokens_hash mismatch",
i, j
);
assert_eq!(
block1.block_hash, block2.block_hash,
"Event {} block {} block_hash mismatch",
i, j
0,
"Blocks were accessed too long ago"
);
}
} else {
panic!("Expected Stored events in both dumps");
}
}
// Also verify that both indexers produce the same match results
for test_seq in [
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
vec![LocalBlockHash(1), LocalBlockHash(4), LocalBlockHash(5)],
vec![LocalBlockHash(6), LocalBlockHash(7)],
vec![LocalBlockHash(1)],
] {
let scores1 = original_indexer
.find_matches(test_seq.clone())
.await
.unwrap();
let scores2 = reconstructed_indexer
.find_matches(test_seq.clone())
// New second access
let _ = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
// Access only the first three blocks
let overlap = kv_indexer
.find_matches(block_hashes[0..3].to_vec())
.await
.unwrap();
// We see the previous two new accesses
assert_eq!(overlap.frequencies, vec![2, 2, 2]);
// Sort the scores to compare
let mut scores1_sorted: Vec<_> = scores1.scores.iter().collect();
let mut scores2_sorted: Vec<_> = scores2.scores.iter().collect();
scores1_sorted.sort_by_key(|(k, _)| *k);
scores2_sorted.sort_by_key(|(k, _)| *k);
assert_eq!(
scores1_sorted, scores2_sorted,
"Match scores differ for sequence {:?}",
test_seq
);
// The third access did not touch the last block
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(overlap.frequencies, vec![3, 3, 3, 2]);
}
// Clean up
original_indexer.shutdown();
reconstructed_indexer.shutdown();
}
// ============================================================================
// KvIndexerMetrics tests
// ============================================================================
#[test]
fn test_increment_event_applied() {
......@@ -2177,8 +3385,11 @@ mod tests {
);
}
// ============================================================================
// LocalKvIndexer tests
fn make_indexer_with_events(ids: &[u64]) -> LocalKvIndexer {
// ============================================================================
fn make_local_indexer_with_events(ids: &[u64]) -> LocalKvIndexer {
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
......@@ -2202,8 +3413,8 @@ mod tests {
}
#[tokio::test]
async fn returns_slice_within_range() {
let indexer = make_indexer_with_events(&[1, 2, 3, 4, 5]);
async fn test_local_indexer_slice_within_range() {
let indexer = make_local_indexer_with_events(&[1, 2, 3, 4, 5]);
// Helper to extract events from response
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
......@@ -2242,20 +3453,19 @@ mod tests {
}
#[tokio::test]
async fn test_get_events_in_id_range_all_cases() {
async fn test_local_indexer_get_events_in_id_range_all_cases() {
// Create indexer with small buffer (5 events max)
// This way older events will only be in the tree, not the buffer
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4, // block_size
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5, // max_buffer_size - only keeps 5 most recent events
5,
);
// Helper to create a test event
let make_event = |id: u64| {
RouterEvent::new(
0, // worker_id
0,
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Stored(KvCacheStoreData {
......@@ -2271,9 +3481,7 @@ mod tests {
)
};
// Add 10 events (IDs 5-14)
// Buffer will only keep the last 5: events 10-14
// Tree will have all blocks
// Add 10 events (IDs 5-14), buffer keeps last 5: events 10-14
for id in 5..15 {
indexer
.apply_event_with_buffer(make_event(id))
......@@ -2281,10 +3489,9 @@ mod tests {
.unwrap();
}
// Wait for events to be processed by the tree
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Wait for events to be processed
tokio::time::sleep(Duration::from_millis(100)).await;
// Helper to extract events from response
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
match resp {
WorkerKvQueryResponse::Events(e) => e,
......@@ -2293,144 +3500,45 @@ mod tests {
}
};
// Helper to extract event IDs from result
let get_ids = |events: Vec<RouterEvent>| -> Vec<u64> {
events.iter().map(|e| e.event.event_id).collect()
};
// Verify buffer state: should have events 10-14 (last 5)
// Verify buffer state
let buffer_events = indexer.get_all_events_in_buffer();
assert_eq!(
get_ids(buffer_events),
vec![10, 11, 12, 13, 14],
"Buffer should have events 10-14"
);
// ========== BUFFER PATH TESTS (start_id >= first_buffered) ==========
// Range is [start, end] inclusive
assert_eq!(get_ids(buffer_events), vec![10, 11, 12, 13, 14]);
// Test: start_id within buffer, no end
// Buffer path tests
let result = indexer.get_events_in_id_range(Some(11), None).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
get_ids(extract_events(result)),
vec![11, 12, 13, 14],
"start_id=11 (in buffer) should return [11, 14]"
);
// Test: start_id at buffer boundary
let result = indexer.get_events_in_id_range(Some(10), None).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
get_ids(extract_events(result)),
vec![10, 11, 12, 13, 14],
"start_id=10 (buffer start) should return [10, 14]"
);
// Test: both start and end within buffer (inclusive)
let result = indexer.get_events_in_id_range(Some(11), Some(13)).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
get_ids(extract_events(result)),
vec![11, 12, 13],
"range [11, 13] inclusive should return 3 events"
);
assert_eq!(get_ids(extract_events(result)), vec![11, 12, 13, 14]);
let result = indexer.get_events_in_id_range(Some(10), Some(14)).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
get_ids(extract_events(result)),
vec![10, 11, 12, 13, 14],
"range [10, 14] should return all buffer events"
);
assert_eq!(get_ids(extract_events(result)), vec![10, 11, 12, 13, 14]);
// ========== TREE DUMP PATH TESTS (range extends before buffer) ==========
// Note: Tree dumps return synthetic 0-indexed event IDs, so we just check
// that we get events back (the IDs won't match original IDs)
// Test: (None, None) dumps entire tree
// Tree dump path tests
let result = indexer.get_events_in_id_range(None, None).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
assert_eq!(
extract_events(result).len(),
10,
"(None, None) should dump entire tree (10 events)"
);
// Test: (None, Some(_)) dumps entire tree
let result = indexer.get_events_in_id_range(None, Some(8)).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
assert_eq!(
extract_events(result).len(),
10,
"(None, Some(_)) dumps entire tree - end_id is ignored for tree dumps"
);
assert_eq!(extract_events(result).len(), 10);
// Test: start_id before buffer triggers tree dump
let result = indexer.get_events_in_id_range(Some(7), None).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
assert_eq!(
extract_events(result).len(),
10,
"start_id=7 (before buffer) should dump entire tree"
);
let result = indexer.get_events_in_id_range(Some(5), Some(12)).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
assert_eq!(
extract_events(result).len(),
10,
"range [5, 12] extending before buffer should dump entire tree"
);
// ========== EDGE CASES ==========
// Single element when start == end (inclusive range)
let result = indexer.get_events_in_id_range(Some(12), Some(12)).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
get_ids(extract_events(result)),
vec![12],
"start == end should return single event"
);
// InvalidRange when start > end
// Edge cases
let result = indexer.get_events_in_id_range(Some(15), Some(10)).await;
assert!(
matches!(result, WorkerKvQueryResponse::InvalidRange { .. }),
"start > end should return InvalidRange"
);
assert!(matches!(result, WorkerKvQueryResponse::InvalidRange { .. }));
// TooNew when start_id is beyond buffer
let result = indexer.get_events_in_id_range(Some(100), Some(200)).await;
assert!(
matches!(result, WorkerKvQueryResponse::TooNew { .. }),
"start_id beyond buffer should return TooNew"
);
// Request with end beyond buffer but valid start -> buffer returns what it has
let result = indexer.get_events_in_id_range(Some(12), Some(100)).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
get_ids(extract_events(result)),
vec![12, 13, 14],
"range with end beyond buffer should return available buffer events"
);
assert!(matches!(result, WorkerKvQueryResponse::TooNew { .. }));
}
#[tokio::test]
async fn test_local_indexer_buffer_and_serialization() {
// Tests components of the LocalKvIndexer query without using nats
let worker_id = 42u64;
// Create a local indexer
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100));
let local_indexer = Arc::new(LocalKvIndexer::new(token, 4, metrics, 100));
// Add events to local indexer's buffer
let test_event_1 = RouterEvent::new(
let test_event = RouterEvent::new(
worker_id,
KvCacheEvent {
event_id: 1,
......@@ -2446,346 +3554,27 @@ mod tests {
},
);
// Apply events with buffer
local_indexer
.apply_event_with_buffer(test_event_1)
.apply_event_with_buffer(test_event)
.await
.unwrap();
// Wait for events to be processed
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
tokio::time::sleep(Duration::from_millis(50)).await;
// Get buffered events (what the query service would return)
let buffered_events = local_indexer.get_all_events_in_buffer();
// Verify buffer contents
assert_eq!(buffered_events.len(), 1, "Buffer should have 1 event");
assert_eq!(buffered_events.len(), 1);
assert_eq!(buffered_events[0].worker_id, worker_id);
assert_eq!(buffered_events[0].event.event_id, 1);
// Build the response that would be sent (Events variant)
let response = WorkerKvQueryResponse::Events(buffered_events.clone());
// Test serialization/deserialization (simulating NATS round-trip)
// Test serialization round-trip
let response = WorkerKvQueryResponse::Events(buffered_events);
let serialized = serde_json::to_vec(&response).unwrap();
let deserialized: WorkerKvQueryResponse = serde_json::from_slice(&serialized).unwrap();
// Verify response correctness
let events = match deserialized {
WorkerKvQueryResponse::Events(e) => e,
_ => panic!("Expected Events variant"),
};
assert_eq!(events.len(), 1);
assert_eq!(events[0].worker_id, worker_id);
assert_eq!(events[0].event.event_id, 1);
// Verify event data
match &events[0].event.data {
KvCacheEventData::Stored(store_data) => {
assert_eq!(store_data.blocks.len(), 1);
assert_eq!(store_data.blocks[0].block_hash.0, 100);
assert_eq!(store_data.blocks[0].tokens_hash.0, 200);
}
_ => panic!("Expected Stored event"),
}
}
}
/// Tests for KvIndex enum (parametrized over RadixTree and FlatHashMap variants).
#[cfg(test)]
mod kv_index_tests {
use super::*;
use crate::protocols::{ExternalSequenceBlockHash, LocalBlockHash, compute_seq_hash_for_block};
use rstest::rstest;
use rstest_reuse::{self, *};
/// Create a store event with proper sequence hashes computed from local hashes.
fn make_store_event(worker_id: u64, local_hashes: &[u64]) -> RouterEvent {
let local_block_hashes: Vec<LocalBlockHash> =
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let seq_hashes = compute_seq_hash_for_block(&local_block_hashes);
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: local_block_hashes
.iter()
.zip(seq_hashes.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
}
}
/// Create a remove event for blocks with given local hashes.
fn make_remove_event(worker_id: u64, local_hashes: &[u64]) -> RouterEvent {
let local_block_hashes: Vec<LocalBlockHash> =
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let seq_hashes = compute_seq_hash_for_block(&local_block_hashes);
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: seq_hashes
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect(),
}),
dp_rank: 0,
},
}
}
#[template]
#[rstest]
fn kv_index_template(#[values("tree", "flat")] variant: &str) {}
fn make_kv_index(variant: &str) -> KvIndex {
match variant {
"tree" => KvIndex::new_tree(),
"flat" => KvIndex::new_flat(),
_ => panic!("Unknown variant: {}", variant),
}
}
#[apply(kv_index_template)]
fn test_store_and_find(variant: &str) {
let mut index = make_kv_index(variant);
// Store a sequence for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
assert_eq!(index.current_size(), 3);
// Find matches using local hashes
let scores = index.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
#[apply(kv_index_template)]
fn test_partial_match(variant: &str) {
let mut index = make_kv_index(variant);
// Store [1, 2, 3] for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
// Find matches for [1, 2, 999] - should match first 2 then stop
let scores = index.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(999)],
false,
);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
}
#[apply(kv_index_template)]
fn test_remove(variant: &str) {
let mut index = make_kv_index(variant);
// Store sequence for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
assert_eq!(index.current_size(), 3);
// Remove all blocks
index.apply_event(make_remove_event(0, &[1, 2, 3])).unwrap();
assert_eq!(index.current_size(), 0);
// Find should return nothing
let scores = index.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert!(scores.scores.is_empty());
}
#[apply(kv_index_template)]
fn test_multiple_workers_shared_prefix(variant: &str) {
let mut index = make_kv_index(variant);
// Worker 0 has [1, 2], Worker 1 has [1, 3]
// Since sequence hashes are cumulative, [1] has same hash for both,
// but [1, 2] and [1, 3] have different hashes.
index.apply_event(make_store_event(0, &[1, 2])).unwrap();
index.apply_event(make_store_event(1, &[1, 3])).unwrap();
// Query [1] - both workers should match
let scores = index.find_matches(vec![LocalBlockHash(1)], false);
assert_eq!(scores.scores.len(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
// Query [1, 2] - worker 0 matches both, worker 1 matches only first block
let scores = index.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)], false);
assert_eq!(scores.scores.len(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
}
#[apply(kv_index_template)]
fn test_remove_worker(variant: &str) {
let mut index = make_kv_index(variant);
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
index.apply_event(make_store_event(1, &[1, 2, 3])).unwrap();
assert_eq!(index.current_size(), 6);
index.remove_worker(0);
assert_eq!(index.current_size(), 3);
let scores = index.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
}
#[apply(kv_index_template)]
fn test_get_workers(variant: &str) {
let mut index = make_kv_index(variant);
index.apply_event(make_store_event(0, &[1])).unwrap();
index.apply_event(make_store_event(2, &[1])).unwrap();
index.apply_event(make_store_event(1, &[1])).unwrap();
let workers = index.get_workers();
assert_eq!(workers, vec![0, 1, 2]);
}
#[apply(kv_index_template)]
fn test_early_exit(variant: &str) {
let mut index = make_kv_index(variant);
// Worker 0 has [0, 1, 2], Worker 1 has [0] only
index.apply_event(make_store_event(0, &[0, 1, 2])).unwrap();
index.apply_event(make_store_event(1, &[0])).unwrap();
// Query [0, 1, 2] with early_exit=true
// Should stop after [0, 1] since only worker 0 has block 1
let scores = index.find_matches(
vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(2)],
true,
);
// Both workers should appear in results
assert_eq!(scores.scores.len(), 2);
// Worker 0 got 2 points (blocks 0 and 1, stopped early)
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
// Worker 1 got 1 point (block 0 only)
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
// Without early_exit, worker 0 should get all 3 blocks
let scores = index.find_matches(
vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(2)],
false,
);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
#[apply(kv_index_template)]
fn test_large_stores(variant: &str) {
let mut index = make_kv_index(variant);
// Test sequences of increasing sizes
for i in 0..10 {
let len = 1 << i; // 1, 2, 4, 8, ..., 512
let worker_id = i;
let sequence: Vec<u64> = (1..=len).map(|x| x + (i as u64 * 10000)).collect();
index
.apply_event(make_store_event(worker_id, &sequence))
.unwrap();
assert!(index.current_size() > 0);
}
}
#[apply(kv_index_template)]
fn test_dump_and_restore(variant: &str) {
let mut index = make_kv_index(variant);
// Store some data
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
index.apply_event(make_store_event(1, &[1, 2, 4])).unwrap();
let original_size = index.current_size();
let workers_before = index.get_workers();
// Dump the tree as events
let events = index.dump_tree_as_events();
assert!(!events.is_empty());
// Create a new index and replay events
let mut restored = make_kv_index(variant);
for event in events {
let _ = restored.apply_event(event);
}
// Verify the restored index has same size and workers
assert_eq!(restored.current_size(), original_size);
assert_eq!(restored.get_workers(), workers_before);
// Verify find_matches produces same results
let original_scores = index.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)], false);
let restored_scores =
restored.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)], false);
assert_eq!(original_scores.scores, restored_scores.scores);
}
#[apply(kv_index_template)]
fn test_clear_all_blocks(variant: &str) {
let mut index = make_kv_index(variant);
// Store some data for two workers
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
index.apply_event(make_store_event(1, &[1, 2, 3])).unwrap();
assert_eq!(index.current_size(), 6);
// Clear worker 0's blocks
index.clear_all_blocks(0);
// Worker 0's blocks should be gone, worker 1's remain
assert_eq!(index.current_size(), 3);
let scores = index.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
}
#[apply(kv_index_template)]
fn test_empty_query(variant: &str) {
let mut index = make_kv_index(variant);
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
// Empty query should return empty scores
let scores = index.find_matches(vec![], false);
assert!(scores.scores.is_empty());
}
#[apply(kv_index_template)]
fn test_miss_query(variant: &str) {
let mut index = make_kv_index(variant);
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
// Query for non-existent blocks
let scores = index.find_matches(vec![LocalBlockHash(999), LocalBlockHash(998)], false);
assert!(scores.scores.is_empty());
}
}
......@@ -9,14 +9,19 @@
pub mod approx;
#[cfg(feature = "bench")]
pub mod bench_utils;
pub mod flat_hashmap;
pub mod concurrent_radix_tree;
pub mod indexer;
pub mod nested_map;
pub mod protocols;
pub mod radix_tree;
#[cfg(test)]
pub(crate) mod test_utils;
// Re-export key types for convenience
pub use flat_hashmap::FlatHashMap;
pub use indexer::MaybeError;
pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
pub use nested_map::PositionalIndexer;
pub use protocols::{
KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, WorkerId,
compute_block_hash_for_seq,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Positional HashMap-based KV cache index with nested structure.
//!
//! This module provides a `PositionalIndexer` that uses nested HashMaps
//! keyed by position for better cache locality and enables jump/binary-search
//! optimizations in find_matches.
//!
//! # Structure
//!
//! - `index`: position -> local_hash -> seq_hash -> workers
//! The main lookup structure. Position-first nesting enables O(1) position access.
//! - `worker_blocks`: worker -> seq_hash -> (position, local_hash)
//! Per-worker reverse lookup for efficient remove operations.
//!
//! # Threading
//!
//! `PositionalIndexer` implements `SyncIndexer`, meaning all its methods are
//! synchronous and thread-safe (via `DashMap` and `RwLock`). To get the full
//! `KvIndexerInterface` with sticky event routing and worker threads, wrap it
//! in a `ThreadPoolIndexer`.
use dashmap::DashMap;
use std::collections::{HashMap, HashSet};
use std::sync::RwLock;
use crate::indexer::SyncIndexer;
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheEventError, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, OverlapScores, RouterEvent, WorkerId, WorkerWithDpRank,
};
/// Entry for the innermost level of the index.
///
/// Optimizes for the common case where there's only one sequence hash
/// at a given (position, local_hash) pair, avoiding HashMap allocation.
#[derive(Debug, Clone)]
enum SeqEntry {
/// Single seq_hash -> workers mapping (common case, no HashMap allocation)
Single(ExternalSequenceBlockHash, HashSet<WorkerWithDpRank>),
/// Multiple seq_hash -> workers mappings (rare case, different prefixes)
Multi(HashMap<ExternalSequenceBlockHash, HashSet<WorkerWithDpRank>>),
}
impl SeqEntry {
/// Create a new entry with a single worker.
fn new(seq_hash: ExternalSequenceBlockHash, worker: WorkerWithDpRank) -> Self {
let mut workers = HashSet::new();
workers.insert(worker);
Self::Single(seq_hash, workers)
}
/// Insert a worker for a given seq_hash, upgrading to Multi if needed.
fn insert(&mut self, seq_hash: ExternalSequenceBlockHash, worker: WorkerWithDpRank) {
match self {
Self::Single(existing_hash, workers) if *existing_hash == seq_hash => {
workers.insert(worker);
}
Self::Single(existing_hash, existing_workers) => {
// Upgrade to Multi
let mut map = HashMap::with_capacity(2);
map.insert(*existing_hash, std::mem::take(existing_workers));
map.entry(seq_hash).or_default().insert(worker);
*self = Self::Multi(map);
}
Self::Multi(map) => {
map.entry(seq_hash).or_default().insert(worker);
}
}
}
/// Remove a worker from a given seq_hash.
/// Returns true if the entry is now completely empty and should be removed.
fn remove(&mut self, seq_hash: ExternalSequenceBlockHash, worker: WorkerWithDpRank) -> bool {
match self {
Self::Single(existing_hash, workers) if *existing_hash == seq_hash => {
workers.remove(&worker);
workers.is_empty()
}
Self::Single(_, _) => false, // Different hash, nothing to remove
Self::Multi(map) => {
if let Some(workers) = map.get_mut(&seq_hash) {
workers.remove(&worker);
if workers.is_empty() {
map.remove(&seq_hash);
}
}
map.is_empty()
}
}
}
/// Get workers for a specific seq_hash.
fn get(&self, seq_hash: ExternalSequenceBlockHash) -> Option<&HashSet<WorkerWithDpRank>> {
match self {
Self::Single(existing_hash, workers) if *existing_hash == seq_hash => Some(workers),
Self::Single(_, _) => None,
Self::Multi(map) => map.get(&seq_hash),
}
}
}
type LevelIndex = RwLock<HashMap<ExternalSequenceBlockHash, (usize, LocalBlockHash)>>;
/// Positional HashMap-based KV cache index.
///
/// Implements [`SyncIndexer`] for use with [`ThreadPoolIndexer`](crate::indexer::ThreadPoolIndexer).
/// All methods are synchronous and thread-safe.
pub struct PositionalIndexer {
index: DashMap<(usize, LocalBlockHash), SeqEntry>,
/// Per-worker reverse lookup: worker -> seq_hash -> (position, local_hash)
/// Enables efficient remove operations without global flat reverse map.
worker_blocks: DashMap<WorkerWithDpRank, LevelIndex>,
jump_size: usize,
}
impl PositionalIndexer {
/// Create a new PositionalIndexer.
///
/// # Arguments
/// * `jump_size` - Jump size for find_matches optimization (e.g., 32).
/// The algorithm jumps by this many positions at a time, only scanning
/// intermediate positions when workers drain (stop matching).
pub fn new(jump_size: usize) -> Self {
assert!(jump_size > 0, "jump_size must be greater than 0");
Self {
index: DashMap::new(),
worker_blocks: DashMap::new(),
jump_size,
}
}
}
// ============================================================================
// SyncIndexer implementation
// ============================================================================
impl SyncIndexer for PositionalIndexer {
fn find_matches(&self, sequence: &[LocalBlockHash], early_exit: bool) -> OverlapScores {
self.jump_search_matches(sequence, early_exit)
}
fn apply_event(&self, event: RouterEvent) -> Result<(), KvCacheEventError> {
Self::apply_event_impl(&self.index, &self.worker_blocks, event)
}
fn remove_worker(&self, worker_id: WorkerId) {
Self::remove_or_clear_worker_blocks_impl(
&self.index,
&self.worker_blocks,
worker_id,
false,
);
}
fn dump_events(&self) -> Vec<RouterEvent> {
let mut events = Vec::new();
let mut event_id = 0u64;
for entry in self.worker_blocks.iter() {
let worker = *entry.key();
let worker_map = entry.value().read().unwrap();
// Collect (position, local_hash, seq_hash) and sort by position
// so parents are emitted before children during replay.
let mut blocks: Vec<_> = worker_map
.iter()
.map(|(seq_hash, (pos, local_hash))| (*pos, *local_hash, *seq_hash))
.collect();
blocks.sort_unstable_by_key(|(pos, _, _)| *pos);
// Track one valid seq_hash per position for parent_hash synthesis.
let mut last_at_position: HashMap<usize, ExternalSequenceBlockHash> = HashMap::new();
for (pos, local_hash, seq_hash) in blocks {
let parent_hash = if pos == 0 {
None
} else {
match last_at_position.get(&(pos - 1)) {
Some(&parent) => Some(parent),
None => {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
position = pos,
"Orphaned block at position with no parent; skipping in dump"
);
continue;
}
}
};
events.push(RouterEvent {
worker_id: worker.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: vec![KvCacheStoredBlockData {
block_hash: seq_hash,
tokens_hash: local_hash,
mm_extra_info: None,
}],
}),
dp_rank: worker.dp_rank,
},
});
event_id += 1;
last_at_position.insert(pos, seq_hash);
}
}
events
}
}
// ============================================================================
// Event processing (write operations)
// ============================================================================
impl PositionalIndexer {
/// Process an event using the provided index and worker_blocks.
/// This is called from worker threads.
fn apply_event_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry>,
worker_blocks: &DashMap<WorkerWithDpRank, LevelIndex>,
event: RouterEvent,
) -> Result<(), KvCacheEventError> {
let (worker_id, kv_event) = (event.worker_id, event.event);
let (id, op) = (kv_event.event_id, kv_event.data);
let worker = WorkerWithDpRank::new(worker_id, kv_event.dp_rank);
tracing::trace!(
id,
"PositionalIndexer::apply_event_impl: operation: {:?}",
op
);
match op {
KvCacheEventData::Stored(store_data) => {
Self::store_blocks_impl(index, worker_blocks, worker, store_data, id)?;
Ok(())
}
KvCacheEventData::Removed(remove_data) => {
Self::remove_blocks_impl(
index,
worker_blocks,
worker,
&remove_data.block_hashes,
id,
)?;
Ok(())
}
KvCacheEventData::Cleared => {
Self::clear_worker_blocks_impl(index, worker_blocks, worker_id);
Ok(())
}
}
}
fn store_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry>,
worker_blocks: &DashMap<WorkerWithDpRank, LevelIndex>,
worker: WorkerWithDpRank,
store_data: KvCacheStoreData,
event_id: u64,
) -> Result<(), KvCacheEventError> {
// Determine starting position based on parent_hash
let start_pos = match store_data.parent_hash {
Some(parent_hash) => {
// Find parent position from worker_blocks
let Some(worker_map) = worker_blocks.get(&worker) else {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
event_id,
parent_hash = ?parent_hash,
);
return Err(KvCacheEventError::ParentBlockNotFound);
};
let worker_map = worker_map.read().unwrap();
let Some(entry) = worker_map.get(&parent_hash) else {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
event_id,
parent_hash = ?parent_hash,
);
return Err(KvCacheEventError::ParentBlockNotFound);
};
entry.0 + 1 // parent position + 1
}
None => 0, // Start from position 0
};
if !worker_blocks.contains_key(&worker) {
worker_blocks.insert(worker, RwLock::new(HashMap::new()));
}
let worker_blocks_entry = worker_blocks.get(&worker).unwrap();
let mut worker_map = worker_blocks_entry.write().unwrap();
for (i, block_data) in store_data.blocks.into_iter().enumerate() {
let position = start_pos + i;
let local_hash = block_data.tokens_hash;
let seq_hash = block_data.block_hash;
index
.entry((position, local_hash))
.and_modify(|entry| entry.insert(seq_hash, worker))
.or_insert_with(|| SeqEntry::new(seq_hash, worker));
// Insert into worker_blocks: worker -> seq_hash -> (position, local_hash)
worker_map.insert(seq_hash, (position, local_hash));
}
Ok(())
}
fn remove_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry>,
worker_blocks: &DashMap<WorkerWithDpRank, LevelIndex>,
worker: WorkerWithDpRank,
seq_hashes: &Vec<ExternalSequenceBlockHash>,
event_id: u64,
) -> Result<(), KvCacheEventError> {
let worker_map = worker_blocks.get(&worker).ok_or_else(|| {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
event_id,
block_hashes = ?seq_hashes,
"Failed to find worker blocks to remove"
);
KvCacheEventError::BlockNotFound
})?;
let mut worker_map = worker_map.write().unwrap();
for seq_hash in seq_hashes {
let Some((position, local_hash)) = worker_map.remove(seq_hash) else {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
event_id,
block_hash = ?seq_hash,
"Failed to find block to remove; skipping remove operation"
);
return Err(KvCacheEventError::BlockNotFound);
};
// Remove from index
if let Some(mut entry) = index.get_mut(&(position, local_hash)) {
let _ = entry.remove(*seq_hash, worker);
}
}
Ok(())
}
/// Clear all blocks for a specific worker_id (all dp_ranks), but keep worker tracked.
/// Static version for use in worker threads.
fn clear_worker_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry>,
worker_blocks: &DashMap<WorkerWithDpRank, LevelIndex>,
worker_id: WorkerId,
) {
Self::remove_or_clear_worker_blocks_impl(index, worker_blocks, worker_id, true);
}
/// Get total number of blocks across all workers.
pub fn current_size(&self) -> usize {
self.worker_blocks
.iter()
.map(|entry| entry.value().read().unwrap().len())
.sum()
}
/// Remove a worker and all their blocks completely from the index.
#[allow(dead_code)]
fn remove_worker_blocks(&self, worker_id: WorkerId) {
Self::remove_or_clear_worker_blocks_impl(
&self.index,
&self.worker_blocks,
worker_id,
false,
);
}
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains tracked with empty blocks.
/// If `keep_worker` is false, the worker is completely removed.
fn remove_or_clear_worker_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry>,
worker_blocks: &DashMap<WorkerWithDpRank, LevelIndex>,
worker_id: WorkerId,
keep_worker: bool,
) {
// Collect all WorkerWithDpRank keys that match this worker_id
let workers: Vec<WorkerWithDpRank> = worker_blocks
.iter()
.filter(|entry| entry.key().worker_id == worker_id)
.map(|entry| *entry.key())
.collect();
for worker in workers {
if let Some((_, worker_map)) = worker_blocks.remove(&worker) {
// Remove each block from the index
for entry in worker_map.read().unwrap().iter() {
let seq_hash = *entry.0;
let (position, local_hash) = *entry.1;
if let Some(mut entry) = index.get_mut(&(position, local_hash)) {
let _ = entry.remove(seq_hash, worker);
}
}
}
if keep_worker {
// Re-insert worker with empty map to keep it tracked
worker_blocks.insert(worker, RwLock::new(HashMap::new()));
}
}
}
}
// -----------------------------------------------------------------------------
// Jump-based search methods (associated functions for use in worker threads)
// -----------------------------------------------------------------------------
impl PositionalIndexer {
/// Compute sequence hash incrementally from previous hash and current local hash.
#[inline]
fn compute_next_seq_hash(prev_seq_hash: u64, current_local_hash: u64) -> u64 {
let mut bytes = [0u8; 16];
bytes[..8].copy_from_slice(&prev_seq_hash.to_le_bytes());
bytes[8..].copy_from_slice(&current_local_hash.to_le_bytes());
crate::protocols::compute_hash(&bytes)
}
/// Ensure seq_hashes is computed up to and including target_pos.
/// Lazily extends the seq_hashes vector as needed.
#[inline]
fn ensure_seq_hash_computed(
seq_hashes: &mut Vec<ExternalSequenceBlockHash>,
target_pos: usize,
sequence: &[LocalBlockHash],
) {
while seq_hashes.len() <= target_pos {
let pos = seq_hashes.len();
if pos == 0 {
// First block's seq_hash equals its local_hash
seq_hashes.push(ExternalSequenceBlockHash::from(sequence[0].0));
} else {
let prev_seq_hash = seq_hashes[pos - 1].0;
let current_local_hash = sequence[pos].0;
let next_hash = Self::compute_next_seq_hash(prev_seq_hash, current_local_hash);
seq_hashes.push(ExternalSequenceBlockHash::from(next_hash));
}
}
}
/// Get workers at a position by verifying both local_hash and seq_hash match.
///
/// Returns None if no workers match at this position.
/// Always computes and verifies the seq_hash to ensure correctness when
/// the query may have diverged from stored sequences at earlier positions.
fn get_workers_lazy(
&self,
position: usize,
local_hash: LocalBlockHash,
seq_hashes: &mut Vec<ExternalSequenceBlockHash>,
sequence: &[LocalBlockHash],
) -> Option<HashSet<WorkerWithDpRank>> {
let entry = self.index.get(&(position, local_hash))?;
// Always compute and verify seq_hash to handle divergent queries correctly.
// Even if there's only one seq_hash entry, the query's seq_hash might differ
// if the query diverged from the stored sequence at an earlier position.
Self::ensure_seq_hash_computed(seq_hashes, position, sequence);
let seq_hash = seq_hashes[position];
entry.get(seq_hash).cloned()
}
fn count_workers_at(
&self,
position: usize,
local_hash: LocalBlockHash,
seq_hashes: &mut Vec<ExternalSequenceBlockHash>,
sequence: &[LocalBlockHash],
) -> Option<usize> {
let entry = self.index.get(&(position, local_hash))?;
// Always compute and verify seq_hash to handle divergent queries correctly.
// Even if there's only one seq_hash entry, the query's seq_hash might differ
// if the query diverged from the stored sequence at an earlier position.
Self::ensure_seq_hash_computed(seq_hashes, position, sequence);
let seq_hash = seq_hashes[position];
Some(
entry
.get(seq_hash)
.map(|workers| workers.len())
.unwrap_or(0),
)
}
/// Scan positions sequentially, updating active set and recording drain scores.
///
/// Inlines the DashMap lookup so the guard lives for each iteration,
/// avoiding a per-position `HashSet` clone.
#[allow(clippy::too_many_arguments)]
fn linear_scan_drain(
&self,
sequence: &[LocalBlockHash],
seq_hashes: &mut Vec<ExternalSequenceBlockHash>,
active: &mut HashSet<WorkerWithDpRank>,
scores: &mut OverlapScores,
lo: usize,
hi: usize,
early_exit: bool,
) {
for pos in lo..hi {
if active.is_empty() {
break;
}
let Some(entry) = self.index.get(&(pos, sequence[pos])) else {
for worker in active.iter() {
scores.scores.insert(*worker, pos as u32);
}
active.clear();
break;
};
Self::ensure_seq_hash_computed(seq_hashes, pos, sequence);
let seq_hash = seq_hashes[pos];
match entry.get(seq_hash) {
Some(workers) => {
active.retain(|w| {
if workers.contains(w) {
true
} else {
scores.scores.insert(*w, pos as u32);
false
}
});
if early_exit && !active.is_empty() {
break;
}
}
None => {
for worker in active.iter() {
scores.scores.insert(*worker, pos as u32);
}
active.clear();
}
}
}
}
/// Jump-based search to find matches for a sequence of block hashes.
///
/// # Algorithm
///
/// 1. Check first position - initialize active set with matching workers
/// 2. Initialize seq_hashes with first block's hash (seq_hash[0] = local_hash[0])
/// 3. Loop: jump by jump_size positions
/// - At each jump, check if active workers still match:
/// - All match: Continue jumping (skip intermediate positions)
/// - None match: Scan range with linear_scan_drain
/// - Partial match: Scan range to find exact drain points
/// 4. Record final scores for remaining active workers
/// 5. Populate tree_sizes from worker_blocks
///
/// # Arguments
/// * `index` - The position -> local_hash -> SeqEntry index
/// * `worker_blocks` - Per-worker reverse lookup for tree sizes
/// * `local_hashes` - Sequence of LocalBlockHash to match
/// * `jump_size` - Number of positions to jump at a time
/// * `early_exit` - If true, stop after finding any match
fn jump_search_matches(
&self,
local_hashes: &[LocalBlockHash],
early_exit: bool,
) -> OverlapScores {
let mut scores = OverlapScores::new();
if local_hashes.is_empty() {
return scores;
}
// Lazily computed sequence hashes
let mut seq_hashes: Vec<ExternalSequenceBlockHash> = Vec::new();
// Check first position to initialize active set
let Some(initial_workers) =
self.get_workers_lazy(0, local_hashes[0], &mut seq_hashes, local_hashes)
else {
return scores;
};
let mut active = initial_workers;
if active.is_empty() {
return scores;
}
if early_exit {
// For early exit, just record that these workers matched at least position 0
for worker in &active {
scores.scores.insert(*worker, 1);
}
// Populate tree_sizes
for worker in scores.scores.keys() {
if let Some(worker_map) = self.worker_blocks.get(worker) {
let worker_map = worker_map.read().unwrap();
scores.tree_sizes.insert(*worker, worker_map.len());
}
}
return scores;
}
let len = local_hashes.len();
let mut current_pos = 0;
// Jump through positions
while current_pos < len - 1 && !active.is_empty() {
let next_pos = (current_pos + self.jump_size).min(len - 1);
// Check workers at jump destination
let num_workers_at_next = self
.count_workers_at(
next_pos,
local_hashes[next_pos],
&mut seq_hashes,
local_hashes,
)
.unwrap_or(0);
if num_workers_at_next == active.len() {
current_pos = next_pos;
} else {
// No active workers match at jump destination
// Scan the range to find where each worker drained
self.linear_scan_drain(
local_hashes,
&mut seq_hashes,
&mut active,
&mut scores,
current_pos + 1,
next_pos + 1,
false,
);
current_pos = next_pos;
}
}
// Record final scores for remaining active workers
// They matched all positions through the end
let final_score = len as u32;
for worker in active {
scores.scores.insert(worker, final_score);
}
// Populate tree_sizes from worker_blocks
for worker in scores.scores.keys() {
if let Some(worker_map) = self.worker_blocks.get(worker) {
let worker_map = worker_map.read().unwrap();
scores.tree_sizes.insert(*worker, worker_map.len());
}
}
scores
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment