Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
fc92fc18
Unverified
Commit
fc92fc18
authored
Jan 28, 2026
by
Yan Ru Pei
Committed by
GitHub
Jan 29, 2026
Browse files
chore: clean ups in kv-router (#5771)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
842f0f15
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
2301 additions
and
1468 deletions
+2301
-1468
lib/bindings/python/Cargo.lock
lib/bindings/python/Cargo.lock
+43
-0
lib/bindings/python/rust/llm/kv.rs
lib/bindings/python/rust/llm/kv.rs
+8
-6
lib/kv-router/benches/radix_tree_microbench.rs
lib/kv-router/benches/radix_tree_microbench.rs
+258
-157
lib/kv-router/src/flat_hashmap.rs
lib/kv-router/src/flat_hashmap.rs
+299
-0
lib/kv-router/src/indexer.rs
lib/kv-router/src/indexer.rs
+398
-1292
lib/kv-router/src/lib.rs
lib/kv-router/src/lib.rs
+9
-2
lib/kv-router/src/protocols.rs
lib/kv-router/src/protocols.rs
+158
-0
lib/kv-router/src/radix_tree.rs
lib/kv-router/src/radix_tree.rs
+1118
-0
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+4
-4
lib/llm/src/kv_router/publisher.rs
lib/llm/src/kv_router/publisher.rs
+1
-1
lib/llm/src/kv_router/recorder.rs
lib/llm/src/kv_router/recorder.rs
+1
-1
lib/llm/src/kv_router/scheduler.rs
lib/llm/src/kv_router/scheduler.rs
+1
-2
lib/llm/src/kv_router/sequence.rs
lib/llm/src/kv_router/sequence.rs
+1
-1
lib/llm/src/kv_router/subscriber.rs
lib/llm/src/kv_router/subscriber.rs
+2
-2
No files found.
lib/bindings/python/Cargo.lock
View file @
fc92fc18
...
...
@@ -700,6 +700,15 @@ dependencies = [
"objc2",
]
[[package]]
name = "bs58"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf88ba1141d185c399bee5288d850d63b8369520c1eafc32a0430b5b6c287bf4"
dependencies = [
"tinyvec",
]
[[package]]
name = "bs62"
version = "0.1.4"
...
...
@@ -1598,6 +1607,24 @@ dependencies = [
"anyhow",
]
[[package]]
name = "dynamo-kv-router"
version = "0.9.0"
dependencies = [
"anyhow",
"async-trait",
"dynamo-runtime",
"dynamo-tokens",
"prometheus",
"rand 0.9.2",
"serde",
"thiserror 2.0.17",
"tokio",
"tokio-util",
"tracing",
"xxhash-rust",
]
[[package]]
name = "dynamo-llm"
version = "0.9.0"
...
...
@@ -1626,9 +1653,11 @@ dependencies = [
"derive_builder",
"dialoguer",
"dynamo-async-openai",
"dynamo-kv-router",
"dynamo-memory",
"dynamo-parsers",
"dynamo-runtime",
"dynamo-tokens",
"either",
"erased-serde",
"etcd-client",
...
...
@@ -1833,6 +1862,20 @@ dependencies = [
"zmq",
]
[[package]]
name = "dynamo-tokens"
version = "0.9.0"
dependencies = [
"bs58",
"bytemuck",
"dashmap 6.1.0",
"derive-getters",
"serde",
"thiserror 2.0.17",
"uuid",
"xxhash-rust",
]
[[package]]
name = "ed25519"
version = "2.2.3"
...
...
lib/bindings/python/rust/llm/kv.rs
View file @
fc92fc18
...
...
@@ -360,7 +360,7 @@ impl KvEventPublisher {
#[pyclass]
#[derive(Clone)]
pub
(
crate
)
struct
OverlapScores
{
inner
:
llm_rs
::
kv_router
::
indexer
::
OverlapScores
,
inner
:
llm_rs
::
kv_router
::
protocols
::
OverlapScores
,
}
#[pymethods]
...
...
@@ -386,7 +386,7 @@ enum RadixTreeRequest {
FindMatches
{
local_block_hashes
:
Vec
<
llm_rs
::
kv_router
::
protocols
::
LocalBlockHash
>
,
early_exit
:
bool
,
response_tx
:
mpsc
::
SyncSender
<
llm_rs
::
kv_router
::
indexer
::
OverlapScores
>
,
response_tx
:
mpsc
::
SyncSender
<
llm_rs
::
kv_router
::
protocols
::
OverlapScores
>
,
},
ApplyEvent
{
worker_id
:
WorkerId
,
...
...
@@ -402,7 +402,7 @@ enum RadixTreeRequest {
response_tx
:
mpsc
::
SyncSender
<
()
>
,
},
DumpTreeAsEvents
{
response_tx
:
mpsc
::
SyncSender
<
Vec
<
llm_rs
::
kv_router
::
indexer
::
RouterEvent
>>
,
response_tx
:
mpsc
::
SyncSender
<
Vec
<
llm_rs
::
kv_router
::
protocols
::
RouterEvent
>>
,
},
Shutdown
,
}
...
...
@@ -616,8 +616,10 @@ impl RadixTree {
>
(
&
kv_cache_event_bytes
)
{
Ok
(
kv_cache_event
)
=>
{
let
router_event
=
llm_rs
::
kv_router
::
indexer
::
RouterEvent
::
new
(
worker_id
,
kv_cache_event
);
let
router_event
=
llm_rs
::
kv_router
::
protocols
::
RouterEvent
::
new
(
worker_id
,
kv_cache_event
,
);
match
radix_tree
.apply_event
(
router_event
)
{
Ok
(
_
)
=>
Ok
(()),
Err
(
e
)
=>
Err
(
PyErr
::
new
::
<
pyo3
::
exceptions
::
PyRuntimeError
,
_
>
(
...
...
@@ -898,7 +900,7 @@ impl KvRecorder {
// Spawn a task to forward events to the recorder
tokio
::
spawn
(
async
move
{
while
let
Some
(
event
)
=
kv_events_rx
.next
()
.await
{
let
event
:
llm_rs
::
kv_router
::
indexer
::
RouterEvent
=
let
event
:
llm_rs
::
kv_router
::
protocols
::
RouterEvent
=
serde_json
::
from_slice
(
&
event
.payload
)
.unwrap
();
tracing
::
debug!
(
"KvRecorder received kv event: {:?}"
,
event
);
if
let
Err
(
e
)
=
event_tx
.send
(
event
)
.await
{
...
...
lib/kv-router/benches/radix_tree_microbench.rs
View file @
fc92fc18
...
...
@@ -15,17 +15,97 @@
use
clap
::{
Parser
,
ValueEnum
};
use
dynamo_kv_router
::{
compute_block_hash_for_seq
,
indexer
::{
RadixTree
,
RouterEvent
}
,
OverlapScores
,
RadixTree
,
RouterEvent
,
compute_block_hash_for_seq
,
flat_hashmap
::
FlatHashMap
,
protocols
::{
ExternalSequenceBlockHash
,
KvCacheEvent
,
KvCacheEventData
,
KvCacheRemoveData
,
KvCacheStoreData
,
KvCacheStoredBlockData
,
LocalBlockHash
,
WorkerId
,
compute_seq_hash_for_block
,
},
};
use
rand
::
rngs
::
StdRng
;
use
rand
::{
Rng
,
SeedableRng
};
use
std
::
time
::{
Duration
,
Instant
};
/// Unified interface for RadixTree and FlatHashMap benchmarking.
///
/// Both structures have feature parity for store, remove, find_matches, and current_size.
/// The key difference is find_matches input:
/// - RadixTree: uses LocalBlockHash (tokens_hash)
/// - FlatHashMap: uses ExternalSequenceBlockHash (cumulative sequence hash)
enum
KvIndex
{
Tree
(
RadixTree
),
Flat
(
FlatHashMap
),
}
impl
KvIndex
{
fn
name
(
&
self
)
->
&
'static
str
{
match
self
{
KvIndex
::
Tree
(
_
)
=>
"RadixTree"
,
KvIndex
::
Flat
(
_
)
=>
"FlatHashMap"
,
}
}
fn
apply_event
(
&
mut
self
,
event
:
RouterEvent
)
{
match
self
{
KvIndex
::
Tree
(
tree
)
=>
{
let
_
=
tree
.apply_event
(
event
);
}
KvIndex
::
Flat
(
map
)
=>
{
map
.apply_event
(
event
);
}
}
}
fn
find_matches_timed
(
&
self
,
seq
:
&
SequenceData
,
early_exit
:
bool
)
->
Duration
{
let
local_hashes
=
seq
.local_hashes
.clone
();
let
start
=
Instant
::
now
();
let
_
=
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.find_matches
(
local_hashes
,
early_exit
),
KvIndex
::
Flat
(
map
)
=>
map
.find_matches
(
local_hashes
,
early_exit
),
};
start
.elapsed
()
}
fn
find_matches_miss_timed
(
&
self
,
depth
:
usize
,
i
:
usize
,
early_exit
:
bool
)
->
Duration
{
let
miss_hashes
:
Vec
<
LocalBlockHash
>
=
(
0
..
depth
)
.map
(|
j
|
LocalBlockHash
(
0xBAD_C0DE_0000_0000
|
((
i
as
u64
)
<<
16
)
|
(
j
as
u64
)))
.collect
();
let
start
=
Instant
::
now
();
let
_
=
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.find_matches
(
miss_hashes
,
early_exit
),
KvIndex
::
Flat
(
map
)
=>
map
.find_matches
(
miss_hashes
,
early_exit
),
};
start
.elapsed
()
}
fn
find_matches_partial_timed
(
&
self
,
seq
:
&
SequenceData
,
half
:
usize
,
i
:
usize
,
early_exit
:
bool
,
)
->
Duration
{
let
mut
partial
=
seq
.local_hashes
[
..
half
]
.to_vec
();
partial
.extend
(
(
0
..
half
)
.map
(|
j
|
LocalBlockHash
(
0xDEAD_0000
|
((
i
as
u64
)
<<
16
)
|
(
j
as
u64
))),
);
let
start
=
Instant
::
now
();
let
_
=
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.find_matches
(
partial
,
early_exit
),
KvIndex
::
Flat
(
map
)
=>
map
.find_matches
(
partial
,
early_exit
),
};
start
.elapsed
()
}
fn
current_size
(
&
self
)
->
usize
{
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.current_size
(),
KvIndex
::
Flat
(
map
)
=>
map
.current_size
(),
}
}
}
/// Sweep benchmark mode
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq,
ValueEnum)]
enum
SweepMode
{
...
...
@@ -61,6 +141,10 @@ struct Args {
#[arg(long,
default_value
=
"1000"
)]
iterations
:
usize
,
/// Warmup ratio (0.0 to 1.0) - fraction of iterations to discard for warmup
#[arg(long,
default_value
=
"0.1"
)]
warmup_ratio
:
f64
,
/// Prefix prompt ratio (0.0 to 1.0) - portion of sequence from the beginning that is a shared prefix
#[arg(long,
default_value
=
"0.25"
)]
prefix_prompt_ratio
:
f64
,
...
...
@@ -116,6 +200,10 @@ struct Args {
/// Random seed for reproducibility
#[arg(long,
default_value
=
"42"
)]
seed
:
u64
,
/// Use flat HashMap baseline instead of radix tree (for comparison)
#[arg(long)]
flat_hashmap
:
bool
,
}
/// Pre-generated sequence data for benchmarking
...
...
@@ -127,13 +215,14 @@ struct SequenceData {
}
impl
SequenceData
{
fn
new
(
seq_id
:
u64
,
worker_id
:
WorkerId
,
depth
:
usize
)
->
Self
{
let
local_hashes
:
Vec
<
LocalBlockHash
>
=
(
0
..
depth
)
.map
(|
block_idx
|
LocalBlockHash
((
seq_id
<<
32
)
|
(
block_idx
as
u64
)))
.collect
();
let
external_hashes
:
Vec
<
ExternalSequenceBlockHash
>
=
(
0
..
depth
)
.map
(|
block_idx
|
ExternalSequenceBlockHash
((
seq_id
<<
32
)
|
(
block_idx
as
u64
)))
/// Create a new SequenceData from local_hashes.
/// Automatically computes external_hashes using compute_seq_hash_for_block (cumulative hashes).
/// This ensures FlatHashMap can correctly identify block positions.
fn
from_local_hashes
(
worker_id
:
WorkerId
,
local_hashes
:
Vec
<
LocalBlockHash
>
)
->
Self
{
let
seq_hashes
=
compute_seq_hash_for_block
(
&
local_hashes
);
let
external_hashes
=
seq_hashes
.into_iter
()
.map
(
ExternalSequenceBlockHash
)
.collect
();
Self
{
...
...
@@ -190,28 +279,42 @@ fn generate_sequences(
seed
:
u64
,
)
->
Vec
<
SequenceData
>
{
let
mut
sequences
=
Vec
::
with_capacity
(
num_sequences
);
let
prefix_length
:
usize
=
(
depth
as
f64
*
prefix_prompt_ratio
)
.round
()
as
usize
;
let
prefix_length
=
(
depth
as
f64
*
prefix_prompt_ratio
)
.round
()
as
usize
;
let
mut
rng
:
StdRng
=
StdRng
::
seed_from_u64
(
seed
);
for
seq_id
in
0
..
num_sequences
{
let
seq_id_u64
=
seq_id
as
u64
;
let
worker_id
=
(
seq_id
%
num_workers
)
as
WorkerId
;
let
mut
seq
=
SequenceData
::
new
(
seq_id
as
u64
,
worker_id
,
depth
);
if
num_prefix_prompts
>
0
&&
prefix_length
>
0
{
let
group_id
=
rng
.random_range
(
0
..
num_prefix_prompts
);
for
i
in
0
..
prefix_length
{
seq
.local_hashes
[
i
]
=
LocalBlockHash
(
0xDEAD_BEEF_0000_0000
|
((
group_id
as
u64
)
<<
32
)
|
(
i
as
u64
));
}
}
// Determine prefix group for this sequence
let
group_id
=
if
num_prefix_prompts
>
0
&&
prefix_length
>
0
{
Some
(
rng
.random_range
(
0
..
num_prefix_prompts
)
as
u64
)
}
else
{
None
};
sequences
.push
(
seq
);
// Build local_hashes: shared prefix (if applicable) + unique suffix
let
local_hashes
:
Vec
<
LocalBlockHash
>
=
(
0
..
depth
)
.map
(|
block_idx
|
{
let
block_idx_u64
=
block_idx
as
u64
;
if
let
Some
(
gid
)
=
group_id
{
if
block_idx
<
prefix_length
{
// Shared prefix based on group_id
return
LocalBlockHash
(
0xDEAD_BEEF_0000_0000
|
(
gid
<<
32
)
|
block_idx_u64
);
}
}
// Unique suffix (or no shared prefix)
LocalBlockHash
((
seq_id_u64
<<
32
)
|
block_idx_u64
)
})
.collect
();
sequences
.push
(
SequenceData
::
from_local_hashes
(
worker_id
,
local_hashes
));
}
sequences
}
/// Build a pre-populated
tree (prints timing info
)
/// Build a pre-populated
RadixTree (for sweep/dump benchmarks that specifically need RadixTree
)
fn
build_tree
(
sequences
:
&
[
SequenceData
])
->
RadixTree
{
let
num_blocks
:
usize
=
sequences
.iter
()
.map
(|
s
|
s
.local_hashes
.len
())
.sum
();
print!
(
...
...
@@ -239,6 +342,45 @@ fn build_tree(sequences: &[SequenceData]) -> RadixTree {
tree
}
/// Build a pre-populated KvIndex (prints timing info)
fn
build_index
(
sequences
:
&
[
SequenceData
],
use_flat_hashmap
:
bool
)
->
KvIndex
{
let
num_blocks
:
usize
=
sequences
.iter
()
.map
(|
s
|
s
.local_hashes
.len
())
.sum
();
let
name
=
if
use_flat_hashmap
{
"FlatHashMap"
}
else
{
"RadixTree"
};
print!
(
" Building {} with {} sequences ({} blocks)... "
,
name
,
sequences
.len
(),
num_blocks
);
std
::
io
::
Write
::
flush
(
&
mut
std
::
io
::
stdout
())
.unwrap
();
let
start
=
Instant
::
now
();
let
mut
index
=
if
use_flat_hashmap
{
KvIndex
::
Flat
(
FlatHashMap
::
new
())
}
else
{
KvIndex
::
Tree
(
RadixTree
::
new
())
};
for
(
event_id
,
seq
)
in
sequences
.iter
()
.enumerate
()
{
let
event
=
seq
.to_store_event
(
event_id
as
u64
);
index
.apply_event
(
event
);
}
let
elapsed
=
start
.elapsed
();
println!
(
"done in {:.2?} ({:.2} sequences/sec, {:.2} blocks/sec)"
,
elapsed
,
sequences
.len
()
as
f64
/
elapsed
.as_secs_f64
(),
num_blocks
as
f64
/
elapsed
.as_secs_f64
()
);
index
}
/// Statistics for a set of timing measurements
#[derive(Debug)]
struct
LatencyStats
{
...
...
@@ -304,14 +446,18 @@ fn bench_hash(args: &Args) {
})
.collect
();
let
mut
durations
=
Vec
::
with_capacity
(
args
.iterations
);
let
warmup_iters
=
(
args
.iterations
as
f64
*
args
.warmup_ratio
)
as
usize
;
let
measured_iters
=
args
.iterations
-
warmup_iters
;
let
mut
durations
=
Vec
::
with_capacity
(
measured_iters
);
for
(
i
,
tokens
)
in
token_sequences
.iter
()
.enumerate
()
{
let
start
=
Instant
::
now
();
let
_
=
compute_block_hash_for_seq
(
tokens
,
args
.block_size
,
None
);
let
elapsed
=
start
.elapsed
();
durations
.push
(
elapsed
);
if
i
>=
warmup_iters
{
durations
.push
(
elapsed
);
}
if
args
.verbose
&&
(
i
+
1
)
%
100
==
0
{
println!
(
" Completed {}/{} iterations"
,
i
+
1
,
args
.iterations
);
...
...
@@ -322,14 +468,19 @@ fn bench_hash(args: &Args) {
stats
.print
(
"COMPUTE_BLOCK_HASH"
,
args
.depth
);
}
/// Benchmark store_block operation
fn
bench_store
(
args
:
&
Args
)
{
println!
(
"
\n
=== Benchmarking STORE_BLOCK ==="
);
/// Benchmark store or remove operation on a steady-state index.
///
/// Uses a remove/store cycle to maintain size. If `time_store` is true,
/// the store operation is timed; otherwise the remove operation is timed.
fn
bench_store_remove_cycle
(
args
:
&
Args
,
time_store
:
bool
)
{
let
op_name
=
if
time_store
{
"STORE_BLOCK"
}
else
{
"REMOVE_BLOCK"
};
let
num_sequences
=
args
.size
/
args
.depth
;
let
bench_iters
=
args
.iterations
.min
(
num_sequences
);
let
all_sequences
=
generate_sequences
(
let
sequences
=
generate_sequences
(
num_sequences
,
args
.depth
,
args
.num_workers
,
...
...
@@ -337,87 +488,58 @@ fn bench_store(args: &Args) {
args
.num_prefix_prompts
,
args
.seed
,
);
let
split_point
=
num_sequences
.saturating_sub
(
bench_iters
);
let
pre_sequences
=
&
all_sequences
[
..
split_point
];
let
bench_sequences
=
&
all_sequences
[
split_point
..
];
// Build tree once, then store sequences sequentially
// Tree grows from (size - iterations) to size over the benchmark
let
mut
tree
=
build_tree
(
&
pre_sequences
);
println!
(
" Initial tree size: {} blocks, will grow to ~{} blocks"
,
tree
.current_size
(),
tree
.current_size
()
+
bench_iters
*
args
.depth
);
let
mut
index
=
build_index
(
&
sequences
,
args
.flat_hashmap
);
println!
(
"
\n
=== Benchmarking {} ({}) ==="
,
op_name
,
index
.name
());
println!
(
" Size: {} blocks"
,
index
.current_size
());
let
mut
durations
=
Vec
::
with_capacity
(
bench_iters
);
let
warmup_iters
=
(
args
.iterations
as
f64
*
args
.warmup_ratio
)
as
usize
;
let
measured_iters
=
args
.iterations
-
warmup_iters
;
let
mut
durations
=
Vec
::
with_capacity
(
measured_iters
);
for
(
i
,
seq
)
in
bench_sequences
.iter
()
.enumerate
()
{
let
event
=
seq
.to_store_event
(
i
as
u64
);
for
i
in
0
..
args
.iterations
{
let
seq
=
&
sequences
[
i
%
sequences
.len
()];
let
remove_event
=
seq
.to_remove_event
(
i
as
u64
);
let
store_event
=
seq
.to_store_event
(
i
as
u64
+
args
.iterations
as
u64
);
let
start
=
Instant
::
now
();
let
_
=
tree
.apply_event
(
event
);
let
elapsed
=
start
.elapsed
();
let
elapsed
=
if
time_store
{
index
.apply_event
(
remove_event
);
let
start
=
Instant
::
now
();
index
.apply_event
(
store_event
);
start
.elapsed
()
}
else
{
let
start
=
Instant
::
now
();
index
.apply_event
(
remove_event
);
let
elapsed
=
start
.elapsed
();
index
.apply_event
(
store_event
);
elapsed
};
durations
.push
(
elapsed
);
if
i
>=
warmup_iters
{
durations
.push
(
elapsed
);
}
if
args
.verbose
&&
(
i
+
1
)
%
100
==
0
{
println!
(
" Completed {}/{} iterations"
,
i
+
1
,
bench_iter
s
);
println!
(
" Completed {}/{} iterations"
,
i
+
1
,
args
.iteration
s
);
}
}
let
stats
=
LatencyStats
::
from_durations
(
durations
);
stats
.print
(
"STORE_BLOCK"
,
args
.depth
);
stats
.print
(
op_name
,
args
.depth
);
}
/// Benchmark store_block operation
fn
bench_store
(
args
:
&
Args
)
{
bench_store_remove_cycle
(
args
,
true
);
}
/// Benchmark remove_block operation
fn
bench_remove
(
args
:
&
Args
)
{
println!
(
"
\n
=== Benchmarking REMOVE_BLOCK ==="
);
let
num_sequences
=
args
.size
/
args
.depth
;
let
sequences
=
generate_sequences
(
num_sequences
,
args
.depth
,
args
.num_workers
,
args
.prefix_prompt_ratio
,
args
.num_prefix_prompts
,
args
.seed
,
);
// Build tree once, then remove/re-add to restore state after each timed removal
let
mut
tree
=
build_tree
(
&
sequences
);
println!
(
" Tree size: {} blocks"
,
tree
.current_size
());
let
mut
durations
=
Vec
::
with_capacity
(
args
.iterations
);
for
i
in
0
..
args
.iterations
{
// Remove a sequence (timed)
let
seq_to_remove
=
&
sequences
[
i
%
sequences
.len
()];
let
remove_event
=
seq_to_remove
.to_remove_event
(
i
as
u64
);
let
start
=
Instant
::
now
();
let
_
=
tree
.apply_event
(
remove_event
);
let
elapsed
=
start
.elapsed
();
durations
.push
(
elapsed
);
// Re-add the sequence to restore tree state (untimed)
let
store_event
=
seq_to_remove
.to_store_event
(
i
as
u64
+
args
.iterations
as
u64
);
let
_
=
tree
.apply_event
(
store_event
);
if
args
.verbose
&&
(
i
+
1
)
%
100
==
0
{
println!
(
" Completed {}/{} iterations"
,
i
+
1
,
args
.iterations
);
}
}
let
stats
=
LatencyStats
::
from_durations
(
durations
);
stats
.print
(
"REMOVE_BLOCK"
,
args
.depth
);
bench_store_remove_cycle
(
args
,
false
);
}
/// Benchmark find_matches operation
fn
bench_find_matches
(
args
:
&
Args
)
{
println!
(
"
\n
=== Benchmarking FIND_MATCHES ==="
);
let
num_sequences
=
args
.size
/
args
.depth
;
let
sequences
=
generate_sequences
(
num_sequences
,
...
...
@@ -428,104 +550,74 @@ fn bench_find_matches(args: &Args) {
args
.seed
,
);
// Build tree once for all find_matches calls
let
tree
=
build_tree
(
&
sequences
);
let
index
=
build_index
(
&
sequences
,
args
.flat_hashmap
);
println!
(
"
\n
=== Benchmarking FIND_MATCHES ({}) ==="
,
index
.name
());
println!
(
"
Tree b
uilt with {} sequences, {} total blocks"
,
"
B
uilt with {} sequences, {} total blocks"
,
sequences
.len
(),
tree
.current_size
()
index
.current_size
()
);
// Benchmark hit case (lookup existing sequences)
println!
(
"
\n
--- HIT case (existing sequences) ---"
)
;
let
mut
hit_durations
=
Vec
::
with_capacity
(
args
.iterations
)
;
let
warmup_iters
=
(
args
.iterations
as
f64
*
args
.warmup_ratio
)
as
usize
;
let
measured_iters
=
args
.iterations
-
warmup_iters
;
let
half
=
args
.depth
/
2
;
// HIT case
println!
(
"
\n
--- HIT case (existing sequences) ---"
);
let
mut
hit_durations
=
Vec
::
with_capacity
(
measured_iters
);
for
i
in
0
..
args
.iterations
{
let
seq
=
&
sequences
[
i
%
sequences
.len
()];
let
hashes_copy
=
seq
.local_hashes
.clone
();
let
start
=
Instant
::
now
();
let
_
=
tree
.find_matches
(
hashes_copy
,
false
);
let
elapsed
=
start
.elapsed
();
hit_durations
.push
(
elapsed
);
let
elapsed
=
index
.find_matches_timed
(
seq
,
false
);
if
i
>=
warmup_iters
{
hit_durations
.push
(
elapsed
);
}
if
args
.verbose
&&
(
i
+
1
)
%
100
==
0
{
println!
(
" Completed {}/{} iterations"
,
i
+
1
,
args
.iterations
);
}
}
LatencyStats
::
from_durations
(
hit_durations
)
.print
(
"FIND_MATCHES (HIT)"
,
args
.depth
);
let
hit_stats
=
LatencyStats
::
from_durations
(
hit_durations
);
hit_stats
.print
(
"FIND_MATCHES (HIT)"
,
args
.depth
);
// Benchmark miss case (find_matches on non-existing sequences)
// MISS case
println!
(
"
\n
--- MISS case (non-existing sequences) ---"
);
let
mut
miss_durations
=
Vec
::
with_capacity
(
args
.iterations
);
let
mut
miss_durations
=
Vec
::
with_capacity
(
measured_iters
);
for
i
in
0
..
args
.iterations
{
// Generate a sequence that won't match
let
miss_hashes
:
Vec
<
LocalBlockHash
>
=
(
0
..
args
.depth
)
.map
(|
j
|
LocalBlockHash
(
0xBAD_C0DE_0000_0000
|
((
i
as
u64
)
<<
16
)
|
(
j
as
u64
)))
.collect
();
let
start
=
Instant
::
now
();
let
_
=
tree
.find_matches
(
miss_hashes
,
false
);
let
elapsed
=
start
.elapsed
();
miss_durations
.push
(
elapsed
);
let
elapsed
=
index
.find_matches_miss_timed
(
args
.depth
,
i
,
false
);
if
i
>=
warmup_iters
{
miss_durations
.push
(
elapsed
);
}
if
args
.verbose
&&
(
i
+
1
)
%
100
==
0
{
println!
(
" Completed {}/{} iterations"
,
i
+
1
,
args
.iterations
);
}
}
LatencyStats
::
from_durations
(
miss_durations
)
.print
(
"FIND_MATCHES (MISS)"
,
args
.depth
);
let
miss_stats
=
LatencyStats
::
from_durations
(
miss_durations
);
miss_stats
.print
(
"FIND_MATCHES (MISS)"
,
args
.depth
);
// Benchmark partial match case
// PARTIAL case
println!
(
"
\n
--- PARTIAL case (prefix match only) ---"
);
let
mut
partial_durations
=
Vec
::
with_capacity
(
args
.iterations
);
let
mut
partial_durations
=
Vec
::
with_capacity
(
measured_iters
);
for
i
in
0
..
args
.iterations
{
let
seq
=
&
sequences
[
i
%
sequences
.len
()];
// Use first half of real sequence, second half is garbage
let
half
=
args
.depth
/
2
;
let
mut
partial_hashes
=
seq
.local_hashes
[
..
half
]
.to_vec
();
partial_hashes
.extend
(
(
0
..
half
)
.map
(|
j
|
LocalBlockHash
(
0xDEAD_0000
|
((
i
as
u64
)
<<
16
)
|
(
j
as
u64
))),
);
let
start
=
Instant
::
now
();
let
_
=
tree
.find_matches
(
partial_hashes
,
false
);
let
elapsed
=
start
.elapsed
();
partial_durations
.push
(
elapsed
);
let
elapsed
=
index
.find_matches_partial_timed
(
seq
,
half
,
i
,
false
);
if
i
>=
warmup_iters
{
partial_durations
.push
(
elapsed
);
}
if
args
.verbose
&&
(
i
+
1
)
%
100
==
0
{
println!
(
" Completed {}/{} iterations"
,
i
+
1
,
args
.iterations
);
}
}
LatencyStats
::
from_durations
(
partial_durations
)
.print
(
"FIND_MATCHES (PARTIAL)"
,
args
.depth
);
let
partial_stats
=
LatencyStats
::
from_durations
(
partial_durations
);
partial_stats
.print
(
"FIND_MATCHES (PARTIAL)"
,
args
.depth
);
// Benchmark with early_exit=true
// EARLY_EXIT case
println!
(
"
\n
--- EARLY_EXIT case ---"
);
let
mut
early_exit_durations
=
Vec
::
with_capacity
(
args
.iterations
);
let
mut
early_exit_durations
=
Vec
::
with_capacity
(
measured_iters
);
for
i
in
0
..
args
.iterations
{
let
seq
=
&
sequences
[
i
%
sequences
.len
()];
let
start
=
Instant
::
now
();
let
_
=
tree
.find_matches
(
seq
.local_hashes
.clone
(),
true
);
let
elapsed
=
start
.elapsed
();
early_exit_durations
.push
(
elapsed
);
let
elapsed
=
index
.find_matches_timed
(
seq
,
true
);
if
i
>=
warmup_iters
{
early_exit_durations
.push
(
elapsed
);
}
}
let
early_exit_stats
=
LatencyStats
::
from_durations
(
early_exit_durations
);
early_exit_stats
.print
(
"FIND_MATCHES (EARLY_EXIT)"
,
args
.depth
);
LatencyStats
::
from_durations
(
early_exit_durations
)
.print
(
"FIND_MATCHES (EARLY_EXIT)"
,
args
.depth
);
}
/// Generate logarithmically spaced values between min and max
...
...
@@ -932,6 +1024,10 @@ fn main() {
eprintln!
(
"prefix_prompt_ratio must be between 0.0 and 1.0"
);
std
::
process
::
exit
(
1
);
}
if
!
(
0.0
..=
1.0
)
.contains
(
&
args
.warmup_ratio
)
{
eprintln!
(
"warmup_ratio must be between 0.0 and 1.0"
);
std
::
process
::
exit
(
1
);
}
let
num_sequences
=
args
.size
/
args
.depth
;
if
matches!
(
...
...
@@ -959,6 +1055,11 @@ fn main() {
println!
(
" Block size: {} tokens"
,
args
.block_size
);
println!
(
" Workers: {}"
,
args
.num_workers
);
println!
(
" Iterations: {}"
,
args
.iterations
);
println!
(
" Warmup: {:.0}% ({} iterations discarded)"
,
args
.warmup_ratio
*
100.0
,
(
args
.iterations
as
f64
*
args
.warmup_ratio
)
as
usize
);
println!
(
" Prefix prompt ratio: {:.1}% ({} blocks at depth {})"
,
args
.prefix_prompt_ratio
*
100.0
,
...
...
lib/kv-router/src/flat_hashmap.rs
0 → 100644
View file @
fc92fc18
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Flat HashMap baseline for benchmarking comparison with RadixTree.
//!
//! This module provides a `FlatHashMap` structure that has full feature parity with `RadixTree`
//! but uses flat HashMaps instead of a tree structure. This isolates the overhead of
//! tree traversal (pointer chasing) from pure HashMap operations.
//!
//! The `find_matches` API matches RadixTree exactly: it takes `LocalBlockHash` values
//! and internally computes the cumulative sequence hashes for lookup.
use
std
::
collections
::{
HashMap
,
HashSet
};
use
crate
::
protocols
::{
ExternalSequenceBlockHash
,
KvCacheEvent
,
KvCacheEventData
,
KvCacheStoreData
,
KvCacheStoredBlockData
,
LocalBlockHash
,
OverlapScores
,
RouterEvent
,
WorkerId
,
WorkerWithDpRank
,
compute_seq_hash_for_block
,
};
/// A flat HashMap-based structure for KV cache indexing.
///
/// Unlike RadixTree which uses a tree of nodes connected by pointers,
/// FlatHashMap uses bidirectional HashMaps. This provides the same
/// find_matches semantics but with better cache locality.
///
/// # Structure
///
/// - `block_to_workers`: Maps ExternalSequenceBlockHash -> Set of workers that have this block.
/// Used for efficient find_matches lookups.
/// - `worker_to_blocks`: Maps Worker -> Set of ExternalSequenceBlockHash they have.
/// Used for remove operations and current_size.
pub
struct
FlatHashMap
{
/// Primary index: block -> workers (for find_matches)
block_to_workers
:
HashMap
<
ExternalSequenceBlockHash
,
HashSet
<
WorkerWithDpRank
>>
,
/// Secondary index: worker -> blocks (for remove and current_size)
worker_to_blocks
:
HashMap
<
WorkerWithDpRank
,
HashSet
<
ExternalSequenceBlockHash
>>
,
}
impl
FlatHashMap
{
/// Create a new empty FlatHashMap.
pub
fn
new
()
->
Self
{
Self
{
block_to_workers
:
HashMap
::
new
(),
worker_to_blocks
:
HashMap
::
new
(),
}
}
/// Store blocks for a worker.
///
/// Updates both indexes for each block.
pub
fn
store
(
&
mut
self
,
worker
:
WorkerWithDpRank
,
block_hashes
:
&
[
ExternalSequenceBlockHash
])
{
let
worker_blocks
=
self
.worker_to_blocks
.entry
(
worker
)
.or_default
();
for
&
block_hash
in
block_hashes
{
// Add to block -> workers index
self
.block_to_workers
.entry
(
block_hash
)
.or_default
()
.insert
(
worker
);
// Add to worker -> blocks index
worker_blocks
.insert
(
block_hash
);
}
}
/// Remove blocks for a worker.
///
/// Updates both indexes for each block.
pub
fn
remove
(
&
mut
self
,
worker
:
WorkerWithDpRank
,
block_hashes
:
&
[
ExternalSequenceBlockHash
])
{
let
Some
(
worker_blocks
)
=
self
.worker_to_blocks
.get_mut
(
&
worker
)
else
{
return
;
};
for
&
block_hash
in
block_hashes
{
// Remove from worker -> blocks index
worker_blocks
.remove
(
&
block_hash
);
// Remove from block -> workers index
if
let
Some
(
workers
)
=
self
.block_to_workers
.get_mut
(
&
block_hash
)
{
workers
.remove
(
&
worker
);
if
workers
.is_empty
()
{
self
.block_to_workers
.remove
(
&
block_hash
);
}
}
}
// Clean up empty worker entry
if
worker_blocks
.is_empty
()
{
self
.worker_to_blocks
.remove
(
&
worker
);
}
}
/// Find matches for a sequence of local block hashes.
///
/// This has the same signature as `RadixTree::find_matches`: it takes `LocalBlockHash`
/// values and internally computes the cumulative sequence hashes for lookup.
///
/// Returns OverlapScores showing which workers have matching blocks.
/// Stops at first non-match (same semantics as RadixTree).
///
/// # Algorithm
///
/// 1. Compute cumulative sequence hashes from local block hashes
/// 2. For each sequence hash:
/// - Look up which workers have this block
/// - Intersect with previously matching workers (in place)
/// - Track depth for scoring
/// - Stop if no workers remain
///
/// This is O(depth) HashMap lookups + O(num_workers) set operations per level.
pub
fn
find_matches
(
&
self
,
sequence
:
Vec
<
LocalBlockHash
>
,
early_exit
:
bool
)
->
OverlapScores
{
let
mut
scores
=
OverlapScores
::
new
();
if
sequence
.is_empty
()
{
return
scores
;
}
// Compute cumulative sequence hashes from local block hashes
let
seq_hashes
=
compute_seq_hash_for_block
(
&
sequence
);
// Track active workers and their match depth
// Workers drop out when they miss a block; their final score is the depth they reached
let
mut
active_workers
:
Option
<
HashSet
<
WorkerWithDpRank
>>
=
None
;
let
mut
depth
=
0u32
;
for
seq_hash
in
seq_hashes
{
let
block_hash
=
ExternalSequenceBlockHash
(
seq_hash
);
// Look up workers that have this block
let
Some
(
workers
)
=
self
.block_to_workers
.get
(
&
block_hash
)
else
{
break
;
// No workers have this block, stop
};
// Intersect with previously active workers (or initialize on first block)
match
&
mut
active_workers
{
None
=>
{
// First block: initialize with workers that have it
active_workers
=
Some
(
workers
.clone
());
}
Some
(
active
)
=>
{
// Record score for workers about to drop out (they matched up to current depth)
for
&
worker
in
active
.iter
()
{
if
!
workers
.contains
(
&
worker
)
{
scores
.scores
.insert
(
worker
,
depth
);
}
}
// Keep only workers that have this block (in-place, no allocation)
active
.retain
(|
w
|
workers
.contains
(
w
));
}
}
depth
+=
1
;
let
active
=
active_workers
.as_ref
()
.unwrap
();
if
active
.is_empty
()
{
break
;
}
// Early exit if only one worker matches
if
early_exit
&&
active
.len
()
==
1
{
break
;
}
}
// Record final scores for workers that matched all blocks (or until early exit)
if
let
Some
(
active
)
=
active_workers
{
for
worker
in
active
{
scores
.scores
.insert
(
worker
,
depth
);
}
}
// Populate tree sizes for workers with scores
for
&
worker
in
scores
.scores
.keys
()
{
if
let
Some
(
blocks
)
=
self
.worker_to_blocks
.get
(
&
worker
)
{
scores
.tree_sizes
.insert
(
worker
,
blocks
.len
());
}
}
scores
}
/// Apply a RouterEvent (for API compatibility with RadixTree).
pub
fn
apply_event
(
&
mut
self
,
event
:
RouterEvent
)
{
let
worker
=
WorkerWithDpRank
::
new
(
event
.worker_id
,
event
.event.dp_rank
);
match
event
.event.data
{
KvCacheEventData
::
Stored
(
store_data
)
=>
{
let
hashes
:
Vec
<
_
>
=
store_data
.blocks
.iter
()
.map
(|
b
|
b
.block_hash
)
.collect
();
self
.store
(
worker
,
&
hashes
);
}
KvCacheEventData
::
Removed
(
remove_data
)
=>
{
self
.remove
(
worker
,
&
remove_data
.block_hashes
);
}
KvCacheEventData
::
Cleared
=>
{
self
.clear_all_blocks
(
worker
.worker_id
);
}
}
}
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains in lookup with empty blocks.
/// If `keep_worker` is false, the worker is completely removed from lookup.
fn
remove_or_clear_worker_blocks
(
&
mut
self
,
worker_id
:
WorkerId
,
keep_worker
:
bool
)
{
// Collect all WorkerWithDpRank keys that match this worker_id
let
workers
:
Vec
<
WorkerWithDpRank
>
=
self
.worker_to_blocks
.keys
()
.filter
(|
w
|
w
.worker_id
==
worker_id
)
.copied
()
.collect
();
for
worker
in
workers
{
if
let
Some
(
blocks
)
=
self
.worker_to_blocks
.remove
(
&
worker
)
{
for
block_hash
in
blocks
{
if
let
Some
(
workers_set
)
=
self
.block_to_workers
.get_mut
(
&
block_hash
)
{
workers_set
.remove
(
&
worker
);
if
workers_set
.is_empty
()
{
self
.block_to_workers
.remove
(
&
block_hash
);
}
}
}
if
keep_worker
{
// Re-insert worker with empty blocks set to keep it tracked
self
.worker_to_blocks
.insert
(
worker
,
HashSet
::
new
());
}
}
}
}
/// Remove a worker and all their blocks from the index.
pub
fn
remove_worker
(
&
mut
self
,
worker_id
:
WorkerId
)
{
self
.remove_or_clear_worker_blocks
(
worker_id
,
false
);
}
/// Clear all blocks for a worker but keep the worker tracked.
pub
fn
clear_all_blocks
(
&
mut
self
,
worker_id
:
WorkerId
)
{
self
.remove_or_clear_worker_blocks
(
worker_id
,
true
);
}
/// Get all worker IDs currently tracked in the index.
/// Returns unique worker_ids sorted (ignoring dp_rank differences).
pub
fn
get_workers
(
&
self
)
->
Vec
<
WorkerId
>
{
let
mut
worker_ids
:
Vec
<
WorkerId
>
=
self
.worker_to_blocks
.keys
()
.map
(|
w
|
w
.worker_id
)
.collect
::
<
HashSet
<
_
>>
()
.into_iter
()
.collect
();
worker_ids
.sort_unstable
();
worker_ids
}
/// Dump the index as a series of RouterEvents that can reconstruct the state.
/// For API compatibility with RadixTree.
pub
fn
dump_tree_as_events
(
&
self
)
->
Vec
<
RouterEvent
>
{
let
mut
events
=
Vec
::
new
();
let
mut
event_id
=
0u64
;
for
(
&
worker
,
blocks
)
in
&
self
.worker_to_blocks
{
for
&
block_hash
in
blocks
{
let
event
=
RouterEvent
{
worker_id
:
worker
.worker_id
,
event
:
KvCacheEvent
{
event_id
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
// FlatHashMap doesn't track parent relationships
blocks
:
vec!
[
KvCacheStoredBlockData
{
block_hash
,
mm_extra_info
:
None
,
// We don't have the original tokens_hash, use a placeholder
tokens_hash
:
LocalBlockHash
(
0
),
}],
}),
dp_rank
:
worker
.dp_rank
,
},
};
events
.push
(
event
);
event_id
+=
1
;
}
}
events
}
/// Returns the total number of (worker, block) pairs stored.
pub
fn
current_size
(
&
self
)
->
usize
{
self
.worker_to_blocks
.values
()
.map
(|
s
|
s
.len
())
.sum
()
}
}
impl
Default
for
FlatHashMap
{
fn
default
()
->
Self
{
Self
::
new
()
}
}
lib/kv-router/src/indexer.rs
View file @
fc92fc18
...
...
@@ -54,21 +54,115 @@ use serde::{Deserialize, Serialize};
#[cfg(feature
=
"metrics"
)]
use
std
::
sync
::
OnceLock
;
use
std
::{
cell
::
RefCell
,
collections
::{
HashMap
,
HashSet
,
VecDeque
},
collections
::{
HashMap
,
VecDeque
},
iter
,
rc
::
Rc
,
sync
::{
Arc
,
Mutex
},
thread
::
JoinHandle
,
time
::
{
Duration
,
Instant
},
time
::
Duration
,
};
use
tokio
::
sync
::{
broadcast
,
mpsc
,
oneshot
};
use
tokio_util
::
sync
::
CancellationToken
;
use
crate
::
approx
::{
BlockEntry
,
PruneConfig
,
PruneManager
};
use
crate
::
flat_hashmap
::
FlatHashMap
;
use
crate
::
protocols
::
*
;
pub
use
crate
::
radix_tree
::
RadixTree
;
use
dynamo_tokens
::
SequenceHash
;
// ------
// KvIndex - Unified interface for RadixTree and FlatHashMap
// ------
/// Unified interface for KV cache indexing.
///
/// Both `RadixTree` and `FlatHashMap` implement the same core operations:
/// - `find_matches`: Find workers with matching cached blocks
/// - `apply_event`: Apply store/remove events
/// - `remove_worker`: Remove a worker's entries
/// - `get_workers`: Get all tracked workers
/// - `dump_tree_as_events`: Dump state as events
/// - `current_size`: Get total (worker, block) pairs
pub
enum
KvIndex
{
Tree
(
RadixTree
),
Flat
(
FlatHashMap
),
}
impl
KvIndex
{
/// Create a new KvIndex using RadixTree.
pub
fn
new_tree
()
->
Self
{
KvIndex
::
Tree
(
RadixTree
::
new
())
}
/// Create a new KvIndex using RadixTree with frequency tracking.
pub
fn
new_tree_with_frequency
(
expiration_duration
:
Option
<
std
::
time
::
Duration
>
)
->
Self
{
KvIndex
::
Tree
(
RadixTree
::
new_with_frequency
(
expiration_duration
))
}
/// Create a new KvIndex using FlatHashMap.
pub
fn
new_flat
()
->
Self
{
KvIndex
::
Flat
(
FlatHashMap
::
new
())
}
/// Find matches for a sequence of local block hashes.
pub
fn
find_matches
(
&
self
,
sequence
:
Vec
<
LocalBlockHash
>
,
early_exit
:
bool
)
->
OverlapScores
{
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.find_matches
(
sequence
,
early_exit
),
KvIndex
::
Flat
(
map
)
=>
map
.find_matches
(
sequence
,
early_exit
),
}
}
/// Apply a RouterEvent to the index.
pub
fn
apply_event
(
&
mut
self
,
event
:
RouterEvent
)
->
Result
<
(),
KvCacheEventError
>
{
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.apply_event
(
event
),
KvIndex
::
Flat
(
map
)
=>
{
map
.apply_event
(
event
);
Ok
(())
}
}
}
/// Remove a worker and all their blocks from the index.
pub
fn
remove_worker
(
&
mut
self
,
worker_id
:
WorkerId
)
{
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.remove_worker
(
worker_id
),
KvIndex
::
Flat
(
map
)
=>
map
.remove_worker
(
worker_id
),
}
}
/// Clear all blocks for a worker but keep the worker tracked.
pub
fn
clear_all_blocks
(
&
mut
self
,
worker_id
:
WorkerId
)
{
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.clear_all_blocks
(
worker_id
),
KvIndex
::
Flat
(
map
)
=>
map
.clear_all_blocks
(
worker_id
),
}
}
/// Get all worker IDs currently tracked.
pub
fn
get_workers
(
&
self
)
->
Vec
<
WorkerId
>
{
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.get_workers
(),
KvIndex
::
Flat
(
map
)
=>
map
.get_workers
(),
}
}
/// Dump the index as a series of RouterEvents.
pub
fn
dump_tree_as_events
(
&
self
)
->
Vec
<
RouterEvent
>
{
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.dump_tree_as_events
(),
KvIndex
::
Flat
(
map
)
=>
map
.dump_tree_as_events
(),
}
}
/// Returns the total number of (worker, block) pairs stored.
pub
fn
current_size
(
&
self
)
->
usize
{
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.current_size
(),
KvIndex
::
Flat
(
map
)
=>
map
.current_size
(),
}
}
}
/// Errors that can occur in the KV Router.
#[derive(Debug,
thiserror::Error)]
pub
enum
KvRouterError
{
...
...
@@ -85,47 +179,6 @@ pub enum KvRouterError {
PruneFailed
(
String
),
}
/// Errors that can occur during KV Cache Event processing.
#[derive(Debug,
thiserror::Error)]
pub
enum
KvCacheEventError
{
#[error(
"Failed to find parent block"
)]
ParentBlockNotFound
,
#[error(
"Failed to find block"
)]
BlockNotFound
,
#[error(
"Invalid block sequence"
)]
InvalidBlockSequence
,
}
/// A shared reference to a [`RadixBlock`].
type
SharedRadixBlock
=
Rc
<
RefCell
<
RadixBlock
>>
;
/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`].
#[derive(Debug,
Clone,
Serialize,
Deserialize,
PartialEq)]
pub
struct
RouterEvent
{
/// The ID of the worker emitting the event.
pub
worker_id
:
WorkerId
,
/// The cache event associated with the worker.
pub
event
:
KvCacheEvent
,
}
impl
RouterEvent
{
/// Create a new `RouterEvent`.
///
/// ### Arguments
///
/// * `worker_id` - The ID of the worker emitting the event.
/// * `event` - The cache event.
///
/// ### Returns
///
/// A new `RouterEvent`.
pub
fn
new
(
worker_id
:
WorkerId
,
event
:
KvCacheEvent
)
->
Self
{
Self
{
worker_id
,
event
}
}
}
// -------
// Distributed router - Worker KV Query types
// -------
...
...
@@ -174,450 +227,6 @@ impl MaybeError for WorkerKvQueryResponse {
}
}
/// A block in the Radix Tree.
#[derive(Debug)]
struct
RadixBlock
{
/// A map of child blocks, keyed by their local block hash.
children
:
HashMap
<
LocalBlockHash
,
SharedRadixBlock
>
,
/// The set of workers that have this block cached.
workers
:
HashSet
<
WorkerWithDpRank
>
,
/// The external sequence block hash for this block (None for root).
/// This is the same for all workers under the simplifying assumption.
block_hash
:
Option
<
ExternalSequenceBlockHash
>
,
/// A buffer of times that this block was last traversed
recent_uses
:
VecDeque
<
Instant
>
,
}
impl
RadixBlock
{
/// Create a new `RadixBlock` (used for root node).
///
/// ### Returns
///
/// A new `RadixBlock` with no block_hash.
pub
fn
new
()
->
Self
{
Self
{
children
:
HashMap
::
new
(),
workers
:
HashSet
::
new
(),
block_hash
:
None
,
recent_uses
:
VecDeque
::
new
(),
}
}
/// Create a new `RadixBlock` with a specific block hash.
///
/// ### Returns
///
/// A new `RadixBlock` with the given block_hash.
pub
fn
with_hash
(
block_hash
:
ExternalSequenceBlockHash
)
->
Self
{
Self
{
children
:
HashMap
::
new
(),
workers
:
HashSet
::
new
(),
block_hash
:
Some
(
block_hash
),
recent_uses
:
VecDeque
::
new
(),
}
}
}
pub
struct
RadixTree
{
/// This is the root of the radix/prefix tree
/// This will only contain root blocks
root
:
SharedRadixBlock
,
/// Per-worker lookup table for O(1) block access.
/// Maps worker -> (block_hash -> block).
lookup
:
HashMap
<
WorkerWithDpRank
,
HashMap
<
ExternalSequenceBlockHash
,
SharedRadixBlock
>>
,
/// The time buffer the radix tree should check when considering frequence of block accesses
expiration_duration
:
Option
<
Duration
>
,
}
impl
Default
for
RadixTree
{
fn
default
()
->
Self
{
Self
::
new
()
}
}
// Dropping Radix blocks can cause a cascade of drops that can overflow the stack.
// This custom drop implementation avoids this using an iterative approach.
impl
Drop
for
RadixTree
{
fn
drop
(
&
mut
self
)
{
let
mut
stack
:
Vec
<
SharedRadixBlock
>
=
Vec
::
new
();
// Break root -> children edge up front
{
let
mut
root
=
self
.root
.borrow_mut
();
stack
.extend
(
root
.children
.drain
()
.map
(|(
_
,
v
)|
v
));
}
// Remove all lookup references (they may include blocks not reachable from root)
for
(
_
,
worker_blocks
)
in
self
.lookup
.drain
()
{
stack
.extend
(
worker_blocks
.into_values
());
}
// Iteratively free any uniquely-owned blocks without recursion
while
let
Some
(
block
)
=
stack
.pop
()
{
match
Rc
::
try_unwrap
(
block
)
{
Ok
(
cell
)
=>
{
// We own the cell, so we can take inner and it will drop after this block.
let
mut
inner
:
RadixBlock
=
cell
.into_inner
();
stack
.extend
(
inner
.children
.drain
()
.map
(|(
_
,
v
)|
v
));
}
Err
(
rc
)
=>
{
// We don't own the cell, just call drop on it.
drop
(
rc
);
}
}
}
}
}
impl
RadixTree
{
/// Create a new `RadixTree`.
///
/// ### Returns
///
/// A new `RadixTree`.
pub
fn
new_with_frequency
(
expiration_duration
:
Option
<
Duration
>
)
->
Self
{
Self
{
root
:
Rc
::
new
(
RefCell
::
new
(
RadixBlock
::
new
())),
lookup
:
HashMap
::
new
(),
expiration_duration
,
}
}
pub
fn
new
()
->
Self
{
Self
::
new_with_frequency
(
None
)
}
/// Traverse the radix tree to find the best match for a given sequence of [`LocalBlockHash`]es.
///
/// ### Arguments
///
/// * `sequence` - A vector of `LocalBlockHash` representing the sequence to match.
/// * `early_exit` - A boolean indicating whether to exit early if a single match is found.
///
/// ### Returns
///
/// An `OverlapScores` representing the match scores.
pub
fn
find_matches
(
&
self
,
sequence
:
Vec
<
LocalBlockHash
>
,
early_exit
:
bool
)
->
OverlapScores
{
let
mut
scores
=
OverlapScores
::
new
();
let
mut
current
=
self
.root
.clone
();
let
now
=
Instant
::
now
();
tracing
::
trace!
(
"RadixTree::find_matches: looking for sequence={:?}"
,
sequence
.iter
()
.map
(|
h
|
h
.0
)
.collect
::
<
Vec
<
_
>>
()
);
for
(
idx
,
block_hash
)
in
sequence
.iter
()
.enumerate
()
{
let
next_block
=
{
let
current_borrow
=
current
.borrow
();
current_borrow
.children
.get
(
block_hash
)
.cloned
()
};
if
let
Some
(
block
)
=
next_block
{
scores
.update_scores
(
block
.borrow
()
.workers
.iter
());
if
let
Some
(
expiration_duration
)
=
self
.expiration_duration
{
let
mut
block_mut
=
block
.borrow_mut
();
while
let
Some
(
access_time
)
=
block_mut
.recent_uses
.front
()
{
if
now
.duration_since
(
*
access_time
)
>
expiration_duration
{
block_mut
.recent_uses
.pop_front
();
}
else
{
break
;
}
}
scores
.add_frequency
(
block_mut
.recent_uses
.len
());
block_mut
.recent_uses
.push_back
(
now
);
}
if
early_exit
&&
block
.borrow
()
.workers
.len
()
==
1
{
break
;
}
current
=
block
;
}
else
{
tracing
::
trace!
(
"RadixTree::find_matches: block not found at index {} for hash {}"
,
idx
,
block_hash
.0
);
break
;
}
}
tracing
::
trace!
(
"RadixTree::find_matches: final scores={:?}"
,
scores
.scores
);
// Populate tree sizes for all workers that have scores
for
worker
in
scores
.scores
.keys
()
{
let
tree_size
=
self
.lookup
.get
(
worker
)
.expect
(
"worker in scores must exist in lookup table"
)
.len
();
scores
.tree_sizes
.insert
(
*
worker
,
tree_size
);
}
scores
}
/// Apply a [`RouterEvent`] to the radix tree.
///
/// ### Arguments
///
/// * `event` - The `RouterEvent` to apply.
pub
fn
apply_event
(
&
mut
self
,
event
:
RouterEvent
)
->
Result
<
(),
KvCacheEventError
>
{
let
(
worker_id
,
kv_event
)
=
(
event
.worker_id
,
event
.event
);
let
(
id
,
op
)
=
(
kv_event
.event_id
,
kv_event
.data
);
// Construct WorkerWithDpRank from worker_id and dp_rank from the event
let
worker
=
WorkerWithDpRank
::
new
(
worker_id
,
kv_event
.dp_rank
);
tracing
::
trace!
(
id
,
"RadixTree::apply_event: Store operation: {:?}"
,
op
);
let
worker_lookup
=
self
.lookup
.entry
(
worker
)
.or_default
();
match
op
{
KvCacheEventData
::
Stored
(
op
)
=>
{
// find the parent block from this worker's lookup
let
mut
current
=
match
op
.parent_hash
{
Some
(
parent
)
=>
match
worker_lookup
.get
(
&
parent
)
{
Some
(
current
)
=>
current
.clone
(),
None
=>
{
tracing
::
warn!
(
worker_id
=
worker
.worker_id
.to_string
(),
dp_rank
=
worker
.dp_rank
,
id
,
parent_hash
=
?
op
.parent_hash
,
num_blocks
=
op
.blocks
.len
(),
"Failed to find parent block; skipping store operation"
);
return
Err
(
KvCacheEventError
::
ParentBlockNotFound
);
}
},
None
=>
self
.root
.clone
(),
};
for
block_data
in
op
.blocks
{
let
mut
parent_mut
=
current
.borrow_mut
();
let
child
=
match
parent_mut
.children
.get
(
&
block_data
.tokens_hash
)
{
Some
(
block
)
=>
{
// Verify our simplifying assumption: block_hash is uniform across workers
if
block
.borrow
()
.block_hash
!=
Some
(
block_data
.block_hash
)
{
tracing
::
warn!
(
expected
=
?
block_data
.block_hash
,
actual
=
?
block
.borrow
()
.block_hash
,
"block_hash mismatch: sequence hashes should be uniform across workers"
);
}
block
.clone
()
}
None
=>
{
// create new block or reuse existing from worker's lookup
let
new_block
=
worker_lookup
.get
(
&
block_data
.block_hash
)
.cloned
()
.unwrap_or_else
(||
{
Rc
::
new
(
RefCell
::
new
(
RadixBlock
::
with_hash
(
block_data
.block_hash
,
)))
});
// insert into radix tree
parent_mut
.children
.insert
(
block_data
.tokens_hash
,
new_block
.clone
());
new_block
}
};
// Update child and check for self referential blocks
{
// Try to borrow the child mutably - if it fails, it's already borrowed
// which means a self referencing block.
let
mut
child_mut
=
match
child
.try_borrow_mut
()
{
Ok
(
b
)
=>
b
,
Err
(
_
)
=>
{
tracing
::
warn!
(
worker_id
=
worker
.worker_id
.to_string
(),
dp_rank
=
worker
.dp_rank
,
id
,
block_hash
=
?
block_data
.block_hash
,
"Detected self referencing block in store event; rejecting sequence"
);
return
Err
(
KvCacheEventError
::
InvalidBlockSequence
);
}
};
// add our worker to the block
child_mut
.workers
.insert
(
worker
);
}
// add the block to the worker's lookup table
worker_lookup
.insert
(
block_data
.block_hash
,
child
.clone
());
// drop child so we can shift current to this block
drop
(
parent_mut
);
current
=
child
;
}
Ok
(())
}
KvCacheEventData
::
Removed
(
remove
)
=>
{
let
mut
kv_cache_err
:
Option
<
KvCacheEventError
>
=
None
;
for
block
in
remove
.block_hashes
{
// lookup block in worker's table
let
entry
=
match
worker_lookup
.get
(
&
block
)
{
Some
(
entry
)
=>
entry
.clone
(),
None
=>
{
tracing
::
warn!
(
worker_id
=
worker
.worker_id
.to_string
(),
dp_rank
=
worker
.dp_rank
,
id
,
block_hash
=
?
block
,
"Failed to find block to remove; skipping remove operation"
);
// Kv cache removed events may be batched; we should try to apply all
// operations in the batch before returning an error. Return the first
// error.
if
kv_cache_err
.is_none
()
{
kv_cache_err
=
Some
(
KvCacheEventError
::
BlockNotFound
);
}
continue
;
}
};
let
mut
guard
=
entry
.borrow_mut
();
guard
.workers
.remove
(
&
worker
);
if
guard
.workers
.is_empty
()
{
// if no workers are using this block, that is true for all children
guard
.children
.clear
();
}
// remove the block from the worker's lookup table
worker_lookup
.remove
(
&
block
);
}
kv_cache_err
.map_or
(
Ok
(()),
Err
)
}
KvCacheEventData
::
Cleared
=>
{
self
.clear_all_blocks
(
worker
.worker_id
);
Ok
(())
}
}
}
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains in lookup with empty blocks.
/// If `keep_worker` is false, the worker is completely removed from lookup.
fn
remove_or_clear_worker_blocks
(
&
mut
self
,
worker_id
:
WorkerId
,
keep_worker
:
bool
)
{
// Collect all WorkerWithDpRank keys that match this worker_id
let
workers
:
Vec
<
WorkerWithDpRank
>
=
self
.lookup
.keys
()
.filter
(|
w
|
w
.worker_id
==
worker_id
)
.copied
()
.collect
();
for
worker
in
workers
{
if
let
Some
((
worker_key
,
blocks
))
=
self
.lookup
.remove_entry
(
&
worker
)
{
for
(
_
,
block
)
in
blocks
{
block
.borrow_mut
()
.workers
.remove
(
&
worker
);
// If no workers are using this block, that is true for all children
if
block
.borrow
()
.workers
.is_empty
()
{
block
.borrow_mut
()
.children
.clear
();
}
}
if
keep_worker
{
// Re-insert worker with empty blocks map to keep it tracked
self
.lookup
.insert
(
worker_key
,
HashMap
::
new
());
}
}
}
}
pub
fn
remove_worker
(
&
mut
self
,
worker_id
:
WorkerId
)
{
self
.remove_or_clear_worker_blocks
(
worker_id
,
false
);
}
pub
fn
clear_all_blocks
(
&
mut
self
,
worker_id
:
WorkerId
)
{
self
.remove_or_clear_worker_blocks
(
worker_id
,
true
);
}
/// Get all worker IDs currently tracked in the radix tree.
/// Returns unique worker_ids (ignoring dp_rank differences).
pub
fn
get_workers
(
&
self
)
->
Vec
<
WorkerId
>
{
let
mut
worker_ids
:
Vec
<
WorkerId
>
=
self
.lookup
.keys
()
.map
(|
w
|
w
.worker_id
)
.collect
();
worker_ids
.sort_unstable
();
worker_ids
.dedup
();
worker_ids
}
/// Dump the radix tree as a series of RouterEvents that can reconstruct the tree.
/// Uses BFS traversal to ensure that the tree reconstruction is unique,
/// though the exact event ordering will be lost.
pub
fn
dump_tree_as_events
(
&
self
)
->
Vec
<
RouterEvent
>
{
tracing
::
debug!
(
"Dumping radix tree as events (contains information about {:?} workers)"
,
self
.lookup
.len
()
);
let
mut
events
=
Vec
::
new
();
let
mut
event_id
=
0u64
;
// Queue entries: (current_block, parent_hash, tokens_hash)
let
mut
queue
=
VecDeque
::
new
();
// Process root's children first
let
root_borrow
=
self
.root
.borrow
();
for
(
tokens_hash
,
child_block
)
in
&
root_borrow
.children
{
queue
.push_back
((
child_block
.clone
(),
None
,
*
tokens_hash
));
}
drop
(
root_borrow
);
while
let
Some
((
current_block
,
parent_hash
,
tokens_hash
))
=
queue
.pop_front
()
{
let
current_borrow
=
current_block
.borrow
();
// Get this block's hash (same for all workers)
let
block_hash
=
current_borrow
.block_hash
.expect
(
"non-root block must have block_hash"
);
// For each worker that has this block
for
worker
in
&
current_borrow
.workers
{
// Create a store event for this worker
let
event
=
RouterEvent
{
worker_id
:
worker
.worker_id
,
event
:
KvCacheEvent
{
event_id
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
,
blocks
:
vec!
[
KvCacheStoredBlockData
{
block_hash
,
mm_extra_info
:
None
,
tokens_hash
,
}],
}),
dp_rank
:
worker
.dp_rank
,
},
};
events
.push
(
event
);
event_id
+=
1
;
}
// Enqueue children with this block's hash as their parent
for
(
child_tokens_hash
,
child_block
)
in
&
current_borrow
.children
{
queue
.push_back
((
child_block
.clone
(),
Some
(
block_hash
),
*
child_tokens_hash
));
}
}
events
}
pub
fn
current_size
(
&
self
)
->
usize
{
self
.lookup
.values
()
.map
(|
m
|
m
.len
())
.sum
()
}
}
/// Metrics for the KV Indexer.
#[derive(Clone)]
pub
struct
KvIndexerMetrics
{
...
...
@@ -718,63 +327,6 @@ impl KvIndexerMetrics {
}
}
/// Scores representing the overlap of workers (with their dp_rank).
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
pub
struct
OverlapScores
{
// map of worker (with dp_rank) to score
pub
scores
:
HashMap
<
WorkerWithDpRank
,
u32
>
,
// List of frequencies that the blocks have been accessed. Entries with value 0 are omitted.
pub
frequencies
:
Vec
<
usize
>
,
// Map of worker to their tree size (number of blocks in the tree for that worker)
pub
tree_sizes
:
HashMap
<
WorkerWithDpRank
,
usize
>
,
}
impl
Default
for
OverlapScores
{
fn
default
()
->
Self
{
Self
::
new
()
}
}
impl
OverlapScores
{
/// Create a new `OverlapScores`.
///
/// ### Returns
///
/// A new `OverlapScores`.
pub
fn
new
()
->
Self
{
Self
{
scores
:
HashMap
::
new
(),
frequencies
:
Vec
::
with_capacity
(
32
),
tree_sizes
:
HashMap
::
new
(),
}
}
/// Update the scores with a set of workers.
///
/// ### Arguments
///
/// * `workers` - An iterator over `WorkerWithDpRank` references.
pub
fn
update_scores
<
'a
,
I
>
(
&
mut
self
,
workers
:
I
)
where
I
:
IntoIterator
<
Item
=
&
'a
WorkerWithDpRank
>
,
{
for
worker
in
workers
{
let
score
=
self
.scores
.entry
(
*
worker
)
.or_insert
(
0
);
*
score
+=
1
;
}
}
/// Add an entry in the frequency list.
pub
fn
add_frequency
(
&
mut
self
,
frequency
:
usize
)
{
if
frequency
!=
0
{
self
.frequencies
.last
()
.inspect
(|
elem
|
debug_assert!
(
**
elem
>=
frequency
));
self
.frequencies
.push
(
frequency
);
}
}
}
/// A request to find matches in the Radix Tree.
pub
struct
MatchRequest
{
/// A vector of `LocalBlockHash` representing the sequence to match.
...
...
@@ -2061,6 +1613,7 @@ mod tests {
use
crate
::
protocols
::{
ExternalSequenceBlockHash
,
LocalBlockHash
};
use
rstest
::
rstest
;
use
rstest_reuse
::{
self
,
*
};
use
std
::
time
::
Instant
;
use
tokio
::
time
;
use
tokio_util
::
sync
::
CancellationToken
;
...
...
@@ -2105,584 +1658,6 @@ mod tests {
}
}
fn
create_remove_event
(
worker_id
:
WorkerId
,
event_id
:
u64
,
hashes
:
Vec
<
u64
>
)
->
RouterEvent
{
RouterEvent
{
worker_id
,
event
:
KvCacheEvent
{
event_id
,
data
:
KvCacheEventData
::
Removed
(
KvCacheRemoveData
{
block_hashes
:
hashes
.iter
()
.map
(|
i
|
ExternalSequenceBlockHash
(
*
i
*
100
))
.collect
(),
}),
dp_rank
:
0
,
},
}
}
#[test]
fn
test_radix_tree
()
{
setup
();
let
mut
trie
=
RadixTree
::
new
();
let
worker_1
=
0
;
let
worker_2
=
1
;
trie
.apply_event
(
create_store_event
(
worker_1
,
1
,
vec!
[
1
,
2
,
3
],
None
))
.unwrap
();
let
scores
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
);
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
(),
&
3
);
assert_eq!
(
trie
.lookup
.len
(),
1
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.len
(),
3
);
assert_eq!
(
trie
.root
.borrow
()
.workers
.len
(),
0
);
assert_eq!
(
trie
.root
.borrow
()
.children
.len
(),
1
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.workers
.len
(),
1
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.children
.len
(),
1
);
trie
.apply_event
(
create_store_event
(
worker_2
,
1
,
vec!
[
1
,
4
,
5
],
None
))
.unwrap
();
let
scores
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
);
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
(),
&
3
);
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
(),
&
1
);
assert_eq!
(
trie
.lookup
.len
(),
2
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.len
(),
3
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
()
.len
(),
3
);
assert_eq!
(
trie
.root
.borrow
()
.workers
.len
(),
0
);
assert_eq!
(
trie
.root
.borrow
()
.children
.len
(),
1
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.workers
.len
(),
2
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.children
.len
(),
2
);
trie
.apply_event
(
create_remove_event
(
worker_2
,
2
,
vec!
[
5
]))
.unwrap
();
assert_eq!
(
trie
.lookup
.len
(),
2
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.len
(),
3
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
()
.len
(),
2
);
assert_eq!
(
trie
.root
.borrow
()
.workers
.len
(),
0
);
assert_eq!
(
trie
.root
.borrow
()
.children
.len
(),
1
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.workers
.len
(),
2
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.children
.len
(),
2
);
trie
.apply_event
(
create_remove_event
(
worker_2
,
3
,
vec!
[
4
]))
.unwrap
();
assert_eq!
(
trie
.lookup
.len
(),
2
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.len
(),
3
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
()
.len
(),
1
);
assert_eq!
(
trie
.root
.borrow
()
.workers
.len
(),
0
);
assert_eq!
(
trie
.root
.borrow
()
.children
.len
(),
1
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.workers
.len
(),
2
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.children
.len
(),
2
);
trie
.apply_event
(
create_store_event
(
worker_2
,
4
,
vec!
[
2
,
6
,
7
],
Some
(
ExternalSequenceBlockHash
(
100
)),
))
.unwrap
();
let
scores
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
);
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
(),
&
3
);
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
(),
&
2
);
assert_eq!
(
trie
.lookup
.len
(),
2
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.len
(),
3
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
()
.len
(),
4
);
assert_eq!
(
trie
.root
.borrow
()
.workers
.len
(),
0
);
assert_eq!
(
trie
.root
.borrow
()
.children
.len
(),
1
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.workers
.len
(),
2
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.children
.len
(),
2
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.get
(
&
ExternalSequenceBlockHash
(
200
))
.unwrap
()
.borrow
()
.workers
.len
(),
2
);
}
#[test]
fn
test_radix_tree_apply_event_errors
()
{
let
mut
trie
=
RadixTree
::
new
();
let
worker_0
=
0
;
// Parent block not found
let
result
=
trie
.apply_event
(
create_store_event
(
worker_0
,
0
,
vec!
[
1
,
2
,
3
],
Some
(
ExternalSequenceBlockHash
(
12345
)),
));
assert
!
(
result
.is_err
());
assert
!
(
matches!
(
result
.unwrap_err
(),
KvCacheEventError
::
ParentBlockNotFound
));
// Block not found for remove event.
let
result
=
trie
.apply_event
(
create_remove_event
(
worker_0
,
0
,
vec!
[
1
,
2
,
3
]));
assert
!
(
result
.is_err
());
assert
!
(
matches!
(
result
.unwrap_err
(),
KvCacheEventError
::
BlockNotFound
));
// Parent appears in blocks: parent=1, blocks=[1, 2, 3]
// This should be rejected as block 1 (hash 100) is the parent - this is
// a self referencing block.
trie
.apply_event
(
create_store_event
(
worker_0
,
4
,
vec!
[
1
],
None
))
.unwrap
();
let
result
=
trie
.apply_event
(
create_store_event
(
worker_0
,
5
,
vec!
[
1
,
2
,
3
],
Some
(
ExternalSequenceBlockHash
(
100
)),
));
assert
!
(
matches!
(
result
.unwrap_err
(),
KvCacheEventError
::
InvalidBlockSequence
));
}
#[test]
fn
test_radix_tree_large_stores
()
{
setup
();
let
mut
trie
=
RadixTree
::
new
();
for
i
in
0
..=
16
{
let
len
=
1
<<
i
;
let
worker_id
=
i
;
tracing
::
info!
(
"Testing sequence of length {}"
,
len
);
let
sequence
=
(
1
..
len
+
1
)
.collect
::
<
Vec
<
u64
>>
();
trie
.apply_event
(
create_store_event
(
worker_id
,
1
,
sequence
,
None
))
.unwrap
();
}
}
#[test]
fn
test_remove_worker
()
{
setup
();
let
mut
trie
=
RadixTree
::
new
();
let
worker_0
=
0
;
let
worker_1
=
1
;
assert
!
(
trie
.find_matches
(
vec!
[
LocalBlockHash
(
0
)],
false
)
.scores
.is_empty
()
);
trie
.apply_event
(
create_store_event
(
worker_0
,
0
,
vec!
[
0
],
None
))
.unwrap
();
trie
.apply_event
(
create_store_event
(
worker_1
,
0
,
vec!
[
0
],
None
))
.unwrap
();
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
0
)],
false
)
.scores
;
assert
!
(
result
.len
()
==
2
&&
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
)]
==
1
&&
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)]
==
1
);
trie
.remove_worker
(
worker_0
);
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
0
)],
false
)
.scores
;
assert
!
(
result
.len
()
==
1
&&
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)]
==
1
);
}
#[test]
fn
test_clear_all_blocks
()
{
let
mut
trie
=
RadixTree
::
new
();
let
worker_0
=
0
;
let
worker_1
=
1
;
assert
!
(
trie
.find_matches
(
vec!
[
LocalBlockHash
(
0
)],
false
)
.scores
.is_empty
()
);
// Test clearing an empty worker
trie
.clear_all_blocks
(
worker_0
);
assert
!
(
!
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
// Test clearing a worker with shared blocks
trie
.apply_event
(
create_store_event
(
worker_0
,
0
,
vec!
[
0
,
1
,
3
],
None
))
.unwrap
();
trie
.apply_event
(
create_store_event
(
worker_1
,
0
,
vec!
[
0
,
2
,
3
],
None
))
.unwrap
();
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
0
)],
false
)
.scores
;
assert
!
(
result
.len
()
==
2
&&
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
)]
==
1
&&
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)]
==
1
);
trie
.clear_all_blocks
(
worker_0
);
assert
!
(
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
assert
!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
.unwrap
()
.is_empty
()
);
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
0
),
LocalBlockHash
(
2
)],
false
)
.scores
;
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)],
2
);
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
0
),
LocalBlockHash
(
1
),
LocalBlockHash
(
3
)],
false
,
)
.scores
;
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)],
1
);
// Test re-adding blocks after clearing worker
trie
.apply_event
(
create_store_event
(
worker_0
,
0
,
vec!
[
4
,
5
],
None
))
.unwrap
();
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
4
),
LocalBlockHash
(
5
)],
false
)
.scores
;
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
)],
2
);
// Test multiple clears
trie
.clear_all_blocks
(
worker_0
);
trie
.clear_all_blocks
(
worker_0
);
assert
!
(
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
// Test clearing all workers
trie
.clear_all_blocks
(
worker_0
);
trie
.clear_all_blocks
(
worker_1
);
assert
!
(
!
trie
.lookup
.is_empty
());
assert
!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
.unwrap
()
.is_empty
()
);
assert
!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.is_empty
()
);
// Test clearing a worker that has been removed
trie
.apply_event
(
create_store_event
(
worker_0
,
0
,
vec!
[
6
],
None
))
.unwrap
();
trie
.apply_event
(
create_store_event
(
worker_1
,
0
,
vec!
[
6
],
None
))
.unwrap
();
trie
.remove_worker
(
worker_0
);
trie
.clear_all_blocks
(
worker_0
);
assert
!
(
!
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
6
)],
false
)
.scores
;
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)],
1
);
// Test clearing a worker that doesn't exist
let
worker_fake
=
2
;
assert
!
(
!
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_fake
))
);
trie
.clear_all_blocks
(
worker_fake
);
assert
!
(
!
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_fake
))
);
assert
!
(
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
);
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
6
)],
false
)
.scores
;
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)],
1
);
}
#[test]
fn
test_early_stopping
()
{
setup
();
let
mut
trie
=
RadixTree
::
new
();
let
worker_0
=
0
;
let
worker_1
=
1
;
trie
.apply_event
(
create_store_event
(
worker_0
,
0
,
vec!
[
0
,
1
,
2
],
None
))
.unwrap
();
trie
.apply_event
(
create_store_event
(
worker_1
,
0
,
vec!
[
0
],
None
))
.unwrap
();
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
0
),
LocalBlockHash
(
1
),
LocalBlockHash
(
2
)],
true
,
)
.scores
;
assert
!
(
result
.len
()
==
2
&&
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
)]
==
2
&&
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)]
==
1
);
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
0
),
LocalBlockHash
(
1
)],
true
)
.scores
;
assert
!
(
result
.len
()
==
2
&&
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
)]
==
2
&&
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)]
==
1
);
}
#[rstest]
#[case(
11
)]
#[case(
32
)]
#[case(
64
)]
fn
test_compute_block_hash_for_seq
(
#[case]
kv_block_size
:
u32
)
{
setup
();
// create a sequence of 64 elements
let
sequence
=
(
0
..
kv_block_size
)
.collect
::
<
Vec
<
u32
>>
();
let
hashes
=
compute_block_hash_for_seq
(
&
sequence
,
kv_block_size
,
None
);
assert_eq!
(
hashes
.len
(),
1
);
// create a sequence of 65 elements
let
sequence
=
(
0
..
(
kv_block_size
+
1
))
.collect
::
<
Vec
<
u32
>>
();
let
hashes
=
compute_block_hash_for_seq
(
&
sequence
,
kv_block_size
,
None
);
assert_eq!
(
hashes
.len
(),
1
);
// create a sequence of 129 elements
let
sequence
=
(
0
..
(
2
*
kv_block_size
+
1
))
.collect
::
<
Vec
<
u32
>>
();
let
hashes
=
compute_block_hash_for_seq
(
&
sequence
,
kv_block_size
,
None
);
assert_eq!
(
hashes
.len
(),
2
);
}
fn
make_indexer
(
token
:
&
CancellationToken
,
num_shards
:
usize
,
...
...
@@ -2874,54 +1849,6 @@ mod tests {
assert_eq!
(
overlap
.frequencies
,
vec!
[
3
,
3
,
3
,
2
]);
}
#[test]
fn
test_router_event_new
()
{
setup
();
let
worker_id
=
0
;
let
kv_cache_event
=
KvCacheEvent
{
event_id
:
1
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
vec!
[
KvCacheStoredBlockData
{
block_hash
:
ExternalSequenceBlockHash
(
0
),
mm_extra_info
:
None
,
tokens_hash
:
LocalBlockHash
(
13226331709069118873
),
}],
}),
dp_rank
:
0
,
};
let
router_event
=
RouterEvent
::
new
(
worker_id
,
kv_cache_event
);
assert_eq!
(
router_event
.worker_id
,
worker_id
);
assert_eq!
(
router_event
.event.event_id
,
1
);
if
let
KvCacheEventData
::
Stored
(
store_op
)
=
&
router_event
.event.data
{
assert_eq!
(
store_op
.blocks
.len
(),
1
);
assert_eq!
(
store_op
.blocks
[
0
]
.tokens_hash
,
compute_block_hash
(
b
"test data"
)
);
assert_eq!
(
store_op
.blocks
[
0
]
.block_hash
,
ExternalSequenceBlockHash
(
0
));
}
else
{
panic!
(
"Expected KvCacheEventData::Stored"
);
}
}
#[test]
fn
test_radix_tree_default
()
{
setup
();
let
radix_tree
:
RadixTree
=
Default
::
default
();
assert
!
(
radix_tree
.root
.borrow
()
.children
.is_empty
());
assert
!
(
radix_tree
.root
.borrow
()
.workers
.is_empty
());
assert
!
(
radix_tree
.lookup
.is_empty
());
}
#[test]
fn
test_overlap_scores_default
()
{
setup
();
let
overlap_scores
:
OverlapScores
=
Default
::
default
();
assert
!
(
overlap_scores
.scores
.is_empty
());
}
#[tokio::test]
async
fn
test_dump_tree_as_events_round_trip
()
{
setup
();
...
...
@@ -3142,126 +2069,6 @@ mod tests {
);
}
#[test]
fn
test_remove_worker_verifies_hash_removal
()
{
setup
();
let
mut
trie
=
RadixTree
::
new
();
let
worker_0
=
0
;
let
worker_1
=
1
;
let
worker_2
=
2
;
// Add blocks for multiple workers
trie
.apply_event
(
create_store_event
(
worker_0
,
0
,
vec!
[
1
,
2
,
3
],
None
))
.unwrap
();
trie
.apply_event
(
create_store_event
(
worker_1
,
0
,
vec!
[
1
,
2
,
3
],
None
))
.unwrap
();
trie
.apply_event
(
create_store_event
(
worker_2
,
0
,
vec!
[
1
,
4
,
5
],
None
))
.unwrap
();
// Verify worker_0 has 3 blocks in lookup
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
.unwrap
()
.len
(),
3
);
// Verify that blocks have the correct workers
let
block_1
=
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
.unwrap
()
.get
(
&
ExternalSequenceBlockHash
(
100
))
.unwrap
();
assert_eq!
(
block_1
.borrow
()
.workers
.len
(),
3
);
// worker_0, worker_1, and worker_2 (all have hash 1)
assert
!
(
block_1
.borrow
()
.workers
.contains
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
assert
!
(
block_1
.borrow
()
.workers
.contains
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
);
assert
!
(
block_1
.borrow
()
.workers
.contains
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
);
// Remove worker_0
trie
.remove_worker
(
worker_0
);
// Verify worker_0 is completely removed from lookup table
assert
!
(
!
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
assert_eq!
(
trie
.lookup
.len
(),
2
);
// Verify that worker_0's hash is removed from the workers set
let
block_1
=
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.get
(
&
ExternalSequenceBlockHash
(
100
))
.unwrap
();
assert_eq!
(
block_1
.borrow
()
.workers
.len
(),
2
);
// worker_1 and worker_2 remain
assert
!
(
!
block_1
.borrow
()
.workers
.contains
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
assert
!
(
block_1
.borrow
()
.workers
.contains
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
);
assert
!
(
block_1
.borrow
()
.workers
.contains
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
);
// Verify that blocks with no remaining workers have their children cleared
// This tests the optimization where empty blocks clear their children
let
block_2
=
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.get
(
&
ExternalSequenceBlockHash
(
200
))
.unwrap
();
assert_eq!
(
block_2
.borrow
()
.workers
.len
(),
1
);
// only worker_1
assert
!
(
block_2
.borrow
()
.workers
.contains
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
);
// Verify match results no longer include worker_0
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
)
.scores
;
assert_eq!
(
result
.len
(),
2
);
assert
!
(
!
result
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
)));
assert
!
(
result
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)));
assert
!
(
result
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
)));
}
// LocalKvIndexer tests
fn
make_indexer_with_events
(
ids
:
&
[
u64
])
->
LocalKvIndexer
{
let
indexer
=
LocalKvIndexer
::
new
(
...
...
@@ -3575,3 +2382,302 @@ mod tests {
}
}
}
/// Tests for KvIndex enum (parametrized over RadixTree and FlatHashMap variants).
#[cfg(test)]
mod
kv_index_tests
{
use
super
::
*
;
use
crate
::
protocols
::{
ExternalSequenceBlockHash
,
LocalBlockHash
,
compute_seq_hash_for_block
};
use
rstest
::
rstest
;
use
rstest_reuse
::{
self
,
*
};
/// Create a store event with proper sequence hashes computed from local hashes.
fn
make_store_event
(
worker_id
:
u64
,
local_hashes
:
&
[
u64
])
->
RouterEvent
{
let
local_block_hashes
:
Vec
<
LocalBlockHash
>
=
local_hashes
.iter
()
.map
(|
&
h
|
LocalBlockHash
(
h
))
.collect
();
let
seq_hashes
=
compute_seq_hash_for_block
(
&
local_block_hashes
);
RouterEvent
{
worker_id
,
event
:
KvCacheEvent
{
event_id
:
0
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
local_block_hashes
.iter
()
.zip
(
seq_hashes
.iter
())
.map
(|(
&
local
,
&
seq
)|
KvCacheStoredBlockData
{
tokens_hash
:
local
,
block_hash
:
ExternalSequenceBlockHash
(
seq
),
mm_extra_info
:
None
,
})
.collect
(),
}),
dp_rank
:
0
,
},
}
}
/// Create a remove event for blocks with given local hashes.
fn
make_remove_event
(
worker_id
:
u64
,
local_hashes
:
&
[
u64
])
->
RouterEvent
{
let
local_block_hashes
:
Vec
<
LocalBlockHash
>
=
local_hashes
.iter
()
.map
(|
&
h
|
LocalBlockHash
(
h
))
.collect
();
let
seq_hashes
=
compute_seq_hash_for_block
(
&
local_block_hashes
);
RouterEvent
{
worker_id
,
event
:
KvCacheEvent
{
event_id
:
0
,
data
:
KvCacheEventData
::
Removed
(
KvCacheRemoveData
{
block_hashes
:
seq_hashes
.iter
()
.map
(|
&
h
|
ExternalSequenceBlockHash
(
h
))
.collect
(),
}),
dp_rank
:
0
,
},
}
}
#[template]
#[rstest]
fn
kv_index_template
(
#[values(
"tree"
,
"flat"
)]
variant
:
&
str
)
{}
fn
make_kv_index
(
variant
:
&
str
)
->
KvIndex
{
match
variant
{
"tree"
=>
KvIndex
::
new_tree
(),
"flat"
=>
KvIndex
::
new_flat
(),
_
=>
panic!
(
"Unknown variant: {}"
,
variant
),
}
}
#[apply(kv_index_template)]
fn
test_store_and_find
(
variant
:
&
str
)
{
let
mut
index
=
make_kv_index
(
variant
);
// Store a sequence for worker 0
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.unwrap
();
assert_eq!
(
index
.current_size
(),
3
);
// Find matches using local hashes
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
);
assert_eq!
(
scores
.scores
.len
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
#[apply(kv_index_template)]
fn
test_partial_match
(
variant
:
&
str
)
{
let
mut
index
=
make_kv_index
(
variant
);
// Store [1, 2, 3] for worker 0
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.unwrap
();
// Find matches for [1, 2, 999] - should match first 2 then stop
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
999
)],
false
,
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
}
#[apply(kv_index_template)]
fn
test_remove
(
variant
:
&
str
)
{
let
mut
index
=
make_kv_index
(
variant
);
// Store sequence for worker 0
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.unwrap
();
assert_eq!
(
index
.current_size
(),
3
);
// Remove all blocks
index
.apply_event
(
make_remove_event
(
0
,
&
[
1
,
2
,
3
]))
.unwrap
();
assert_eq!
(
index
.current_size
(),
0
);
// Find should return nothing
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
);
assert
!
(
scores
.scores
.is_empty
());
}
#[apply(kv_index_template)]
fn
test_multiple_workers_shared_prefix
(
variant
:
&
str
)
{
let
mut
index
=
make_kv_index
(
variant
);
// Worker 0 has [1, 2], Worker 1 has [1, 3]
// Since sequence hashes are cumulative, [1] has same hash for both,
// but [1, 2] and [1, 3] have different hashes.
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
]))
.unwrap
();
index
.apply_event
(
make_store_event
(
1
,
&
[
1
,
3
]))
.unwrap
();
// Query [1] - both workers should match
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
)],
false
);
assert_eq!
(
scores
.scores
.len
(),
2
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
1
);
// Query [1, 2] - worker 0 matches both, worker 1 matches only first block
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
)],
false
);
assert_eq!
(
scores
.scores
.len
(),
2
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
1
);
}
#[apply(kv_index_template)]
fn
test_remove_worker
(
variant
:
&
str
)
{
let
mut
index
=
make_kv_index
(
variant
);
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.unwrap
();
index
.apply_event
(
make_store_event
(
1
,
&
[
1
,
2
,
3
]))
.unwrap
();
assert_eq!
(
index
.current_size
(),
6
);
index
.remove_worker
(
0
);
assert_eq!
(
index
.current_size
(),
3
);
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
);
assert_eq!
(
scores
.scores
.len
(),
1
);
assert
!
(
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
1
,
0
)));
}
#[apply(kv_index_template)]
fn
test_get_workers
(
variant
:
&
str
)
{
let
mut
index
=
make_kv_index
(
variant
);
index
.apply_event
(
make_store_event
(
0
,
&
[
1
]))
.unwrap
();
index
.apply_event
(
make_store_event
(
2
,
&
[
1
]))
.unwrap
();
index
.apply_event
(
make_store_event
(
1
,
&
[
1
]))
.unwrap
();
let
workers
=
index
.get_workers
();
assert_eq!
(
workers
,
vec!
[
0
,
1
,
2
]);
}
#[apply(kv_index_template)]
fn
test_early_exit
(
variant
:
&
str
)
{
let
mut
index
=
make_kv_index
(
variant
);
// Worker 0 has [0, 1, 2], Worker 1 has [0] only
index
.apply_event
(
make_store_event
(
0
,
&
[
0
,
1
,
2
]))
.unwrap
();
index
.apply_event
(
make_store_event
(
1
,
&
[
0
]))
.unwrap
();
// Query [0, 1, 2] with early_exit=true
// Should stop after [0, 1] since only worker 0 has block 1
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
0
),
LocalBlockHash
(
1
),
LocalBlockHash
(
2
)],
true
,
);
// Both workers should appear in results
assert_eq!
(
scores
.scores
.len
(),
2
);
// Worker 0 got 2 points (blocks 0 and 1, stopped early)
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
// Worker 1 got 1 point (block 0 only)
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
1
);
// Without early_exit, worker 0 should get all 3 blocks
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
0
),
LocalBlockHash
(
1
),
LocalBlockHash
(
2
)],
false
,
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
#[apply(kv_index_template)]
fn
test_large_stores
(
variant
:
&
str
)
{
let
mut
index
=
make_kv_index
(
variant
);
// Test sequences of increasing sizes
for
i
in
0
..
10
{
let
len
=
1
<<
i
;
// 1, 2, 4, 8, ..., 512
let
worker_id
=
i
;
let
sequence
:
Vec
<
u64
>
=
(
1
..=
len
)
.map
(|
x
|
x
+
(
i
as
u64
*
10000
))
.collect
();
index
.apply_event
(
make_store_event
(
worker_id
,
&
sequence
))
.unwrap
();
assert
!
(
index
.current_size
()
>
0
);
}
}
#[apply(kv_index_template)]
fn
test_dump_and_restore
(
variant
:
&
str
)
{
let
mut
index
=
make_kv_index
(
variant
);
// Store some data
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.unwrap
();
index
.apply_event
(
make_store_event
(
1
,
&
[
1
,
2
,
4
]))
.unwrap
();
let
original_size
=
index
.current_size
();
let
workers_before
=
index
.get_workers
();
// Dump the tree as events
let
events
=
index
.dump_tree_as_events
();
assert
!
(
!
events
.is_empty
());
// Create a new index and replay events
let
mut
restored
=
make_kv_index
(
variant
);
for
event
in
events
{
let
_
=
restored
.apply_event
(
event
);
}
// Verify the restored index has same size and workers
assert_eq!
(
restored
.current_size
(),
original_size
);
assert_eq!
(
restored
.get_workers
(),
workers_before
);
// Verify find_matches produces same results
let
original_scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
)],
false
);
let
restored_scores
=
restored
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
)],
false
);
assert_eq!
(
original_scores
.scores
,
restored_scores
.scores
);
}
#[apply(kv_index_template)]
fn
test_clear_all_blocks
(
variant
:
&
str
)
{
let
mut
index
=
make_kv_index
(
variant
);
// Store some data for two workers
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.unwrap
();
index
.apply_event
(
make_store_event
(
1
,
&
[
1
,
2
,
3
]))
.unwrap
();
assert_eq!
(
index
.current_size
(),
6
);
// Clear worker 0's blocks
index
.clear_all_blocks
(
0
);
// Worker 0's blocks should be gone, worker 1's remain
assert_eq!
(
index
.current_size
(),
3
);
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
);
assert_eq!
(
scores
.scores
.len
(),
1
);
assert
!
(
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
1
,
0
)));
}
#[apply(kv_index_template)]
fn
test_empty_query
(
variant
:
&
str
)
{
let
mut
index
=
make_kv_index
(
variant
);
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.unwrap
();
// Empty query should return empty scores
let
scores
=
index
.find_matches
(
vec!
[],
false
);
assert
!
(
scores
.scores
.is_empty
());
}
#[apply(kv_index_template)]
fn
test_miss_query
(
variant
:
&
str
)
{
let
mut
index
=
make_kv_index
(
variant
);
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.unwrap
();
// Query for non-existent blocks
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
999
),
LocalBlockHash
(
998
)],
false
);
assert
!
(
scores
.scores
.is_empty
());
}
}
lib/kv-router/src/lib.rs
View file @
fc92fc18
...
...
@@ -7,9 +7,16 @@
//! efficient KV cache lookup and routing in distributed LLM inference systems.
pub
mod
approx
;
pub
mod
flat_hashmap
;
pub
mod
indexer
;
pub
mod
protocols
;
pub
mod
radix_tree
;
// Re-export key types for convenience
pub
use
indexer
::{
MaybeError
,
RadixTree
,
RouterEvent
};
pub
use
protocols
::{
LocalBlockHash
,
WorkerId
,
compute_block_hash_for_seq
};
pub
use
flat_hashmap
::
FlatHashMap
;
pub
use
indexer
::
MaybeError
;
pub
use
protocols
::{
KvCacheEventError
,
LocalBlockHash
,
OverlapScores
,
RouterEvent
,
WorkerId
,
compute_block_hash_for_seq
,
};
pub
use
radix_tree
::
RadixTree
;
lib/kv-router/src/protocols.rs
View file @
fc92fc18
...
...
@@ -453,6 +453,105 @@ impl<'de> Deserialize<'de> for ExternalSequenceBlockHash {
}
}
// ------
// Router Event Types
// ------
/// Errors that can occur during KV Cache Event processing.
#[derive(Debug,
thiserror::Error)]
pub
enum
KvCacheEventError
{
#[error(
"Failed to find parent block"
)]
ParentBlockNotFound
,
#[error(
"Failed to find block"
)]
BlockNotFound
,
#[error(
"Invalid block sequence"
)]
InvalidBlockSequence
,
}
/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`].
#[derive(Debug,
Clone,
Serialize,
Deserialize,
PartialEq)]
pub
struct
RouterEvent
{
/// The ID of the worker emitting the event.
pub
worker_id
:
WorkerId
,
/// The cache event associated with the worker.
pub
event
:
KvCacheEvent
,
}
impl
RouterEvent
{
/// Create a new `RouterEvent`.
///
/// ### Arguments
///
/// * `worker_id` - The ID of the worker emitting the event.
/// * `event` - The cache event.
///
/// ### Returns
///
/// A new `RouterEvent`.
pub
fn
new
(
worker_id
:
WorkerId
,
event
:
KvCacheEvent
)
->
Self
{
Self
{
worker_id
,
event
}
}
}
/// Scores representing the overlap of workers (with their dp_rank).
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
pub
struct
OverlapScores
{
/// Map of worker (with dp_rank) to score.
pub
scores
:
std
::
collections
::
HashMap
<
WorkerWithDpRank
,
u32
>
,
/// List of frequencies that the blocks have been accessed. Entries with value 0 are omitted.
pub
frequencies
:
Vec
<
usize
>
,
/// Map of worker to their tree size (number of blocks in the tree for that worker).
pub
tree_sizes
:
std
::
collections
::
HashMap
<
WorkerWithDpRank
,
usize
>
,
}
impl
Default
for
OverlapScores
{
fn
default
()
->
Self
{
Self
::
new
()
}
}
impl
OverlapScores
{
/// Create a new `OverlapScores`.
///
/// ### Returns
///
/// A new `OverlapScores`.
pub
fn
new
()
->
Self
{
Self
{
scores
:
std
::
collections
::
HashMap
::
new
(),
frequencies
:
Vec
::
with_capacity
(
32
),
tree_sizes
:
std
::
collections
::
HashMap
::
new
(),
}
}
/// Update the scores with a set of workers.
///
/// ### Arguments
///
/// * `workers` - An iterator over `WorkerWithDpRank` references.
pub
fn
update_scores
<
'a
,
I
>
(
&
mut
self
,
workers
:
I
)
where
I
:
IntoIterator
<
Item
=
&
'a
WorkerWithDpRank
>
,
{
for
worker
in
workers
{
let
score
=
self
.scores
.entry
(
*
worker
)
.or_insert
(
0
);
*
score
+=
1
;
}
}
/// Add an entry in the frequency list.
pub
fn
add_frequency
(
&
mut
self
,
frequency
:
usize
)
{
if
frequency
!=
0
{
self
.frequencies
.last
()
.inspect
(|
elem
|
debug_assert!
(
**
elem
>=
frequency
));
self
.frequencies
.push
(
frequency
);
}
}
}
// ------
// TokensWithHashes
// ------
...
...
@@ -556,8 +655,67 @@ impl TokensWithHashes {
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
rstest
::
rstest
;
use
serde_json
;
#[test]
fn
test_router_event_new
()
{
let
worker_id
=
0
;
let
kv_cache_event
=
KvCacheEvent
{
event_id
:
1
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
vec!
[
KvCacheStoredBlockData
{
block_hash
:
ExternalSequenceBlockHash
(
0
),
mm_extra_info
:
None
,
tokens_hash
:
LocalBlockHash
(
13226331709069118873
),
}],
}),
dp_rank
:
0
,
};
let
router_event
=
RouterEvent
::
new
(
worker_id
,
kv_cache_event
);
assert_eq!
(
router_event
.worker_id
,
worker_id
);
assert_eq!
(
router_event
.event.event_id
,
1
);
if
let
KvCacheEventData
::
Stored
(
store_op
)
=
&
router_event
.event.data
{
assert_eq!
(
store_op
.blocks
.len
(),
1
);
assert_eq!
(
store_op
.blocks
[
0
]
.tokens_hash
,
compute_block_hash
(
b
"test data"
)
);
assert_eq!
(
store_op
.blocks
[
0
]
.block_hash
,
ExternalSequenceBlockHash
(
0
));
}
else
{
panic!
(
"Expected KvCacheEventData::Stored"
);
}
}
#[test]
fn
test_overlap_scores_default
()
{
let
overlap_scores
:
OverlapScores
=
Default
::
default
();
assert
!
(
overlap_scores
.scores
.is_empty
());
}
#[rstest]
#[case(
11
)]
#[case(
32
)]
#[case(
64
)]
fn
test_compute_block_hash_for_seq
(
#[case]
kv_block_size
:
u32
)
{
// create a sequence of kv_block_size elements
let
sequence
=
(
0
..
kv_block_size
)
.collect
::
<
Vec
<
u32
>>
();
let
hashes
=
compute_block_hash_for_seq
(
&
sequence
,
kv_block_size
,
None
);
assert_eq!
(
hashes
.len
(),
1
);
// create a sequence of kv_block_size + 1 elements
let
sequence
=
(
0
..
(
kv_block_size
+
1
))
.collect
::
<
Vec
<
u32
>>
();
let
hashes
=
compute_block_hash_for_seq
(
&
sequence
,
kv_block_size
,
None
);
assert_eq!
(
hashes
.len
(),
1
);
// create a sequence of 2 * kv_block_size + 1 elements
let
sequence
=
(
0
..
(
2
*
kv_block_size
+
1
))
.collect
::
<
Vec
<
u32
>>
();
let
hashes
=
compute_block_hash_for_seq
(
&
sequence
,
kv_block_size
,
None
);
assert_eq!
(
hashes
.len
(),
2
);
}
#[test]
fn
test_local_block_hash_serialization
()
{
let
hash
=
LocalBlockHash
(
12345
);
...
...
lib/kv-router/src/radix_tree.rs
0 → 100644
View file @
fc92fc18
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Radix Tree implementation for KV cache routing.
//!
//! This module provides a radix tree (prefix tree) data structure optimized for
//! efficient KV cache block lookup and management in distributed LLM inference.
//!
//! # Overview
//!
//! The main components include:
//!
//! - **RadixTree**: The main data structure with nodes (`RadixBlock`) containing
//! children and associated worker IDs. Allows efficient storage and retrieval
//! of data blocks based on their hashes.
use
std
::{
cell
::
RefCell
,
collections
::{
HashMap
,
HashSet
,
VecDeque
},
rc
::
Rc
,
time
::{
Duration
,
Instant
},
};
use
crate
::
protocols
::
*
;
/// A shared reference to a [`RadixBlock`].
pub
(
crate
)
type
SharedRadixBlock
=
Rc
<
RefCell
<
RadixBlock
>>
;
/// A block in the Radix Tree.
#[derive(Debug)]
pub
(
crate
)
struct
RadixBlock
{
/// A map of child blocks, keyed by their local block hash.
pub
(
crate
)
children
:
HashMap
<
LocalBlockHash
,
SharedRadixBlock
>
,
/// The set of workers that have this block cached.
pub
(
crate
)
workers
:
HashSet
<
WorkerWithDpRank
>
,
/// The external sequence block hash for this block (None for root).
/// This is the same for all workers under the simplifying assumption.
pub
(
crate
)
block_hash
:
Option
<
ExternalSequenceBlockHash
>
,
/// A buffer of times that this block was last traversed
pub
(
crate
)
recent_uses
:
VecDeque
<
Instant
>
,
}
impl
RadixBlock
{
/// Create a new `RadixBlock` (used for root node).
///
/// ### Returns
///
/// A new `RadixBlock` with no block_hash.
pub
fn
new
()
->
Self
{
Self
{
children
:
HashMap
::
new
(),
workers
:
HashSet
::
new
(),
block_hash
:
None
,
recent_uses
:
VecDeque
::
new
(),
}
}
/// Create a new `RadixBlock` with a specific block hash.
///
/// ### Returns
///
/// A new `RadixBlock` with the given block_hash.
pub
fn
with_hash
(
block_hash
:
ExternalSequenceBlockHash
)
->
Self
{
Self
{
children
:
HashMap
::
new
(),
workers
:
HashSet
::
new
(),
block_hash
:
Some
(
block_hash
),
recent_uses
:
VecDeque
::
new
(),
}
}
}
pub
struct
RadixTree
{
/// This is the root of the radix/prefix tree
/// This will only contain root blocks
pub
(
crate
)
root
:
SharedRadixBlock
,
/// Per-worker lookup table for O(1) block access.
/// Maps worker -> (block_hash -> block).
pub
(
crate
)
lookup
:
HashMap
<
WorkerWithDpRank
,
HashMap
<
ExternalSequenceBlockHash
,
SharedRadixBlock
>>
,
/// The time buffer the radix tree should check when considering frequence of block accesses
pub
(
crate
)
expiration_duration
:
Option
<
Duration
>
,
}
impl
Default
for
RadixTree
{
fn
default
()
->
Self
{
Self
::
new
()
}
}
// Dropping Radix blocks can cause a cascade of drops that can overflow the stack.
// This custom drop implementation avoids this using an iterative approach.
impl
Drop
for
RadixTree
{
fn
drop
(
&
mut
self
)
{
let
mut
stack
:
Vec
<
SharedRadixBlock
>
=
Vec
::
new
();
// Break root -> children edge up front
{
let
mut
root
=
self
.root
.borrow_mut
();
stack
.extend
(
root
.children
.drain
()
.map
(|(
_
,
v
)|
v
));
}
// Remove all lookup references (they may include blocks not reachable from root)
for
(
_
,
worker_blocks
)
in
self
.lookup
.drain
()
{
stack
.extend
(
worker_blocks
.into_values
());
}
// Iteratively free any uniquely-owned blocks without recursion
while
let
Some
(
block
)
=
stack
.pop
()
{
match
Rc
::
try_unwrap
(
block
)
{
Ok
(
cell
)
=>
{
// We own the cell, so we can take inner and it will drop after this block.
let
mut
inner
:
RadixBlock
=
cell
.into_inner
();
stack
.extend
(
inner
.children
.drain
()
.map
(|(
_
,
v
)|
v
));
}
Err
(
rc
)
=>
{
// We don't own the cell, just call drop on it.
drop
(
rc
);
}
}
}
}
}
impl
RadixTree
{
/// Create a new `RadixTree`.
///
/// ### Returns
///
/// A new `RadixTree`.
pub
fn
new_with_frequency
(
expiration_duration
:
Option
<
Duration
>
)
->
Self
{
Self
{
root
:
Rc
::
new
(
RefCell
::
new
(
RadixBlock
::
new
())),
lookup
:
HashMap
::
new
(),
expiration_duration
,
}
}
pub
fn
new
()
->
Self
{
Self
::
new_with_frequency
(
None
)
}
/// Traverse the radix tree to find the best match for a given sequence of [`LocalBlockHash`]es.
///
/// ### Arguments
///
/// * `sequence` - A vector of `LocalBlockHash` representing the sequence to match.
/// * `early_exit` - A boolean indicating whether to exit early if a single match is found.
///
/// ### Returns
///
/// An `OverlapScores` representing the match scores.
pub
fn
find_matches
(
&
self
,
sequence
:
Vec
<
LocalBlockHash
>
,
early_exit
:
bool
)
->
OverlapScores
{
let
mut
scores
=
OverlapScores
::
new
();
let
mut
current
=
self
.root
.clone
();
let
now
=
Instant
::
now
();
tracing
::
trace!
(
"RadixTree::find_matches: looking for sequence={:?}"
,
sequence
.iter
()
.map
(|
h
|
h
.0
)
.collect
::
<
Vec
<
_
>>
()
);
for
(
idx
,
block_hash
)
in
sequence
.iter
()
.enumerate
()
{
let
next_block
=
{
let
current_borrow
=
current
.borrow
();
current_borrow
.children
.get
(
block_hash
)
.cloned
()
};
if
let
Some
(
block
)
=
next_block
{
scores
.update_scores
(
block
.borrow
()
.workers
.iter
());
if
let
Some
(
expiration_duration
)
=
self
.expiration_duration
{
let
mut
block_mut
=
block
.borrow_mut
();
while
let
Some
(
access_time
)
=
block_mut
.recent_uses
.front
()
{
if
now
.duration_since
(
*
access_time
)
>
expiration_duration
{
block_mut
.recent_uses
.pop_front
();
}
else
{
break
;
}
}
scores
.add_frequency
(
block_mut
.recent_uses
.len
());
block_mut
.recent_uses
.push_back
(
now
);
}
if
early_exit
&&
block
.borrow
()
.workers
.len
()
==
1
{
break
;
}
current
=
block
;
}
else
{
tracing
::
trace!
(
"RadixTree::find_matches: block not found at index {} for hash {}"
,
idx
,
block_hash
.0
);
break
;
}
}
tracing
::
trace!
(
"RadixTree::find_matches: final scores={:?}"
,
scores
.scores
);
// Populate tree sizes for all workers that have scores
for
worker
in
scores
.scores
.keys
()
{
let
tree_size
=
self
.lookup
.get
(
worker
)
.expect
(
"worker in scores must exist in lookup table"
)
.len
();
scores
.tree_sizes
.insert
(
*
worker
,
tree_size
);
}
scores
}
/// Apply a [`RouterEvent`] to the radix tree.
///
/// ### Arguments
///
/// * `event` - The `RouterEvent` to apply.
pub
fn
apply_event
(
&
mut
self
,
event
:
RouterEvent
)
->
Result
<
(),
KvCacheEventError
>
{
let
(
worker_id
,
kv_event
)
=
(
event
.worker_id
,
event
.event
);
let
(
id
,
op
)
=
(
kv_event
.event_id
,
kv_event
.data
);
// Construct WorkerWithDpRank from worker_id and dp_rank from the event
let
worker
=
WorkerWithDpRank
::
new
(
worker_id
,
kv_event
.dp_rank
);
tracing
::
trace!
(
id
,
"RadixTree::apply_event: Store operation: {:?}"
,
op
);
let
worker_lookup
=
self
.lookup
.entry
(
worker
)
.or_default
();
match
op
{
KvCacheEventData
::
Stored
(
op
)
=>
{
// find the parent block from this worker's lookup
let
mut
current
=
match
op
.parent_hash
{
Some
(
parent
)
=>
match
worker_lookup
.get
(
&
parent
)
{
Some
(
current
)
=>
current
.clone
(),
None
=>
{
tracing
::
warn!
(
worker_id
=
worker
.worker_id
.to_string
(),
dp_rank
=
worker
.dp_rank
,
id
,
parent_hash
=
?
op
.parent_hash
,
num_blocks
=
op
.blocks
.len
(),
"Failed to find parent block; skipping store operation"
);
return
Err
(
KvCacheEventError
::
ParentBlockNotFound
);
}
},
None
=>
self
.root
.clone
(),
};
for
block_data
in
op
.blocks
{
let
mut
parent_mut
=
current
.borrow_mut
();
let
child
=
match
parent_mut
.children
.get
(
&
block_data
.tokens_hash
)
{
Some
(
block
)
=>
{
// Verify our simplifying assumption: block_hash is uniform across workers
if
block
.borrow
()
.block_hash
!=
Some
(
block_data
.block_hash
)
{
tracing
::
warn!
(
expected
=
?
block_data
.block_hash
,
actual
=
?
block
.borrow
()
.block_hash
,
"block_hash mismatch: sequence hashes should be uniform across workers"
);
}
block
.clone
()
}
None
=>
{
// create new block or reuse existing from worker's lookup
let
new_block
=
worker_lookup
.get
(
&
block_data
.block_hash
)
.cloned
()
.unwrap_or_else
(||
{
Rc
::
new
(
RefCell
::
new
(
RadixBlock
::
with_hash
(
block_data
.block_hash
,
)))
});
// insert into radix tree
parent_mut
.children
.insert
(
block_data
.tokens_hash
,
new_block
.clone
());
new_block
}
};
// Update child and check for self referential blocks
{
// Try to borrow the child mutably - if it fails, it's already borrowed
// which means a self referencing block.
let
mut
child_mut
=
match
child
.try_borrow_mut
()
{
Ok
(
b
)
=>
b
,
Err
(
_
)
=>
{
tracing
::
warn!
(
worker_id
=
worker
.worker_id
.to_string
(),
dp_rank
=
worker
.dp_rank
,
id
,
block_hash
=
?
block_data
.block_hash
,
"Detected self referencing block in store event; rejecting sequence"
);
return
Err
(
KvCacheEventError
::
InvalidBlockSequence
);
}
};
// add our worker to the block
child_mut
.workers
.insert
(
worker
);
}
// add the block to the worker's lookup table
worker_lookup
.insert
(
block_data
.block_hash
,
child
.clone
());
// drop child so we can shift current to this block
drop
(
parent_mut
);
current
=
child
;
}
Ok
(())
}
KvCacheEventData
::
Removed
(
remove
)
=>
{
let
mut
kv_cache_err
:
Option
<
KvCacheEventError
>
=
None
;
for
block
in
remove
.block_hashes
{
// lookup block in worker's table
let
entry
=
match
worker_lookup
.get
(
&
block
)
{
Some
(
entry
)
=>
entry
.clone
(),
None
=>
{
tracing
::
warn!
(
worker_id
=
worker
.worker_id
.to_string
(),
dp_rank
=
worker
.dp_rank
,
id
,
block_hash
=
?
block
,
"Failed to find block to remove; skipping remove operation"
);
// Kv cache removed events may be batched; we should try to apply all
// operations in the batch before returning an error. Return the first
// error.
if
kv_cache_err
.is_none
()
{
kv_cache_err
=
Some
(
KvCacheEventError
::
BlockNotFound
);
}
continue
;
}
};
let
mut
guard
=
entry
.borrow_mut
();
guard
.workers
.remove
(
&
worker
);
if
guard
.workers
.is_empty
()
{
// if no workers are using this block, that is true for all children
guard
.children
.clear
();
}
// remove the block from the worker's lookup table
worker_lookup
.remove
(
&
block
);
}
kv_cache_err
.map_or
(
Ok
(()),
Err
)
}
KvCacheEventData
::
Cleared
=>
{
self
.clear_all_blocks
(
worker
.worker_id
);
Ok
(())
}
}
}
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains in lookup with empty blocks.
/// If `keep_worker` is false, the worker is completely removed from lookup.
fn
remove_or_clear_worker_blocks
(
&
mut
self
,
worker_id
:
WorkerId
,
keep_worker
:
bool
)
{
// Collect all WorkerWithDpRank keys that match this worker_id
let
workers
:
Vec
<
WorkerWithDpRank
>
=
self
.lookup
.keys
()
.filter
(|
w
|
w
.worker_id
==
worker_id
)
.copied
()
.collect
();
for
worker
in
workers
{
if
let
Some
((
worker_key
,
blocks
))
=
self
.lookup
.remove_entry
(
&
worker
)
{
for
(
_
,
block
)
in
blocks
{
block
.borrow_mut
()
.workers
.remove
(
&
worker
);
// If no workers are using this block, that is true for all children
if
block
.borrow
()
.workers
.is_empty
()
{
block
.borrow_mut
()
.children
.clear
();
}
}
if
keep_worker
{
// Re-insert worker with empty blocks map to keep it tracked
self
.lookup
.insert
(
worker_key
,
HashMap
::
new
());
}
}
}
}
pub
fn
remove_worker
(
&
mut
self
,
worker_id
:
WorkerId
)
{
self
.remove_or_clear_worker_blocks
(
worker_id
,
false
);
}
pub
fn
clear_all_blocks
(
&
mut
self
,
worker_id
:
WorkerId
)
{
self
.remove_or_clear_worker_blocks
(
worker_id
,
true
);
}
/// Get all worker IDs currently tracked in the radix tree.
/// Returns unique worker_ids (ignoring dp_rank differences).
pub
fn
get_workers
(
&
self
)
->
Vec
<
WorkerId
>
{
let
mut
worker_ids
:
Vec
<
WorkerId
>
=
self
.lookup
.keys
()
.map
(|
w
|
w
.worker_id
)
.collect
();
worker_ids
.sort_unstable
();
worker_ids
.dedup
();
worker_ids
}
/// Dump the radix tree as a series of RouterEvents that can reconstruct the tree.
/// Uses BFS traversal to ensure that the tree reconstruction is unique,
/// though the exact event ordering will be lost.
pub
fn
dump_tree_as_events
(
&
self
)
->
Vec
<
RouterEvent
>
{
tracing
::
debug!
(
"Dumping radix tree as events (contains information about {:?} workers)"
,
self
.lookup
.len
()
);
let
mut
events
=
Vec
::
new
();
let
mut
event_id
=
0u64
;
// Queue entries: (current_block, parent_hash, tokens_hash)
let
mut
queue
=
VecDeque
::
new
();
// Process root's children first
let
root_borrow
=
self
.root
.borrow
();
for
(
tokens_hash
,
child_block
)
in
&
root_borrow
.children
{
queue
.push_back
((
child_block
.clone
(),
None
,
*
tokens_hash
));
}
drop
(
root_borrow
);
while
let
Some
((
current_block
,
parent_hash
,
tokens_hash
))
=
queue
.pop_front
()
{
let
current_borrow
=
current_block
.borrow
();
// Get this block's hash (same for all workers)
let
block_hash
=
current_borrow
.block_hash
.expect
(
"non-root block must have block_hash"
);
// For each worker that has this block
for
worker
in
&
current_borrow
.workers
{
// Create a store event for this worker
let
event
=
RouterEvent
{
worker_id
:
worker
.worker_id
,
event
:
KvCacheEvent
{
event_id
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
,
blocks
:
vec!
[
KvCacheStoredBlockData
{
block_hash
,
mm_extra_info
:
None
,
tokens_hash
,
}],
}),
dp_rank
:
worker
.dp_rank
,
},
};
events
.push
(
event
);
event_id
+=
1
;
}
// Enqueue children with this block's hash as their parent
for
(
child_tokens_hash
,
child_block
)
in
&
current_borrow
.children
{
queue
.push_back
((
child_block
.clone
(),
Some
(
block_hash
),
*
child_tokens_hash
));
}
}
events
}
pub
fn
current_size
(
&
self
)
->
usize
{
self
.lookup
.values
()
.map
(|
m
|
m
.len
())
.sum
()
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
protocols
::{
ExternalSequenceBlockHash
,
KvCacheEvent
,
KvCacheEventData
,
KvCacheRemoveData
,
KvCacheStoreData
,
KvCacheStoredBlockData
,
LocalBlockHash
,
WorkerId
,
};
/// Creates blocks with artificial hash mapping (hash * 100) for testing RadixTree internals.
fn
make_blocks
(
hashes
:
Vec
<
u64
>
)
->
Vec
<
KvCacheStoredBlockData
>
{
hashes
.iter
()
.map
(|
i
|
KvCacheStoredBlockData
{
tokens_hash
:
LocalBlockHash
(
*
i
),
block_hash
:
ExternalSequenceBlockHash
(
*
i
*
100
),
mm_extra_info
:
None
,
})
.collect
()
}
fn
add_blocks
(
hashes
:
Vec
<
u64
>
,
parent_hash
:
Option
<
ExternalSequenceBlockHash
>
,
)
->
KvCacheEventData
{
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
,
blocks
:
make_blocks
(
hashes
),
})
}
fn
create_store_event
(
worker_id
:
WorkerId
,
event_id
:
u64
,
hashes
:
Vec
<
u64
>
,
parent
:
Option
<
ExternalSequenceBlockHash
>
,
)
->
RouterEvent
{
RouterEvent
{
worker_id
,
event
:
KvCacheEvent
{
event_id
,
data
:
add_blocks
(
hashes
,
parent
),
dp_rank
:
0
,
},
}
}
fn
create_remove_event
(
worker_id
:
WorkerId
,
event_id
:
u64
,
hashes
:
Vec
<
u64
>
)
->
RouterEvent
{
RouterEvent
{
worker_id
,
event
:
KvCacheEvent
{
event_id
,
data
:
KvCacheEventData
::
Removed
(
KvCacheRemoveData
{
block_hashes
:
hashes
.iter
()
.map
(|
i
|
ExternalSequenceBlockHash
(
*
i
*
100
))
.collect
(),
}),
dp_rank
:
0
,
},
}
}
#[test]
fn
test_radix_tree
()
{
let
mut
trie
=
RadixTree
::
new
();
let
worker_1
=
0
;
let
worker_2
=
1
;
trie
.apply_event
(
create_store_event
(
worker_1
,
1
,
vec!
[
1
,
2
,
3
],
None
))
.unwrap
();
let
scores
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
);
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
(),
&
3
);
assert_eq!
(
trie
.lookup
.len
(),
1
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.len
(),
3
);
assert_eq!
(
trie
.root
.borrow
()
.workers
.len
(),
0
);
assert_eq!
(
trie
.root
.borrow
()
.children
.len
(),
1
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.workers
.len
(),
1
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.children
.len
(),
1
);
trie
.apply_event
(
create_store_event
(
worker_2
,
1
,
vec!
[
1
,
4
,
5
],
None
))
.unwrap
();
let
scores
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
);
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
(),
&
3
);
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
(),
&
1
);
assert_eq!
(
trie
.lookup
.len
(),
2
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.len
(),
3
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
()
.len
(),
3
);
assert_eq!
(
trie
.root
.borrow
()
.workers
.len
(),
0
);
assert_eq!
(
trie
.root
.borrow
()
.children
.len
(),
1
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.workers
.len
(),
2
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.children
.len
(),
2
);
trie
.apply_event
(
create_remove_event
(
worker_2
,
2
,
vec!
[
5
]))
.unwrap
();
assert_eq!
(
trie
.lookup
.len
(),
2
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.len
(),
3
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
()
.len
(),
2
);
assert_eq!
(
trie
.root
.borrow
()
.workers
.len
(),
0
);
assert_eq!
(
trie
.root
.borrow
()
.children
.len
(),
1
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.workers
.len
(),
2
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.children
.len
(),
2
);
trie
.apply_event
(
create_remove_event
(
worker_2
,
3
,
vec!
[
4
]))
.unwrap
();
assert_eq!
(
trie
.lookup
.len
(),
2
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.len
(),
3
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
()
.len
(),
1
);
assert_eq!
(
trie
.root
.borrow
()
.workers
.len
(),
0
);
assert_eq!
(
trie
.root
.borrow
()
.children
.len
(),
1
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.workers
.len
(),
2
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.children
.len
(),
2
);
trie
.apply_event
(
create_store_event
(
worker_2
,
4
,
vec!
[
2
,
6
,
7
],
Some
(
ExternalSequenceBlockHash
(
100
)),
))
.unwrap
();
let
scores
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
);
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
(),
&
3
);
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
(),
&
2
);
assert_eq!
(
trie
.lookup
.len
(),
2
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.len
(),
3
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
()
.len
(),
4
);
assert_eq!
(
trie
.root
.borrow
()
.workers
.len
(),
0
);
assert_eq!
(
trie
.root
.borrow
()
.children
.len
(),
1
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.workers
.len
(),
2
);
assert_eq!
(
trie
.root
.borrow
()
.children
.get
(
&
LocalBlockHash
(
1
))
.unwrap
()
.borrow
()
.children
.len
(),
2
);
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.get
(
&
ExternalSequenceBlockHash
(
200
))
.unwrap
()
.borrow
()
.workers
.len
(),
2
);
}
#[test]
fn
test_radix_tree_apply_event_errors
()
{
let
mut
trie
=
RadixTree
::
new
();
let
worker_0
=
0
;
// Parent block not found
let
result
=
trie
.apply_event
(
create_store_event
(
worker_0
,
0
,
vec!
[
1
,
2
,
3
],
Some
(
ExternalSequenceBlockHash
(
12345
)),
));
assert
!
(
result
.is_err
());
assert
!
(
matches!
(
result
.unwrap_err
(),
KvCacheEventError
::
ParentBlockNotFound
));
// Block not found for remove event.
let
result
=
trie
.apply_event
(
create_remove_event
(
worker_0
,
0
,
vec!
[
1
,
2
,
3
]));
assert
!
(
result
.is_err
());
assert
!
(
matches!
(
result
.unwrap_err
(),
KvCacheEventError
::
BlockNotFound
));
// Parent appears in blocks: parent=1, blocks=[1, 2, 3]
// This should be rejected as block 1 (hash 100) is the parent - this is
// a self referencing block.
trie
.apply_event
(
create_store_event
(
worker_0
,
4
,
vec!
[
1
],
None
))
.unwrap
();
let
result
=
trie
.apply_event
(
create_store_event
(
worker_0
,
5
,
vec!
[
1
,
2
,
3
],
Some
(
ExternalSequenceBlockHash
(
100
)),
));
assert
!
(
matches!
(
result
.unwrap_err
(),
KvCacheEventError
::
InvalidBlockSequence
));
}
#[test]
fn
test_clear_all_blocks
()
{
let
mut
trie
=
RadixTree
::
new
();
let
worker_0
=
0
;
let
worker_1
=
1
;
assert
!
(
trie
.find_matches
(
vec!
[
LocalBlockHash
(
0
)],
false
)
.scores
.is_empty
()
);
// Test clearing an empty worker
trie
.clear_all_blocks
(
worker_0
);
assert
!
(
!
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
// Test clearing a worker with shared blocks
trie
.apply_event
(
create_store_event
(
worker_0
,
0
,
vec!
[
0
,
1
,
3
],
None
))
.unwrap
();
trie
.apply_event
(
create_store_event
(
worker_1
,
0
,
vec!
[
0
,
2
,
3
],
None
))
.unwrap
();
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
0
)],
false
)
.scores
;
assert
!
(
result
.len
()
==
2
&&
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
)]
==
1
&&
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)]
==
1
);
trie
.clear_all_blocks
(
worker_0
);
assert
!
(
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
assert
!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
.unwrap
()
.is_empty
()
);
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
0
),
LocalBlockHash
(
2
)],
false
)
.scores
;
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)],
2
);
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
0
),
LocalBlockHash
(
1
),
LocalBlockHash
(
3
)],
false
,
)
.scores
;
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)],
1
);
// Test re-adding blocks after clearing worker
trie
.apply_event
(
create_store_event
(
worker_0
,
0
,
vec!
[
4
,
5
],
None
))
.unwrap
();
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
4
),
LocalBlockHash
(
5
)],
false
)
.scores
;
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
)],
2
);
// Test multiple clears
trie
.clear_all_blocks
(
worker_0
);
trie
.clear_all_blocks
(
worker_0
);
assert
!
(
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
// Test clearing all workers
trie
.clear_all_blocks
(
worker_0
);
trie
.clear_all_blocks
(
worker_1
);
assert
!
(
!
trie
.lookup
.is_empty
());
assert
!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
.unwrap
()
.is_empty
()
);
assert
!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.is_empty
()
);
// Test clearing a worker that has been removed
trie
.apply_event
(
create_store_event
(
worker_0
,
0
,
vec!
[
6
],
None
))
.unwrap
();
trie
.apply_event
(
create_store_event
(
worker_1
,
0
,
vec!
[
6
],
None
))
.unwrap
();
trie
.remove_worker
(
worker_0
);
trie
.clear_all_blocks
(
worker_0
);
assert
!
(
!
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
6
)],
false
)
.scores
;
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)],
1
);
// Test clearing a worker that doesn't exist
let
worker_fake
=
2
;
assert
!
(
!
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_fake
))
);
trie
.clear_all_blocks
(
worker_fake
);
assert
!
(
!
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_fake
))
);
assert
!
(
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
);
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
6
)],
false
)
.scores
;
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)],
1
);
}
#[test]
fn
test_radix_tree_default
()
{
let
radix_tree
:
RadixTree
=
Default
::
default
();
assert
!
(
radix_tree
.root
.borrow
()
.children
.is_empty
());
assert
!
(
radix_tree
.root
.borrow
()
.workers
.is_empty
());
assert
!
(
radix_tree
.lookup
.is_empty
());
}
#[test]
fn
test_remove_worker_verifies_hash_removal
()
{
let
mut
trie
=
RadixTree
::
new
();
let
worker_0
=
0
;
let
worker_1
=
1
;
let
worker_2
=
2
;
// Add blocks for multiple workers
trie
.apply_event
(
create_store_event
(
worker_0
,
0
,
vec!
[
1
,
2
,
3
],
None
))
.unwrap
();
trie
.apply_event
(
create_store_event
(
worker_1
,
0
,
vec!
[
1
,
2
,
3
],
None
))
.unwrap
();
trie
.apply_event
(
create_store_event
(
worker_2
,
0
,
vec!
[
1
,
4
,
5
],
None
))
.unwrap
();
// Verify worker_0 has 3 blocks in lookup
assert_eq!
(
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
.unwrap
()
.len
(),
3
);
// Verify that blocks have the correct workers
let
block_1
=
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
.unwrap
()
.get
(
&
ExternalSequenceBlockHash
(
100
))
.unwrap
();
assert_eq!
(
block_1
.borrow
()
.workers
.len
(),
3
);
// worker_0, worker_1, and worker_2 (all have hash 1)
assert
!
(
block_1
.borrow
()
.workers
.contains
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
assert
!
(
block_1
.borrow
()
.workers
.contains
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
);
assert
!
(
block_1
.borrow
()
.workers
.contains
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
);
// Remove worker_0
trie
.remove_worker
(
worker_0
);
// Verify worker_0 is completely removed from lookup table
assert
!
(
!
trie
.lookup
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
assert_eq!
(
trie
.lookup
.len
(),
2
);
// Verify that worker_0's hash is removed from the workers set
let
block_1
=
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.get
(
&
ExternalSequenceBlockHash
(
100
))
.unwrap
();
assert_eq!
(
block_1
.borrow
()
.workers
.len
(),
2
);
// worker_1 and worker_2 remain
assert
!
(
!
block_1
.borrow
()
.workers
.contains
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
assert
!
(
block_1
.borrow
()
.workers
.contains
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
);
assert
!
(
block_1
.borrow
()
.workers
.contains
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
);
// Verify that blocks with no remaining workers have their children cleared
// This tests the optimization where empty blocks clear their children
let
block_2
=
trie
.lookup
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.get
(
&
ExternalSequenceBlockHash
(
200
))
.unwrap
();
assert_eq!
(
block_2
.borrow
()
.workers
.len
(),
1
);
// only worker_1
assert
!
(
block_2
.borrow
()
.workers
.contains
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
);
// Verify match results no longer include worker_0
let
result
=
trie
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
)
.scores
;
assert_eq!
(
result
.len
(),
2
);
assert
!
(
!
result
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
)));
assert
!
(
result
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)));
assert
!
(
result
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
)));
}
}
lib/llm/src/kv_router.rs
View file @
fc92fc18
...
...
@@ -44,11 +44,11 @@ use crate::{
discovery
::
RuntimeConfigsWithNotify
,
kv_router
::{
approx
::
PruneConfig
,
indexer
::{
KvIndexer
,
KvIndexerInterface
,
KvRouterError
,
OverlapScores
,
RouterEvent
},
indexer
::{
KvIndexer
,
KvIndexerInterface
,
KvRouterError
},
protocols
::{
LocalBlockHash
,
RouterRequest
,
RouterResponse
,
TokensWithHashes
,
WorkerId
,
WorkerSelectionResult
,
WorkerWithDpRank
,
compute_block_hash_for_seq
,
compute_seq_hash_for_block
,
LocalBlockHash
,
OverlapScores
,
RouterEvent
,
RouterRequest
,
RouterResponse
,
TokensWithHashes
,
WorkerId
,
WorkerSelectionResult
,
WorkerWithDpRank
,
compute_block_hash_for_seq
,
compute_seq_hash_for_block
,
},
scheduler
::{
KvScheduler
,
KvSchedulerError
,
PotentialLoad
,
SchedulingRequest
},
sequence
::
SequenceError
,
...
...
lib/llm/src/kv_router/publisher.rs
View file @
fc92fc18
...
...
@@ -42,7 +42,7 @@ fn create_kv_stream_name(component: &Component, subject: &str) -> String {
use
crate
::
kv_router
::{
KV_EVENT_SUBJECT
,
KV_METRICS_SUBJECT
,
WORKER_KV_INDEXER_BUFFER_SIZE
,
indexer
::{
KvIndexerMetrics
,
LocalKvIndexer
,
RouterEvent
},
indexer
::{
KvIndexerMetrics
,
LocalKvIndexer
},
protocols
::
*
,
worker_query
::
start_worker_kv_query_endpoint
,
};
...
...
lib/llm/src/kv_router/recorder.rs
View file @
fc92fc18
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
crate
::
kv_router
::
indexer
::
RouterEvent
;
use
crate
::
kv_router
::
protocols
::
RouterEvent
;
use
crate
::
recorder
::
Recorder
;
// Type alias for backward compatibility
...
...
lib/llm/src/kv_router/scheduler.rs
View file @
fc92fc18
...
...
@@ -17,8 +17,7 @@ use super::KV_HIT_RATE_SUBJECT;
use
super
::
KvRouterConfig
;
use
super
::
RouterConfigOverride
;
use
super
::
WorkerSelector
;
use
super
::
indexer
::
OverlapScores
;
use
super
::
protocols
::{
DpRank
,
WorkerId
,
WorkerSelectionResult
,
WorkerWithDpRank
};
use
super
::
protocols
::{
DpRank
,
OverlapScores
,
WorkerId
,
WorkerSelectionResult
,
WorkerWithDpRank
};
use
super
::
sequence
::{
ActiveSequencesMultiWorker
,
SequenceError
};
use
dynamo_tokens
::
SequenceHash
;
...
...
lib/llm/src/kv_router/sequence.rs
View file @
fc92fc18
...
...
@@ -22,7 +22,7 @@
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples).
use
crate
::
kv_router
::
indexer
::
OverlapScores
;
use
crate
::
kv_router
::
protocols
::
OverlapScores
;
use
anyhow
::
Result
;
use
dashmap
::
DashMap
;
use
derive_getters
::
Getters
;
...
...
lib/llm/src/kv_router/subscriber.rs
View file @
fc92fc18
...
...
@@ -19,8 +19,8 @@ use tokio_util::sync::CancellationToken;
use
crate
::
kv_router
::{
KV_EVENT_SUBJECT
,
RADIX_STATE_BUCKET
,
RADIX_STATE_FILE
,
indexer
::{
DumpRequest
,
GetWorkersRequest
,
RouterEvent
,
WorkerKvQueryResponse
},
protocols
::
WorkerId
,
indexer
::{
DumpRequest
,
GetWorkersRequest
,
WorkerKvQueryResponse
},
protocols
::
{
RouterEvent
,
WorkerId
}
,
router_discovery_query
,
worker_query
::
WorkerQueryClient
,
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment