Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
9e5014da
Unverified
Commit
9e5014da
authored
Feb 28, 2026
by
jthomson04
Committed by
GitHub
Feb 28, 2026
Browse files
perf: Concurrent router perf improvements (#6536)
Signed-off-by:
jthomson04
<
jwillthomson19@gmail.com
>
parent
fd035b19
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
392 additions
and
1879 deletions
+392
-1879
container/templates/dev.Dockerfile
container/templates/dev.Dockerfile
+2
-1
lib/kv-router/Cargo.toml
lib/kv-router/Cargo.toml
+2
-8
lib/kv-router/benches/radix_tree_microbench.rs
lib/kv-router/benches/radix_tree_microbench.rs
+0
-964
lib/kv-router/src/concurrent_radix_tree.rs
lib/kv-router/src/concurrent_radix_tree.rs
+125
-710
lib/kv-router/src/indexer.rs
lib/kv-router/src/indexer.rs
+92
-25
lib/kv-router/src/nested_map.rs
lib/kv-router/src/nested_map.rs
+170
-171
tests/fault_tolerance/deploy/container/Dockerfile.local_vllm
tests/fault_tolerance/deploy/container/Dockerfile.local_vllm
+1
-0
No files found.
container/templates/dev.Dockerfile
View file @
9e5014da
...
...
@@ -107,7 +107,8 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
# Native deps for Python/Rust wheels
patchelf \
clang \
libclang-dev && \
libclang-dev \
libfontconfig-dev && \
rm -rf /var/lib/apt/lists/* && \
# Initialize Git LFS for the dynamo user (required for requirements with lfs=true)
git lfs install
...
...
lib/kv-router/Cargo.toml
View file @
9e5014da
...
...
@@ -12,12 +12,12 @@ repository.workspace = true
[features]
default
=
[]
metrics
=
[
"dep:dynamo-runtime"
]
metrics
=
[]
bench
=
[
"dep:clap"
,
"dep:indicatif"
,
"dep:serde_json"
,
"dynamo-runtime/integration"
,
"dep:plotters"
]
[dependencies]
# repo
dynamo-runtime
=
{
workspace
=
true
,
optional
=
true
}
dynamo-runtime
=
{
workspace
=
true
}
dynamo-tokens
=
{
workspace
=
true
}
# workspace
...
...
@@ -58,12 +58,6 @@ dynamo-mocker = { workspace = true }
dynamo-tokens
=
{
workspace
=
true
}
minstant
=
"0.1.7"
[[bench]]
name
=
"radix_tree_microbench"
harness
=
false
required-features
=
["bench"]
[[bench]]
name
=
"kv_indexer_bench"
harness
=
false
...
...
lib/kv-router/benches/radix_tree_microbench.rs
deleted
100644 → 0
View file @
fd035b19
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Microbenchmark for radix tree operations with configurable size and depth.
//!
//! Measures latency and throughput of:
//! - store_block: Adding blocks to the tree
//! - remove_block: Removing blocks from the tree
//! - find_matches: Finding prefix matches in the tree
//!
//! Size is defined as total (worker, block) pairs in the tree.
//! Depth is the number of blocks per sequence (depth = (isl + osl) / block_size).
//!
//! Run with: cargo bench --package dynamo-kv-router --bench radix_tree_microbench --features bench -- --help
#[path
=
"common/mod.rs"
]
mod
common
;
use
common
::{
SequenceData
,
generate_sequences
};
use
clap
::{
Parser
,
ValueEnum
};
use
dynamo_bench
::
common
::
LatencyStats
;
use
dynamo_kv_router
::{
ConcurrentRadixTree
,
OverlapScores
,
PositionalIndexer
,
RadixTree
,
RouterEvent
,
SyncIndexer
,
compute_block_hash_for_seq
,
protocols
::
LocalBlockHash
,
};
use
std
::
time
::{
Duration
,
Instant
};
/// Unified interface for RadixTree, ConcurrentRadixTree, and PositionalIndexer benchmarking.
///
/// All structures have feature parity for store, remove, find_matches, and current_size.
/// The key difference is find_matches input:
/// - RadixTree/ConcurrentRadixTree: uses LocalBlockHash (tokens_hash)
/// - PositionalIndexer: uses LocalBlockHash (same as tree; internal mapping uses sequence hashes)
enum
KvIndex
{
Tree
(
RadixTree
),
Concurrent
(
ConcurrentRadixTree
),
Nested
(
PositionalIndexer
),
}
impl
KvIndex
{
fn
name
(
&
self
)
->
&
'static
str
{
match
self
{
KvIndex
::
Tree
(
_
)
=>
"RadixTree"
,
KvIndex
::
Concurrent
(
_
)
=>
"ConcurrentRadixTree"
,
KvIndex
::
Nested
(
_
)
=>
"PositionalIndexer"
,
}
}
fn
apply_event
(
&
mut
self
,
event
:
RouterEvent
)
{
match
self
{
KvIndex
::
Tree
(
tree
)
=>
{
let
_
=
tree
.apply_event
(
event
);
}
KvIndex
::
Concurrent
(
tree
)
=>
{
let
_
=
tree
.apply_event
(
event
);
}
KvIndex
::
Nested
(
map
)
=>
{
let
_
=
map
.apply_event
(
event
)
.ok
();
}
}
}
fn
find_matches_timed
(
&
self
,
seq
:
&
SequenceData
,
early_exit
:
bool
)
->
Duration
{
let
local_hashes
=
seq
.local_hashes
.clone
();
let
start
=
Instant
::
now
();
let
_
=
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.find_matches
(
local_hashes
,
early_exit
),
KvIndex
::
Concurrent
(
tree
)
=>
tree
.find_matches_impl
(
&
local_hashes
,
early_exit
),
KvIndex
::
Nested
(
map
)
=>
map
.find_matches
(
&
local_hashes
,
early_exit
),
};
start
.elapsed
()
}
fn
find_matches_miss_timed
(
&
self
,
depth
:
usize
,
i
:
usize
,
early_exit
:
bool
)
->
Duration
{
let
miss_hashes
:
Vec
<
LocalBlockHash
>
=
(
0
..
depth
)
.map
(|
j
|
LocalBlockHash
(
0xBAD_C0DE_0000_0000
|
((
i
as
u64
)
<<
16
)
|
(
j
as
u64
)))
.collect
();
let
start
=
Instant
::
now
();
let
_
=
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.find_matches
(
miss_hashes
,
early_exit
),
KvIndex
::
Concurrent
(
tree
)
=>
tree
.find_matches_impl
(
&
miss_hashes
,
early_exit
),
KvIndex
::
Nested
(
map
)
=>
map
.find_matches
(
&
miss_hashes
,
early_exit
),
};
start
.elapsed
()
}
fn
find_matches_partial_timed
(
&
self
,
seq
:
&
SequenceData
,
half
:
usize
,
i
:
usize
,
early_exit
:
bool
,
)
->
Duration
{
let
mut
partial
=
seq
.local_hashes
[
..
half
]
.to_vec
();
partial
.extend
(
(
0
..
half
)
.map
(|
j
|
LocalBlockHash
(
0xDEAD_0000
|
((
i
as
u64
)
<<
16
)
|
(
j
as
u64
))),
);
let
start
=
Instant
::
now
();
let
_
=
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.find_matches
(
partial
,
early_exit
),
KvIndex
::
Concurrent
(
tree
)
=>
tree
.find_matches_impl
(
&
partial
,
early_exit
),
KvIndex
::
Nested
(
map
)
=>
map
.find_matches
(
&
partial
,
early_exit
),
};
start
.elapsed
()
}
fn
current_size
(
&
self
)
->
usize
{
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.current_size
(),
KvIndex
::
Concurrent
(
tree
)
=>
tree
.current_size
(),
KvIndex
::
Nested
(
map
)
=>
map
.current_size
(),
}
}
fn
find_matches
(
&
self
,
local_hashes
:
Vec
<
LocalBlockHash
>
,
early_exit
:
bool
)
->
OverlapScores
{
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.find_matches
(
local_hashes
,
early_exit
),
KvIndex
::
Concurrent
(
tree
)
=>
tree
.find_matches_impl
(
&
local_hashes
,
early_exit
),
KvIndex
::
Nested
(
map
)
=>
map
.find_matches
(
&
local_hashes
,
early_exit
),
}
}
fn
dump_tree_as_events
(
&
self
)
->
Vec
<
RouterEvent
>
{
match
self
{
KvIndex
::
Tree
(
tree
)
=>
tree
.dump_tree_as_events
(),
KvIndex
::
Concurrent
(
tree
)
=>
tree
.dump_tree_as_events
(),
KvIndex
::
Nested
(
_
)
=>
{
// NestedMap does not support dump_tree_as_events
vec!
[]
}
}
}
}
/// Sweep benchmark mode
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq,
ValueEnum)]
enum
SweepMode
{
/// Vary sequence/query length (query has exactly `depth` blocks, all matching)
Depth
,
/// Vary match length (query has `max_depth` blocks, first `depth` match, rest garbage)
MatchLength
,
/// Vary number of prefix prompt groups (width of shared prefixes)
Width
,
}
#[derive(Parser,
Debug)]
#[command(name
=
"radix_tree_microbench"
)]
#[command(about
=
"Microbenchmark for radix tree operations"
)]
struct
Args
{
/// Ignored: passed by cargo bench harness
#[arg(long,
hide
=
true
)]
bench
:
bool
,
/// Target tree size in total (worker, block) pairs
#[arg(long,
default_value
=
"10000"
)]
size
:
usize
,
/// Sequence depth in blocks (depth = (isl + osl) / block_size, where block_size = 16)
#[arg(long,
default_value
=
"32"
)]
depth
:
usize
,
/// Number of workers to distribute blocks across
#[arg(long,
default_value
=
"4"
)]
num_workers
:
usize
,
/// Number of iterations per operation for timing
#[arg(long,
default_value
=
"1000"
)]
iterations
:
usize
,
/// Warmup ratio (0.0 to 1.0) - fraction of iterations to discard for warmup
#[arg(long,
default_value
=
"0.1"
)]
warmup_ratio
:
f64
,
/// Prefix prompt ratio (0.0 to 1.0) - portion of sequence from the beginning that is a shared prefix
#[arg(long,
default_value
=
"0.25"
)]
prefix_prompt_ratio
:
f64
,
/// Number of unique prefix prompt groups to randomly sample from
#[arg(long,
default_value
=
"4"
)]
num_prefix_prompts
:
usize
,
/// Run only specific benchmark (hash, store, remove, find_matches, dump, sweep, or all)
#[arg(long,
default_value
=
"all"
)]
benchmark_type
:
String
,
/// KV block size in tokens (for hash computation)
#[arg(long,
default_value
=
"16"
)]
block_size
:
u32
,
/// Verbose output with per-iteration timings
#[arg(short,
long)]
verbose
:
bool
,
/// Minimum depth for sweep mode
#[arg(long,
default_value
=
"1"
)]
min_depth
:
usize
,
/// Maximum depth for sweep mode
#[arg(long,
default_value
=
"8000"
)]
max_depth
:
usize
,
/// Number of depth points to sample in sweep mode (logarithmically spaced)
#[arg(long,
default_value
=
"20"
)]
sweep_points
:
usize
,
/// Iterations per depth point in sweep mode
#[arg(long,
default_value
=
"100"
)]
sweep_iterations
:
usize
,
/// Output format for sweep mode: "table" or "csv"
#[arg(long,
default_value
=
"table"
)]
sweep_format
:
String
,
/// Sweep mode: what to vary during the sweep
#[arg(long,
value_enum,
default_value
=
"depth"
)]
sweep_mode
:
SweepMode
,
/// Minimum width (num_prefix_prompts) for width sweep mode
#[arg(long,
default_value
=
"1"
)]
min_width
:
usize
,
/// Maximum width (num_prefix_prompts) for width sweep mode
#[arg(long,
default_value
=
"64"
)]
max_width
:
usize
,
/// Random seed for reproducibility
#[arg(long,
default_value
=
"42"
)]
seed
:
u64
,
/// Use nested map instead of radix tree (for comparison)
#[arg(long)]
nested_map
:
bool
,
/// Use concurrent radix tree instead of single-threaded radix tree
#[arg(long)]
concurrent
:
bool
,
}
/// Build a pre-populated KvIndex (prints timing info)
fn
build_index
(
sequences
:
&
[
SequenceData
],
use_nested_map
:
bool
,
use_concurrent
:
bool
)
->
KvIndex
{
let
num_blocks
:
usize
=
sequences
.iter
()
.map
(|
s
|
s
.local_hashes
.len
())
.sum
();
let
name
=
if
use_nested_map
{
"NestedMap"
}
else
if
use_concurrent
{
"ConcurrentRadixTree"
}
else
{
"RadixTree"
};
print!
(
" Building {} with {} sequences ({} blocks)... "
,
name
,
sequences
.len
(),
num_blocks
);
std
::
io
::
Write
::
flush
(
&
mut
std
::
io
::
stdout
())
.unwrap
();
let
start
=
Instant
::
now
();
let
mut
index
=
if
use_nested_map
{
KvIndex
::
Nested
(
PositionalIndexer
::
new
(
32
))
}
else
if
use_concurrent
{
KvIndex
::
Concurrent
(
ConcurrentRadixTree
::
new
())
}
else
{
KvIndex
::
Tree
(
RadixTree
::
new
())
};
for
(
event_id
,
seq
)
in
sequences
.iter
()
.enumerate
()
{
let
event
=
seq
.to_store_event
(
event_id
as
u64
);
index
.apply_event
(
event
);
}
let
elapsed
=
start
.elapsed
();
println!
(
"done in {:.2?} ({:.2} sequences/sec, {:.2} blocks/sec)"
,
elapsed
,
sequences
.len
()
as
f64
/
elapsed
.as_secs_f64
(),
num_blocks
as
f64
/
elapsed
.as_secs_f64
()
);
index
}
/// Benchmark compute_block_hash_for_seq operation
fn
bench_hash
(
args
:
&
Args
)
{
println!
(
"
\n
=== Benchmarking COMPUTE_BLOCK_HASH (per-request hot path) ==="
);
let
num_tokens
=
args
.depth
*
args
.block_size
as
usize
;
println!
(
" Token sequence length: {} tokens ({} blocks)"
,
num_tokens
,
args
.depth
);
// Generate token sequences to hash
let
token_sequences
:
Vec
<
Vec
<
u32
>>
=
(
0
..
args
.iterations
)
.map
(|
i
|
{
(
0
..
num_tokens
)
.map
(|
j
|
((
i
*
num_tokens
+
j
)
%
50000
)
as
u32
)
.collect
()
})
.collect
();
let
warmup_iters
=
(
args
.iterations
as
f64
*
args
.warmup_ratio
)
as
usize
;
let
measured_iters
=
args
.iterations
-
warmup_iters
;
let
mut
durations
=
Vec
::
with_capacity
(
measured_iters
);
for
(
i
,
tokens
)
in
token_sequences
.iter
()
.enumerate
()
{
let
start
=
Instant
::
now
();
let
_
=
compute_block_hash_for_seq
(
tokens
,
args
.block_size
,
None
,
None
);
let
elapsed
=
start
.elapsed
();
if
i
>=
warmup_iters
{
durations
.push
(
elapsed
);
}
if
args
.verbose
&&
(
i
+
1
)
%
100
==
0
{
println!
(
" Completed {}/{} iterations"
,
i
+
1
,
args
.iterations
);
}
}
let
stats
=
LatencyStats
::
from_durations
(
&
durations
)
.unwrap
();
stats
.print
(
"COMPUTE_BLOCK_HASH"
,
args
.depth
);
}
/// Benchmark store or remove operation on a steady-state index.
///
/// Uses a remove/store cycle to maintain size. If `time_store` is true,
/// the store operation is timed; otherwise the remove operation is timed.
fn
bench_store_remove_cycle
(
args
:
&
Args
,
time_store
:
bool
)
{
let
op_name
=
if
time_store
{
"STORE_BLOCK"
}
else
{
"REMOVE_BLOCK"
};
let
num_sequences
=
args
.size
/
args
.depth
;
let
sequences
=
generate_sequences
(
num_sequences
,
args
.depth
,
args
.num_workers
,
args
.prefix_prompt_ratio
,
args
.num_prefix_prompts
,
args
.seed
,
true
,
);
let
mut
index
=
build_index
(
&
sequences
,
args
.nested_map
,
args
.concurrent
);
println!
(
"
\n
=== Benchmarking {} ({}) ==="
,
op_name
,
index
.name
());
println!
(
" Size: {} blocks"
,
index
.current_size
());
let
warmup_iters
=
(
args
.iterations
as
f64
*
args
.warmup_ratio
)
as
usize
;
let
measured_iters
=
args
.iterations
-
warmup_iters
;
let
mut
durations
=
Vec
::
with_capacity
(
measured_iters
);
for
i
in
0
..
args
.iterations
{
let
seq
=
&
sequences
[
i
%
sequences
.len
()];
let
remove_event
=
seq
.to_remove_event
(
i
as
u64
);
let
store_event
=
seq
.to_store_event
(
i
as
u64
+
args
.iterations
as
u64
);
let
elapsed
=
if
time_store
{
index
.apply_event
(
remove_event
);
let
start
=
Instant
::
now
();
index
.apply_event
(
store_event
);
start
.elapsed
()
}
else
{
let
start
=
Instant
::
now
();
index
.apply_event
(
remove_event
);
let
elapsed
=
start
.elapsed
();
index
.apply_event
(
store_event
);
elapsed
};
if
i
>=
warmup_iters
{
durations
.push
(
elapsed
);
}
if
args
.verbose
&&
(
i
+
1
)
%
100
==
0
{
println!
(
" Completed {}/{} iterations"
,
i
+
1
,
args
.iterations
);
}
}
let
stats
=
LatencyStats
::
from_durations
(
&
durations
)
.unwrap
();
stats
.print
(
op_name
,
args
.depth
);
}
/// Benchmark store_block operation
fn
bench_store
(
args
:
&
Args
)
{
bench_store_remove_cycle
(
args
,
true
);
}
/// Benchmark remove_block operation
fn
bench_remove
(
args
:
&
Args
)
{
bench_store_remove_cycle
(
args
,
false
);
}
/// Benchmark find_matches operation
fn
bench_find_matches
(
args
:
&
Args
)
{
let
num_sequences
=
args
.size
/
args
.depth
;
let
sequences
=
generate_sequences
(
num_sequences
,
args
.depth
,
args
.num_workers
,
args
.prefix_prompt_ratio
,
args
.num_prefix_prompts
,
args
.seed
,
true
,
);
let
index
=
build_index
(
&
sequences
,
args
.nested_map
,
args
.concurrent
);
println!
(
"
\n
=== Benchmarking FIND_MATCHES ({}) ==="
,
index
.name
());
println!
(
" Built with {} sequences, {} total blocks"
,
sequences
.len
(),
index
.current_size
()
);
let
warmup_iters
=
(
args
.iterations
as
f64
*
args
.warmup_ratio
)
as
usize
;
let
measured_iters
=
args
.iterations
-
warmup_iters
;
let
half
=
args
.depth
/
2
;
// HIT case
println!
(
"
\n
--- HIT case (existing sequences) ---"
);
let
mut
hit_durations
=
Vec
::
with_capacity
(
measured_iters
);
for
i
in
0
..
args
.iterations
{
let
seq
=
&
sequences
[
i
%
sequences
.len
()];
let
elapsed
=
index
.find_matches_timed
(
seq
,
false
);
if
i
>=
warmup_iters
{
hit_durations
.push
(
elapsed
);
}
if
args
.verbose
&&
(
i
+
1
)
%
100
==
0
{
println!
(
" Completed {}/{} iterations"
,
i
+
1
,
args
.iterations
);
}
}
LatencyStats
::
from_durations
(
&
hit_durations
)
.unwrap
()
.print
(
"FIND_MATCHES (HIT)"
,
args
.depth
);
// MISS case
println!
(
"
\n
--- MISS case (non-existing sequences) ---"
);
let
mut
miss_durations
=
Vec
::
with_capacity
(
measured_iters
);
for
i
in
0
..
args
.iterations
{
let
elapsed
=
index
.find_matches_miss_timed
(
args
.depth
,
i
,
false
);
if
i
>=
warmup_iters
{
miss_durations
.push
(
elapsed
);
}
if
args
.verbose
&&
(
i
+
1
)
%
100
==
0
{
println!
(
" Completed {}/{} iterations"
,
i
+
1
,
args
.iterations
);
}
}
LatencyStats
::
from_durations
(
&
miss_durations
)
.unwrap
()
.print
(
"FIND_MATCHES (MISS)"
,
args
.depth
);
// PARTIAL case
println!
(
"
\n
--- PARTIAL case (prefix match only) ---"
);
let
mut
partial_durations
=
Vec
::
with_capacity
(
measured_iters
);
for
i
in
0
..
args
.iterations
{
let
seq
=
&
sequences
[
i
%
sequences
.len
()];
let
elapsed
=
index
.find_matches_partial_timed
(
seq
,
half
,
i
,
false
);
if
i
>=
warmup_iters
{
partial_durations
.push
(
elapsed
);
}
if
args
.verbose
&&
(
i
+
1
)
%
100
==
0
{
println!
(
" Completed {}/{} iterations"
,
i
+
1
,
args
.iterations
);
}
}
LatencyStats
::
from_durations
(
&
partial_durations
)
.unwrap
()
.print
(
"FIND_MATCHES (PARTIAL)"
,
args
.depth
);
// EARLY_EXIT case
println!
(
"
\n
--- EARLY_EXIT case ---"
);
let
mut
early_exit_durations
=
Vec
::
with_capacity
(
measured_iters
);
for
i
in
0
..
args
.iterations
{
let
seq
=
&
sequences
[
i
%
sequences
.len
()];
let
elapsed
=
index
.find_matches_timed
(
seq
,
true
);
if
i
>=
warmup_iters
{
early_exit_durations
.push
(
elapsed
);
}
}
LatencyStats
::
from_durations
(
&
early_exit_durations
)
.unwrap
()
.print
(
"FIND_MATCHES (EARLY_EXIT)"
,
args
.depth
);
}
/// Generate logarithmically spaced values between min and max
fn
generate_log_spaced_points
(
min_val
:
usize
,
max_val
:
usize
,
num_points
:
usize
)
->
Vec
<
usize
>
{
if
num_points
<=
1
{
return
vec!
[
max_val
];
}
let
log_min
=
(
min_val
as
f64
)
.ln
();
let
log_max
=
(
max_val
as
f64
)
.ln
();
let
step
=
(
log_max
-
log_min
)
/
(
num_points
-
1
)
as
f64
;
let
mut
points
:
Vec
<
usize
>
=
(
0
..
num_points
)
.map
(|
i
|
(
log_min
+
step
*
i
as
f64
)
.exp
()
.round
()
as
usize
)
.map
(|
v
|
v
.max
(
1
))
// Ensure minimum value of 1
.collect
();
// Deduplicate (logarithmic spacing can produce duplicates at low values)
points
.dedup
();
points
}
/// Latency statistics (avg, p50, p99) in nanoseconds
#[derive(Debug)]
struct
DurationStats
{
avg_ns
:
u64
,
p50_ns
:
u64
,
p99_ns
:
u64
,
}
impl
DurationStats
{
/// Compute stats from durations. Sorts the input vector in place.
fn
from_durations
(
durations
:
&
mut
[
Duration
])
->
Self
{
durations
.sort
();
let
n
=
durations
.len
();
let
avg
=
durations
.iter
()
.sum
::
<
Duration
>
()
/
n
as
u32
;
Self
{
avg_ns
:
avg
.as_nanos
()
as
u64
,
p50_ns
:
durations
[
n
/
2
]
.as_nanos
()
as
u64
,
p99_ns
:
durations
[
n
*
99
/
100
]
.as_nanos
()
as
u64
,
}
}
}
/// Results for a single sweep point (depth or width)
#[derive(Debug)]
struct
SweepResult
{
point
:
usize
,
point_label
:
&
'static
str
,
store
:
DurationStats
,
remove
:
DurationStats
,
find_matches
:
DurationStats
,
}
impl
SweepResult
{
fn
csv_header
(
&
self
)
->
String
{
format!
(
"{},store_avg_ns,store_p50_ns,store_p99_ns,remove_avg_ns,remove_p50_ns,remove_p99_ns,find_matches_avg_ns,find_matches_p50_ns,find_matches_p99_ns"
,
self
.point_label
)
}
fn
csv_row
(
&
self
)
->
String
{
format!
(
"{},{},{},{},{},{},{},{},{},{}"
,
self
.point
,
self
.store.avg_ns
,
self
.store.p50_ns
,
self
.store.p99_ns
,
self
.remove.avg_ns
,
self
.remove.p50_ns
,
self
.remove.p99_ns
,
self
.find_matches.avg_ns
,
self
.find_matches.p50_ns
,
self
.find_matches.p99_ns
)
}
fn
table_header
(
&
self
)
->
String
{
format!
(
"{:>8} | store_avg store_p50 store_p99 | remove_avg remove_p50 remove_p99 | fm_avg fm_p50 fm_p99"
,
self
.point_label
)
}
fn
table_row
(
&
self
)
->
String
{
format!
(
"{:>8} | {:>12} {:>12} {:>12} | {:>12} {:>12} {:>12} | {:>12} {:>12} {:>12}"
,
self
.point
,
format_duration_ns
(
self
.store.avg_ns
),
format_duration_ns
(
self
.store.p50_ns
),
format_duration_ns
(
self
.store.p99_ns
),
format_duration_ns
(
self
.remove.avg_ns
),
format_duration_ns
(
self
.remove.p50_ns
),
format_duration_ns
(
self
.remove.p99_ns
),
format_duration_ns
(
self
.find_matches.avg_ns
),
format_duration_ns
(
self
.find_matches.p50_ns
),
format_duration_ns
(
self
.find_matches.p99_ns
)
)
}
}
fn
print_sweep_results_dynamic
(
results
:
&
[
SweepResult
],
format
:
&
str
)
{
if
results
.is_empty
()
{
return
;
}
println!
();
if
format
==
"csv"
{
println!
(
"{}"
,
results
[
0
]
.csv_header
());
for
r
in
results
{
println!
(
"{}"
,
r
.csv_row
());
}
}
else
{
println!
(
"{}"
,
results
[
0
]
.table_header
());
println!
(
"{}"
,
"-"
.repeat
(
130
));
for
r
in
results
{
println!
(
"{}"
,
r
.table_row
());
}
}
}
/// Benchmark store/remove/find_matches across a range of depths or widths.
///
/// For each sweep point, the tree is rebuilt.
///
/// With `--sweep_mode match_length`, find_matches queries have `max_depth` blocks
/// where only the first `depth` blocks match (rest are garbage). With `--sweep_mode depth`,
/// queries have exactly `depth` blocks (all matching). With `--sweep_mode width`,
/// the number of prefix prompt groups is varied.
fn
bench_sweep
(
args
:
&
Args
)
{
let
seq_length
=
args
.max_depth
;
let
num_sequences
=
args
.size
/
seq_length
;
if
num_sequences
<
2
{
eprintln!
(
"Error: size {} / max_depth {} = {} sequences (need at least 2).
\
Increase --size or decrease --max-depth."
,
args
.size
,
seq_length
,
num_sequences
);
std
::
process
::
exit
(
1
);
}
let
(
mode_name
,
point_label
,
sweep_points
)
=
match
args
.sweep_mode
{
SweepMode
::
Depth
=>
(
"Depth"
,
"depth"
,
generate_log_spaced_points
(
args
.min_depth
,
args
.max_depth
,
args
.sweep_points
),
),
SweepMode
::
MatchLength
=>
(
"Match Length"
,
"depth"
,
generate_log_spaced_points
(
args
.min_depth
,
args
.max_depth
,
args
.sweep_points
),
),
SweepMode
::
Width
=>
(
"Width"
,
"width"
,
generate_log_spaced_points
(
args
.min_width
,
args
.max_width
,
args
.sweep_points
),
),
};
println!
(
"
\n
=== {} Sweep Benchmark ==="
,
mode_name
);
println!
(
" Sequence length: {} blocks (fixed)"
,
seq_length
);
match
args
.sweep_mode
{
SweepMode
::
Depth
|
SweepMode
::
MatchLength
=>
{
println!
(
" Sweep range: {} to {} ({} points, log-spaced)"
,
args
.min_depth
,
args
.max_depth
,
args
.sweep_points
);
}
SweepMode
::
Width
=>
{
println!
(
" Width range: {} to {} ({} points, log-spaced)"
,
args
.min_width
,
args
.max_width
,
args
.sweep_points
);
println!
(
" Prefix prompt ratio: {:.1}%"
,
args
.prefix_prompt_ratio
*
100.0
);
}
}
println!
(
" Iterations per point: {}"
,
args
.sweep_iterations
);
println!
(
" Tree: {} sequences, {} total blocks"
,
num_sequences
,
num_sequences
*
seq_length
);
println!
(
" Workers: {}"
,
args
.num_workers
);
match
args
.sweep_mode
{
SweepMode
::
MatchLength
=>
{
println!
(
" Mode: find_matches queries padded with garbage to max_depth"
);
}
SweepMode
::
Depth
=>
{
println!
(
" Mode: find_matches queries truncated to depth"
);
}
SweepMode
::
Width
=>
{
println!
(
" Mode: varying num_prefix_prompts, full-depth operations"
);
}
}
println!
();
let
mut
results
:
Vec
<
SweepResult
>
=
Vec
::
with_capacity
(
sweep_points
.len
());
for
(
idx
,
&
point
)
in
sweep_points
.iter
()
.enumerate
()
{
print!
(
"[{}/{}] {}={}... "
,
idx
+
1
,
sweep_points
.len
(),
point_label
,
point
);
std
::
io
::
Write
::
flush
(
&
mut
std
::
io
::
stdout
())
.unwrap
();
// Determine depth and num_prefix_prompts for this sweep point
let
(
depth
,
num_prefix_prompts
)
=
match
args
.sweep_mode
{
SweepMode
::
Depth
|
SweepMode
::
MatchLength
=>
(
point
,
args
.num_prefix_prompts
),
SweepMode
::
Width
=>
(
seq_length
,
point
),
};
// Generate sequences and rebuild tree for this point
let
extra_count
=
args
.sweep_iterations
;
let
all_sequences
=
generate_sequences
(
num_sequences
+
extra_count
,
seq_length
,
args
.num_workers
,
args
.prefix_prompt_ratio
,
num_prefix_prompts
,
args
.seed
,
true
,
);
let
tree_sequences
=
&
all_sequences
[
..
num_sequences
];
let
extra_sequences
=
&
all_sequences
[
num_sequences
..
];
let
mut
index
=
build_index
(
tree_sequences
,
args
.nested_map
,
args
.concurrent
);
// --- STORE benchmark ---
let
mut
store_durations
=
Vec
::
with_capacity
(
args
.sweep_iterations
);
for
(
i
,
seq
)
in
extra_sequences
.iter
()
.enumerate
()
.take
(
args
.sweep_iterations
)
{
let
truncated
=
SequenceData
{
worker_id
:
seq
.worker_id
,
local_hashes
:
seq
.local_hashes
[
..
depth
]
.to_vec
(),
external_hashes
:
seq
.external_hashes
[
..
depth
]
.to_vec
(),
};
let
store_event
=
truncated
.to_store_event
(
i
as
u64
);
let
start
=
Instant
::
now
();
index
.apply_event
(
store_event
);
store_durations
.push
(
start
.elapsed
());
// Remove to restore index state (untimed)
let
remove_event
=
truncated
.to_remove_event
(
i
as
u64
);
index
.apply_event
(
remove_event
);
}
// --- REMOVE benchmark ---
let
mut
remove_durations
=
Vec
::
with_capacity
(
args
.sweep_iterations
);
for
i
in
0
..
args
.sweep_iterations
.min
(
num_sequences
)
{
let
seq
=
&
tree_sequences
[
i
%
tree_sequences
.len
()];
let
truncated
=
SequenceData
{
worker_id
:
seq
.worker_id
,
local_hashes
:
seq
.local_hashes
[
..
depth
]
.to_vec
(),
external_hashes
:
seq
.external_hashes
[
..
depth
]
.to_vec
(),
};
let
remove_event
=
truncated
.to_remove_event
(
i
as
u64
);
let
start
=
Instant
::
now
();
index
.apply_event
(
remove_event
);
remove_durations
.push
(
start
.elapsed
());
// Re-add to restore state (untimed)
let
store_event
=
truncated
.to_store_event
(
i
as
u64
+
1000000
);
index
.apply_event
(
store_event
);
}
// --- FIND_MATCHES benchmark ---
let
mut
find_matches_durations
=
Vec
::
with_capacity
(
args
.sweep_iterations
);
for
i
in
0
..
args
.sweep_iterations
{
let
seq
=
&
tree_sequences
[
i
%
tree_sequences
.len
()];
let
query
=
match
args
.sweep_mode
{
SweepMode
::
MatchLength
=>
{
// Match length mode: first `depth` blocks match, rest are garbage
let
mut
q
=
seq
.local_hashes
[
..
depth
]
.to_vec
();
let
garbage_len
=
seq_length
-
depth
;
q
.extend
((
0
..
garbage_len
)
.map
(|
j
|
{
LocalBlockHash
(
0xBAD_C0DE_0000_0000
|
((
i
as
u64
)
<<
16
)
|
(
j
as
u64
))
}));
q
}
SweepMode
::
Depth
|
SweepMode
::
Width
=>
{
// Depth/width mode: query has exactly `depth` blocks
seq
.local_hashes
[
..
depth
]
.to_vec
()
}
};
let
start
=
Instant
::
now
();
let
_
=
index
.find_matches
(
query
,
false
);
find_matches_durations
.push
(
start
.elapsed
());
}
// Compute stats
let
store
=
DurationStats
::
from_durations
(
&
mut
store_durations
);
let
remove
=
DurationStats
::
from_durations
(
&
mut
remove_durations
);
let
find_matches
=
DurationStats
::
from_durations
(
&
mut
find_matches_durations
);
println!
(
"store={:.2}us, remove={:.2}us, find_matches={:.2}us"
,
store
.avg_ns
as
f64
/
1000.0
,
remove
.avg_ns
as
f64
/
1000.0
,
find_matches
.avg_ns
as
f64
/
1000.0
);
results
.push
(
SweepResult
{
point
,
point_label
,
store
,
remove
,
find_matches
,
});
}
print_sweep_results_dynamic
(
&
results
,
&
args
.sweep_format
);
}
/// Benchmark dump_tree_as_events (BFS dump)
fn
bench_dump
(
args
:
&
Args
)
{
println!
(
"
\n
=== Benchmarking DUMP_TREE_AS_EVENTS (BFS dump) ==="
);
let
num_sequences
=
args
.size
/
args
.depth
;
let
sequences
=
generate_sequences
(
num_sequences
,
args
.depth
,
args
.num_workers
,
args
.prefix_prompt_ratio
,
args
.num_prefix_prompts
,
args
.seed
,
true
,
);
let
index
=
build_index
(
&
sequences
,
args
.nested_map
,
args
.concurrent
);
println!
(
" {} built with {} sequences, {} total blocks"
,
index
.name
(),
sequences
.len
(),
index
.current_size
()
);
// Single iteration timing
let
start
=
Instant
::
now
();
let
events
=
index
.dump_tree_as_events
();
let
elapsed
=
start
.elapsed
();
println!
(
"
\n
DUMP_TREE_AS_EVENTS Results:"
);
println!
(
" Time: {:?}"
,
elapsed
);
println!
(
" Events: {}"
,
events
.len
());
println!
(
" Throughput: {:.2} events/sec"
,
events
.len
()
as
f64
/
elapsed
.as_secs_f64
()
);
}
/// Format nanoseconds as human-readable string
fn
format_duration_ns
(
ns
:
u64
)
->
String
{
if
ns
>=
1_000_000_000
{
format!
(
"{:.2}s"
,
ns
as
f64
/
1_000_000_000.0
)
}
else
if
ns
>=
1_000_000
{
format!
(
"{:.2}ms"
,
ns
as
f64
/
1_000_000.0
)
}
else
if
ns
>=
1_000
{
format!
(
"{:.2}us"
,
ns
as
f64
/
1_000.0
)
}
else
{
format!
(
"{}ns"
,
ns
)
}
}
fn
main
()
{
let
args
=
Args
::
parse
();
// Validate arguments to prevent panics
if
args
.size
==
0
||
args
.depth
==
0
||
args
.num_workers
==
0
||
args
.iterations
==
0
||
args
.block_size
==
0
||
args
.min_depth
==
0
||
args
.max_depth
==
0
||
args
.min_width
==
0
||
args
.max_width
==
0
||
args
.sweep_iterations
==
0
{
eprintln!
(
"size, depth, num_workers, iterations, block_size, min_depth, max_depth, min_width, max_width, and sweep_iterations must be > 0"
);
std
::
process
::
exit
(
1
);
}
if
args
.min_depth
>
args
.max_depth
{
eprintln!
(
"min_depth must be <= max_depth"
);
std
::
process
::
exit
(
1
);
}
if
args
.min_width
>
args
.max_width
{
eprintln!
(
"min_width must be <= max_width"
);
std
::
process
::
exit
(
1
);
}
if
!
(
0.0
..=
1.0
)
.contains
(
&
args
.prefix_prompt_ratio
)
{
eprintln!
(
"prefix_prompt_ratio must be between 0.0 and 1.0"
);
std
::
process
::
exit
(
1
);
}
if
!
(
0.0
..=
1.0
)
.contains
(
&
args
.warmup_ratio
)
{
eprintln!
(
"warmup_ratio must be between 0.0 and 1.0"
);
std
::
process
::
exit
(
1
);
}
let
num_sequences
=
args
.size
/
args
.depth
;
if
matches!
(
args
.benchmark_type
.as_str
(),
"store"
|
"remove"
|
"lookup"
|
"sweep"
|
"all"
)
&&
num_sequences
==
0
{
eprintln!
(
"size must be >= depth to produce at least one sequence for {}"
,
args
.benchmark_type
);
std
::
process
::
exit
(
1
);
}
println!
(
"Radix Tree Microbenchmark"
);
println!
(
"=========================
\n
"
);
println!
(
"Configuration:"
);
println!
(
" Target size: {} (worker, block) pairs"
,
args
.size
);
println!
(
" Depth: {} blocks/sequence (= {} tokens with block_size={})"
,
args
.depth
,
args
.depth
*
args
.block_size
as
usize
,
args
.block_size
);
println!
(
" Block size: {} tokens"
,
args
.block_size
);
println!
(
" Workers: {}"
,
args
.num_workers
);
println!
(
" Iterations: {}"
,
args
.iterations
);
println!
(
" Warmup: {:.0}% ({} iterations discarded)"
,
args
.warmup_ratio
*
100.0
,
(
args
.iterations
as
f64
*
args
.warmup_ratio
)
as
usize
);
println!
(
" Prefix prompt ratio: {:.1}% ({} blocks at depth {})"
,
args
.prefix_prompt_ratio
*
100.0
,
(
args
.depth
as
f64
*
args
.prefix_prompt_ratio
)
.round
()
as
usize
,
args
.depth
);
println!
(
" Prefix prompt groups: {}"
,
args
.num_prefix_prompts
);
println!
(
"
\n
Derived: {} sequences to reach target size"
,
num_sequences
);
match
args
.benchmark_type
.as_str
()
{
"hash"
=>
bench_hash
(
&
args
),
"store"
=>
bench_store
(
&
args
),
"remove"
=>
bench_remove
(
&
args
),
"find_matches"
=>
bench_find_matches
(
&
args
),
"dump"
=>
bench_dump
(
&
args
),
"sweep"
=>
bench_sweep
(
&
args
),
"all"
=>
{
bench_hash
(
&
args
);
bench_store
(
&
args
);
bench_remove
(
&
args
);
bench_find_matches
(
&
args
);
bench_dump
(
&
args
);
}
_
=>
{
eprintln!
(
"Unknown benchmark type: {}. Use 'hash', 'store', 'remove', 'find_matches', 'dump', 'sweep', or 'all'"
,
args
.benchmark_type
);
std
::
process
::
exit
(
1
);
}
}
println!
(
"
\n
Benchmark complete."
);
}
lib/kv-router/src/concurrent_radix_tree.rs
View file @
9e5014da
...
...
@@ -25,12 +25,15 @@
//! per-worker write concurrency.
//! - Deadlock prevention: always lock parent before child, hand-over-hand locking
use
std
::
{
collections
::
VecDeque
,
sync
::
Arc
}
;
use
std
::
sync
::
Arc
;
use
dashmap
::
DashMap
;
use
parking_lot
::
RwLock
;
use
rustc_hash
::{
FxHashMap
,
FxHashSet
};
use
rustc_hash
::{
FxBuildHasher
,
FxHashMap
,
FxHashSet
};
use
std
::
collections
::
VecDeque
;
use
std
::
sync
::
atomic
::{
AtomicUsize
,
Ordering
};
use
crate
::
indexer
::
SyncIndexer
;
use
crate
::
indexer
::
{
SyncIndexer
,
WorkerTask
}
;
use
crate
::
protocols
::
*
;
/// Thread-safe shared reference to a Block.
...
...
@@ -98,10 +101,7 @@ pub struct ConcurrentRadixTree {
/// This will only contain root blocks.
root
:
SharedBlock
,
/// Per-worker lookup table for O(1) block access.
/// Outer RwLock protects the worker map structure (rarely mutated);
/// inner RwLock per worker protects that worker's block-hash map.
lookup
:
RwLock
<
FxHashMap
<
WorkerWithDpRank
,
RwLock
<
WorkerLookup
>>>
,
tree_sizes
:
DashMap
<
WorkerWithDpRank
,
AtomicUsize
,
FxBuildHasher
>
,
}
impl
Default
for
ConcurrentRadixTree
{
...
...
@@ -122,14 +122,9 @@ impl Drop for ConcurrentRadixTree {
stack
.extend
(
root
.children
.drain
()
.map
(|(
_
,
v
)|
v
));
}
// Remove all lookup references (they may include blocks not reachable from root).
// We have &mut self so no concurrent access; drain the map.
let
lookup
=
self
.lookup
.get_mut
();
for
(
_
,
inner_lock
)
in
lookup
.drain
()
{
stack
.extend
(
inner_lock
.into_inner
()
.into_values
());
}
// Iteratively free any uniquely-owned blocks without recursion
// Iteratively drop blocks to avoid stack overflow on deep trees.
// Without this loop, dropping `stack` would recursively drop each
// Arc<RwLock<Block>> through its `children` map.
while
let
Some
(
block
)
=
stack
.pop
()
{
if
let
Ok
(
rwlock
)
=
Arc
::
try_unwrap
(
block
)
{
let
mut
inner
=
rwlock
.into_inner
();
...
...
@@ -144,7 +139,7 @@ impl ConcurrentRadixTree {
pub
fn
new
()
->
Self
{
Self
{
root
:
Arc
::
new
(
RwLock
::
new
(
Block
::
new
())),
lookup
:
RwLock
::
new
(
FxHashMap
::
default
()
),
tree_sizes
:
DashMap
::
with_hasher
(
FxBuildHasher
),
}
}
...
...
@@ -197,10 +192,11 @@ impl ConcurrentRadixTree {
for
worker
in
&
active
{
scores
.scores
.insert
(
*
worker
,
1
);
}
let
lk
=
self
.lookup
.read
();
for
worker
in
scores
.scores
.keys
()
{
if
let
Some
(
inner_lock
)
=
lk
.get
(
worker
)
{
scores
.tree_sizes
.insert
(
*
worker
,
inner_lock
.read
()
.len
());
if
let
Some
(
worker_tree_size
)
=
self
.tree_sizes
.get
(
worker
)
{
scores
.tree_sizes
.insert
(
*
worker
,
worker_tree_size
.load
(
Ordering
::
Relaxed
));
}
}
return
scores
;
...
...
@@ -272,10 +268,11 @@ impl ConcurrentRadixTree {
}
// Get tree sizes from lookup.
let
lk
=
self
.lookup
.read
();
for
worker
in
scores
.scores
.keys
()
{
if
let
Some
(
inner_lock
)
=
lk
.get
(
worker
)
{
scores
.tree_sizes
.insert
(
*
worker
,
inner_lock
.read
()
.len
());
if
let
Some
(
worker_tree_size
)
=
self
.tree_sizes
.get
(
worker
)
{
scores
.tree_sizes
.insert
(
*
worker
,
worker_tree_size
.load
(
Ordering
::
Relaxed
));
}
}
...
...
@@ -290,7 +287,11 @@ impl ConcurrentRadixTree {
/// ### Arguments
///
/// * `event` - The `RouterEvent` to apply.
pub
fn
apply_event
(
&
self
,
event
:
RouterEvent
)
->
Result
<
(),
KvCacheEventError
>
{
fn
apply_event
(
&
self
,
lookup
:
&
mut
FxHashMap
<
WorkerWithDpRank
,
WorkerLookup
>
,
event
:
RouterEvent
,
)
->
Result
<
(),
KvCacheEventError
>
{
let
(
worker_id
,
kv_event
)
=
(
event
.worker_id
,
event
.event
);
let
(
id
,
op
)
=
(
kv_event
.event_id
,
kv_event
.data
);
...
...
@@ -298,10 +299,17 @@ impl ConcurrentRadixTree {
let
worker
=
WorkerWithDpRank
::
new
(
worker_id
,
kv_event
.dp_rank
);
match
op
{
KvCacheEventData
::
Stored
(
op
)
=>
self
.apply_stored
(
worker
,
op
,
id
),
KvCacheEventData
::
Removed
(
op
)
=>
self
.apply_removed
(
worker
,
op
,
id
),
KvCacheEventData
::
Stored
(
op
)
=>
self
.apply_stored
(
lookup
,
worker
,
op
,
id
),
KvCacheEventData
::
Removed
(
op
)
=>
self
.apply_removed
(
lookup
,
worker
,
op
,
id
),
KvCacheEventData
::
Cleared
=>
{
self
.clear_all_blocks
(
worker
.worker_id
);
// Ensure the worker is tracked in lookup before clearing,
// matching RadixTree behavior where `lookup.entry(worker).or_default()`
// fires before the match arm.
lookup
.entry
(
worker
)
.or_default
();
self
.tree_sizes
.entry
(
worker
)
.or_insert_with
(||
AtomicUsize
::
new
(
0
));
self
.clear_all_blocks
(
lookup
,
worker
.worker_id
);
Ok
(())
}
}
...
...
@@ -310,20 +318,13 @@ impl ConcurrentRadixTree {
/// Apply a store operation.
fn
apply_stored
(
&
self
,
lookup
:
&
mut
FxHashMap
<
WorkerWithDpRank
,
WorkerLookup
>
,
worker
:
WorkerWithDpRank
,
op
:
KvCacheStoreData
,
id
:
u64
,
)
->
Result
<
(),
KvCacheEventError
>
{
// Ensure this worker has an entry in the outer map.
if
!
self
.lookup
.read
()
.contains_key
(
&
worker
)
{
self
.lookup
.write
()
.entry
(
worker
)
.or_insert_with
(||
RwLock
::
new
(
FxHashMap
::
default
()));
}
let
lk
=
self
.lookup
.read
();
let
mut
worker_lookup
=
lk
.get
(
&
worker
)
.unwrap
()
.write
();
let
worker_lookup
=
lookup
.entry
(
worker
)
.or_default
();
// Find parent block
let
mut
current
=
match
op
.parent_hash
{
...
...
@@ -346,6 +347,8 @@ impl ConcurrentRadixTree {
let
mut
needs_worker_insert
=
false
;
let
num_blocks_added
=
op
.blocks
.len
();
// In each iteration, we lock the parent block and insert the worker into it from
// the previous iteration. This avoids locking a block twice.
for
block_data
in
op
.blocks
{
...
...
@@ -399,6 +402,16 @@ impl ConcurrentRadixTree {
current
=
child
;
}
match
self
.tree_sizes
.get
(
&
worker
)
{
Some
(
size
)
=>
{
size
.fetch_add
(
num_blocks_added
,
Ordering
::
Relaxed
);
}
None
=>
{
self
.tree_sizes
.insert
(
worker
,
AtomicUsize
::
new
(
num_blocks_added
));
}
}
// Insert worker into the last child (not yet handled since there is
// no subsequent iteration to pick it up).
if
needs_worker_insert
{
...
...
@@ -417,15 +430,16 @@ impl ConcurrentRadixTree {
/// `child_count > active_count`.
fn
apply_removed
(
&
self
,
lookup
:
&
mut
FxHashMap
<
WorkerWithDpRank
,
WorkerLookup
>
,
worker
:
WorkerWithDpRank
,
op
:
KvCacheRemoveData
,
id
:
u64
,
)
->
Result
<
(),
KvCacheEventError
>
{
let
lk
=
self
.lookup
.read
();
let
Some
(
inner_ref
)
=
lk
.get
(
&
worker
)
else
{
let
Some
(
worker_lookup
)
=
lookup
.get_mut
(
&
worker
)
else
{
return
Err
(
KvCacheEventError
::
BlockNotFound
);
};
let
mut
worker_lookup
=
inner_ref
.write
();
let
mut
num_removed
=
0
;
for
block_hash
in
op
.block_hashes
{
let
Some
(
block
)
=
worker_lookup
.remove
(
&
block_hash
)
else
{
...
...
@@ -445,6 +459,18 @@ impl ConcurrentRadixTree {
if
guard
.workers
.is_empty
()
{
guard
.children
.clear
();
}
num_removed
+=
1
;
}
match
self
.tree_sizes
.get
(
&
worker
)
{
Some
(
size
)
=>
{
size
.fetch_sub
(
num_removed
,
Ordering
::
Relaxed
);
}
None
=>
{
self
.tree_sizes
.insert
(
worker
,
AtomicUsize
::
new
(
num_removed
));
}
}
Ok
(())
...
...
@@ -453,20 +479,21 @@ impl ConcurrentRadixTree {
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains in lookup with empty blocks.
/// If `keep_worker` is false, the worker is completely removed from lookup.
fn
remove_or_clear_worker_blocks
(
&
self
,
worker_id
:
WorkerId
,
keep_worker
:
bool
)
{
let
workers
:
Vec
<
WorkerWithDpRank
>
=
self
.lookup
.read
()
fn
remove_or_clear_worker_blocks
(
&
self
,
lookup
:
&
mut
FxHashMap
<
WorkerWithDpRank
,
WorkerLookup
>
,
worker_id
:
WorkerId
,
keep_worker
:
bool
,
)
{
let
workers
:
Vec
<
WorkerWithDpRank
>
=
lookup
.keys
()
.filter
(|
w
|
w
.worker_id
==
worker_id
)
.copied
()
.collect
();
let
mut
lk
=
self
.lookup
.write
();
for
worker
in
workers
{
if
let
Some
(
inner_lock
)
=
lk
.remove
(
&
worker
)
{
let
blocks
=
inner_lock
.into_inner
();
for
(
_
,
block
)
in
blocks
{
if
let
Some
(
worker_lookup
)
=
lookup
.remove
(
&
worker
)
{
for
(
_
,
block
)
in
worker_lookup
.into_iter
()
{
let
mut
guard
=
block
.write
();
guard
.workers
.remove
(
&
worker
);
if
guard
.workers
.is_empty
()
{
...
...
@@ -475,45 +502,49 @@ impl ConcurrentRadixTree {
}
if
keep_worker
{
lk
.insert
(
worker
,
RwLock
::
new
(
FxHashMap
::
default
()));
lookup
.insert
(
worker
,
FxHashMap
::
default
());
// Reset tree size to 0 but keep the entry so get_workers()
// still returns this worker (matches RadixTree::clear_all_blocks behavior).
if
let
Some
(
size
)
=
self
.tree_sizes
.get
(
&
worker
)
{
size
.store
(
0
,
Ordering
::
Relaxed
);
}
}
else
{
// Fully remove the worker from tree_sizes so get_workers()
// no longer returns it (matches RadixTree::remove_worker behavior).
self
.tree_sizes
.remove
(
&
worker
);
}
}
}
/// Remove a worker and all their blocks from the tree.
pub
fn
remove_worker
(
&
self
,
worker_id
:
WorkerId
)
{
self
.remove_or_clear_worker_blocks
(
worker_id
,
false
);
}
/// Clear all blocks for a worker but keep the worker tracked.
pub
fn
clear_all_blocks
(
&
self
,
worker_id
:
WorkerId
)
{
self
.remove_or_clear_worker_blocks
(
worker_id
,
true
);
fn
clear_all_blocks
(
&
self
,
lookup
:
&
mut
FxHashMap
<
WorkerWithDpRank
,
WorkerLookup
>
,
worker_id
:
WorkerId
,
)
{
self
.remove_or_clear_worker_blocks
(
lookup
,
worker_id
,
true
);
}
/// Get all worker IDs currently tracked in the radix tree.
/// Returns unique worker_ids (ignoring dp_rank differences).
pub
fn
get_workers
(
&
self
)
->
Vec
<
WorkerId
>
{
let
mut
worker_ids
:
Vec
<
WorkerId
>
=
self
.lookup
.read
()
.keys
()
.map
(|
w
|
w
.worker_id
)
.collect
::
<
FxHashSet
<
_
>>
()
.into_iter
()
.tree_sizes
.iter
()
.map
(|
entry
|
entry
.key
()
.worker_id
)
.collect
();
worker_ids
.sort_unstable
();
worker_ids
.dedup
();
worker_ids
}
/// Dump the radix tree as a series of RouterEvents that can reconstruct the tree.
/// Uses BFS traversal to ensure that the tree reconstruction is unique,
/// though the exact event ordering will be lost.
pub
fn
dump_tree_as_events
(
&
self
)
->
Vec
<
RouterEvent
>
{
tracing
::
debug!
(
"Dumping concurrent radix tree as events (contains information about {:?} workers)"
,
self
.lookup
.read
()
.len
()
);
/// Uses BFS traversal over the shared tree. Since all worker/block membership is
/// stored in the tree nodes themselves, this can be called from any thread without
/// needing per-thread lookup state.
fn
dump_tree_as_events
(
&
self
)
->
Vec
<
RouterEvent
>
{
tracing
::
debug!
(
"Dumping concurrent radix tree as events"
);
let
mut
events
=
Vec
::
new
();
let
mut
event_id
=
0u64
;
...
...
@@ -567,15 +598,6 @@ impl ConcurrentRadixTree {
events
}
/// Get total number of blocks across all workers.
pub
fn
current_size
(
&
self
)
->
usize
{
self
.lookup
.read
()
.values
()
.map
(|
inner
|
inner
.read
()
.len
())
.sum
()
}
}
// ============================================================================
...
...
@@ -583,646 +605,39 @@ impl ConcurrentRadixTree {
// ============================================================================
impl
SyncIndexer
for
ConcurrentRadixTree
{
fn
find_matches
(
&
self
,
sequence
:
&
[
LocalBlockHash
],
early_exit
:
bool
)
->
OverlapScores
{
// Delegate to the existing find_matches method
self
.find_matches_impl
(
sequence
,
early_exit
)
}
fn
apply_event
(
&
self
,
event
:
RouterEvent
)
->
Result
<
(),
KvCacheEventError
>
{
self
.apply_event
(
event
)
}
fn
remove_worker
(
&
self
,
worker_id
:
WorkerId
)
{
self
.remove_worker
(
worker_id
);
}
fn
dump_events
(
&
self
)
->
Vec
<
RouterEvent
>
{
self
.dump_tree_as_events
()
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
test_utils
::{
create_remove_event
,
create_store_event
};
use
std
::
sync
::
Arc
;
use
std
::
thread
;
#[test]
fn
test_concurrent_radix_tree_basic
()
{
let
trie
=
ConcurrentRadixTree
::
new
();
let
worker_1
=
0
;
let
worker_2
=
1
;
trie
.apply_event
(
create_store_event
(
worker_1
,
1
,
vec!
[
1
,
2
,
3
],
None
))
.unwrap
();
let
scores
=
trie
.find_matches_impl
(
&
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
);
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
(),
&
3
);
assert_eq!
(
trie
.lookup
.read
()
.len
(),
1
);
assert_eq!
(
trie
.lookup
.read
()
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
()
.read
()
.len
(),
3
);
trie
.apply_event
(
create_store_event
(
worker_2
,
1
,
vec!
[
1
,
4
,
5
],
None
))
.unwrap
();
let
scores
=
trie
.find_matches_impl
(
&
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
);
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
(),
&
3
);
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
(),
&
1
);
assert_eq!
(
trie
.lookup
.read
()
.len
(),
2
);
}
#[test]
fn
test_concurrent_radix_tree_remove
()
{
let
trie
=
ConcurrentRadixTree
::
new
();
let
worker_1
=
0
;
let
worker_2
=
1
;
trie
.apply_event
(
create_store_event
(
worker_1
,
1
,
vec!
[
1
,
2
,
3
],
None
))
.unwrap
();
trie
.apply_event
(
create_store_event
(
worker_2
,
1
,
vec!
[
1
,
4
,
5
],
None
))
.unwrap
();
trie
.apply_event
(
create_remove_event
(
worker_2
,
2
,
vec!
[
5
]))
.unwrap
();
assert_eq!
(
trie
.lookup
.read
()
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
()
.read
()
.len
(),
2
);
trie
.apply_event
(
create_remove_event
(
worker_2
,
3
,
vec!
[
4
]))
.unwrap
();
assert_eq!
(
trie
.lookup
.read
()
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
()
.read
()
.len
(),
1
);
}
#[test]
fn
test_concurrent_radix_tree_apply_event_errors
()
{
let
trie
=
ConcurrentRadixTree
::
new
();
let
worker_0
=
0
;
// Parent block not found
let
result
=
trie
.apply_event
(
create_store_event
(
worker_0
,
0
,
vec!
[
1
,
2
,
3
],
Some
(
ExternalSequenceBlockHash
(
12345
)),
));
assert
!
(
result
.is_err
());
assert
!
(
matches!
(
result
.unwrap_err
(),
KvCacheEventError
::
ParentBlockNotFound
));
}
#[test]
fn
test_clear_all_blocks
()
{
let
trie
=
ConcurrentRadixTree
::
new
();
let
worker_0
=
0
;
let
worker_1
=
1
;
trie
.apply_event
(
create_store_event
(
worker_0
,
0
,
vec!
[
0
,
1
,
3
],
None
))
.unwrap
();
trie
.apply_event
(
create_store_event
(
worker_1
,
0
,
vec!
[
0
,
2
,
3
],
None
))
.unwrap
();
let
result
=
trie
.find_matches_impl
(
&
[
LocalBlockHash
(
0
)],
false
)
.scores
;
assert_eq!
(
result
.len
(),
2
);
trie
.clear_all_blocks
(
worker_0
);
assert
!
(
trie
.lookup
.read
()
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
assert
!
(
trie
.lookup
.read
()
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
.unwrap
()
.read
()
.is_empty
()
);
let
result
=
trie
.find_matches_impl
(
&
[
LocalBlockHash
(
0
),
LocalBlockHash
(
2
)],
false
)
.scores
;
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)],
2
);
}
#[test]
fn
test_remove_worker
()
{
let
trie
=
ConcurrentRadixTree
::
new
();
let
worker_0
=
0
;
let
worker_1
=
1
;
trie
.apply_event
(
create_store_event
(
worker_0
,
0
,
vec!
[
1
,
2
,
3
],
None
))
.unwrap
();
trie
.apply_event
(
create_store_event
(
worker_1
,
0
,
vec!
[
1
,
2
,
3
],
None
))
.unwrap
();
assert_eq!
(
trie
.lookup
.read
()
.len
(),
2
);
trie
.remove_worker
(
worker_0
);
assert
!
(
!
trie
.lookup
.read
()
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
))
);
assert_eq!
(
trie
.lookup
.read
()
.len
(),
1
);
let
result
=
trie
.find_matches_impl
(
&
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
)
.scores
;
assert_eq!
(
result
.len
(),
1
);
assert
!
(
!
result
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_0
)));
assert
!
(
result
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)));
}
#[test]
fn
test_concurrent_radix_tree_default
()
{
let
trie
:
ConcurrentRadixTree
=
Default
::
default
();
assert
!
(
trie
.root
.read
()
.children
.is_empty
());
assert
!
(
trie
.root
.read
()
.workers
.is_empty
());
assert
!
(
trie
.lookup
.read
()
.is_empty
());
}
#[test]
fn
test_concurrent_find_matches
()
{
let
trie
=
Arc
::
new
(
ConcurrentRadixTree
::
new
());
// Populate tree
trie
.apply_event
(
create_store_event
(
0
,
0
,
vec!
[
1
,
2
,
3
,
4
,
5
],
None
))
.unwrap
();
trie
.apply_event
(
create_store_event
(
1
,
0
,
vec!
[
1
,
2
,
6
,
7
,
8
],
None
))
.unwrap
();
let
sequence
=
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
),
LocalBlockHash
(
4
),
LocalBlockHash
(
5
),
];
// Spawn multiple threads doing concurrent find_matches
let
handles
:
Vec
<
_
>
=
(
0
..
10
)
.map
(|
_
|
{
let
tree
=
trie
.clone
();
let
seq
=
sequence
.clone
();
thread
::
spawn
(
move
||
tree
.find_matches_impl
(
&
seq
,
false
))
})
.collect
();
// All should return the same result
let
expected_worker_0_score
=
5
;
let
expected_worker_1_score
=
2
;
for
h
in
handles
{
let
result
=
h
.join
()
.unwrap
();
assert_eq!
(
result
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
0
))
.unwrap
(),
&
expected_worker_0_score
);
assert_eq!
(
result
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
1
))
.unwrap
(),
&
expected_worker_1_score
);
}
}
#[test]
fn
test_concurrent_read_write
()
{
let
trie
=
Arc
::
new
(
ConcurrentRadixTree
::
new
());
// Pre-populate
for
i
in
0
..
5
{
trie
.apply_event
(
create_store_event
(
i
,
0
,
vec!
[
1
,
2
,
3
],
None
))
.unwrap
();
}
fn
worker
(
&
self
,
event_receiver
:
flume
::
Receiver
<
WorkerTask
>
)
->
anyhow
::
Result
<
()
>
{
let
mut
lookup
=
FxHashMap
::
default
();
let
sequence
=
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)];
// Spawn readers
let
reader_handles
:
Vec
<
_
>
=
(
0
..
5
)
.map
(|
_
|
{
let
tree
=
trie
.clone
();
let
seq
=
sequence
.clone
();
thread
::
spawn
(
move
||
{
for
_
in
0
..
100
{
let
_
=
tree
.find_matches_impl
(
&
seq
,
false
);
}
})
})
.collect
();
// Spawn writers (adding more workers)
let
writer_handles
:
Vec
<
_
>
=
(
5
..
10
)
.map
(|
i
|
{
let
tree
=
trie
.clone
();
thread
::
spawn
(
move
||
{
for
j
in
0
..
10
{
let
_
=
tree
.apply_event
(
create_store_event
(
i
,
j
,
vec!
[
1
,
2
,
3
,
4
+
j
],
None
));
}
})
})
.collect
();
// Wait for all threads
for
h
in
reader_handles
{
h
.join
()
.unwrap
();
}
for
h
in
writer_handles
{
h
.join
()
.unwrap
();
while
let
Ok
(
task
)
=
event_receiver
.recv
()
{
match
task
{
WorkerTask
::
Event
(
event
)
=>
{
if
let
Err
(
e
)
=
self
.apply_event
(
&
mut
lookup
,
event
)
{
tracing
::
warn!
(
"Failed to apply event: {:?}"
,
e
);
}
// Tree should have 10 workers now
assert_eq!
(
trie
.get_workers
()
.len
(),
10
);
}
#[test]
fn
test_remove_parent_does_not_cascade
()
{
let
trie
=
ConcurrentRadixTree
::
new
();
let
worker_1
=
0
;
// Create a chain: root -> block1 -> block2 -> block3
trie
.apply_event
(
create_store_event
(
worker_1
,
1
,
vec!
[
1
,
2
,
3
],
None
))
.unwrap
();
let
worker_key
=
WorkerWithDpRank
::
from_worker_id
(
worker_1
);
assert_eq!
(
trie
.lookup
.read
()
.get
(
&
worker_key
)
.unwrap
()
.read
()
.len
(),
3
);
// Remove ONLY block1 -- descendants should NOT be cascade-removed
trie
.apply_event
(
create_remove_event
(
worker_1
,
2
,
vec!
[
1
]))
.unwrap
();
let
lk
=
trie
.lookup
.read
();
let
worker_lookup
=
lk
.get
(
&
worker_key
)
.unwrap
()
.read
();
assert
!
(
!
worker_lookup
.contains_key
(
&
ExternalSequenceBlockHash
(
100
)),
"block1 should be removed"
);
assert
!
(
worker_lookup
.contains_key
(
&
ExternalSequenceBlockHash
(
200
)),
"block2 should remain (no cascade)"
);
assert
!
(
worker_lookup
.contains_key
(
&
ExternalSequenceBlockHash
(
300
)),
"block3 should remain (no cascade)"
);
assert_eq!
(
worker_lookup
.len
(),
2
);
}
#[test]
fn
test_remove_all_blocks_individually
()
{
// Verifies that explicitly removing all blocks (as the engine would)
// cleans up fully, even without cascade.
let
trie
=
ConcurrentRadixTree
::
new
();
let
worker_1
=
0
;
trie
.apply_event
(
create_store_event
(
worker_1
,
1
,
vec!
[
1
,
2
,
3
],
None
))
.unwrap
();
let
worker_key
=
WorkerWithDpRank
::
from_worker_id
(
worker_1
);
// Remove all three blocks explicitly in one event
trie
.apply_event
(
create_remove_event
(
worker_1
,
2
,
vec!
[
1
,
2
,
3
]))
.unwrap
();
let
lk
=
trie
.lookup
.read
();
let
worker_lookup
=
lk
.get
(
&
worker_key
)
.unwrap
()
.read
();
assert_eq!
(
worker_lookup
.len
(),
0
,
"all blocks should be removed"
);
}
#[test]
fn
test_find_matches_with_stale_entries
()
{
// Two workers share a full path. Remove worker_1 from the root block
// only (simulating a partial remove). find_matches should still
// produce correct scores for worker_2, and worker_1 should score at
// the stale descendant depth (transiently inflated but not a crash).
let
trie
=
ConcurrentRadixTree
::
new
();
let
worker_1
=
0
;
let
worker_2
=
1
;
// Both workers have blocks 1 -> 2 -> 3
trie
.apply_event
(
create_store_event
(
worker_1
,
1
,
vec!
[
1
,
2
,
3
],
None
))
.unwrap
();
trie
.apply_event
(
create_store_event
(
worker_2
,
2
,
vec!
[
1
,
2
,
3
],
None
))
.unwrap
();
// Remove worker_1 from block 1 only (no cascade to 2,3)
trie
.apply_event
(
create_remove_event
(
worker_1
,
3
,
vec!
[
1
]))
.unwrap
();
let
scores
=
trie
.find_matches_impl
(
&
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)],
false
,
);
// worker_2 was never removed, should have full depth
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
)),
Some
(
&
3
),
"worker_2 should score 3 (fully present)"
);
// worker_1 was removed from block 1 so it drops out at depth 1.
// But because blocks 2 and 3 still have worker_1 (stale), the
// child_count > active_count path fires and detects the dropout.
// The exact score depends on the detection logic: worker_1 is absent
// from block 1's workers, so it should be scored at depth 0 from the
// first child initialization (it won't appear in `active` at all).
// So worker_1 should NOT appear in scores (it was never in active).
assert
!
(
!
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)),
"worker_1 should not appear in scores (removed from root-level block)"
);
}
// ========================================================================
// ThreadPoolIndexer<ConcurrentRadixTree> Tests
// ========================================================================
mod
thread_pool_indexer_tests
{
use
tokio
::
time
::
Duration
;
use
super
::
*
;
use
crate
::
indexer
::{
KvIndexerInterface
,
ThreadPoolIndexer
};
fn
make_indexer
(
num_workers
:
usize
,
kv_block_size
:
u32
,
)
->
ThreadPoolIndexer
<
ConcurrentRadixTree
>
{
ThreadPoolIndexer
::
new
(
ConcurrentRadixTree
::
new
(),
num_workers
,
kv_block_size
)
WorkerTask
::
RemoveWorker
(
worker_id
)
=>
{
self
.remove_or_clear_worker_blocks
(
&
mut
lookup
,
worker_id
,
false
);
}
#[tokio::test]
async
fn
test_thread_pool_indexer_basic
()
{
let
indexer
=
make_indexer
(
4
,
16
);
let
worker_1
=
0
;
let
worker_2
=
1
;
indexer
.apply_event
(
create_store_event
(
worker_1
,
1
,
vec!
[
1
,
2
,
3
],
None
))
.await
;
indexer
.apply_event
(
create_store_event
(
worker_2
,
1
,
vec!
[
1
,
4
,
5
],
None
))
.await
;
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
let
scores
=
indexer
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
),
])
.await
.unwrap
();
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
))
.unwrap
(),
&
3
);
assert_eq!
(
scores
.scores
.get
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
))
.unwrap
(),
&
1
);
indexer
.shutdown
();
WorkerTask
::
DumpEvents
(
_
sender
)
=>
{
// Handled directly via dump_events() on the shared tree.
// Should not be reached, but respond with empty to avoid blocking.
let
_
=
_
sender
.send
(
Ok
(
Vec
::
new
()));
}
#[tokio::test]
async
fn
test_thread_pool_indexer_remove_worker
()
{
let
indexer
=
make_indexer
(
2
,
16
);
let
worker_0
=
0
;
let
worker_1
=
1
;
indexer
.apply_event
(
create_store_event
(
worker_0
,
1
,
vec!
[
1
,
2
,
3
],
None
))
.await
;
indexer
.apply_event
(
create_store_event
(
worker_1
,
1
,
vec!
[
1
,
2
,
3
],
None
))
.await
;
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
assert_eq!
(
indexer
.backend
()
.get_workers
()
.len
(),
2
);
indexer
.remove_worker
(
worker_0
)
.await
;
let
workers
=
indexer
.backend
()
.get_workers
();
assert_eq!
(
workers
.len
(),
1
);
assert
!
(
!
workers
.contains
(
&
worker_0
));
assert
!
(
workers
.contains
(
&
worker_1
));
indexer
.shutdown
();
}
#[tokio::test]
async
fn
test_thread_pool_indexer_dump_events
()
{
let
indexer
=
make_indexer
(
2
,
16
);
indexer
.apply_event
(
create_store_event
(
0
,
1
,
vec!
[
1
,
2
,
3
],
None
))
.await
;
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
let
events
=
indexer
.dump_events
()
.await
.unwrap
();
assert_eq!
(
events
.len
(),
3
);
indexer
.shutdown
();
}
#[tokio::test]
async
fn
test_thread_pool_indexer_find_matches_for_request
()
{
let
indexer
=
make_indexer
(
2
,
1
);
indexer
.apply_event
(
create_store_event
(
0
,
1
,
vec!
[
100
,
200
,
300
],
None
))
.await
;
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
let
scores
=
indexer
.find_matches_for_request
(
&
[
100
,
200
,
300
],
None
)
.await
;
assert
!
(
scores
.is_ok
());
indexer
.shutdown
();
}
#[tokio::test]
async
fn
test_thread_pool_indexer_sticky_routing
()
{
let
indexer
=
make_indexer
(
4
,
16
);
for
i
in
0
..
10
{
indexer
.apply_event
(
create_store_event
(
0
,
i
,
vec!
[
i
],
None
))
.await
;
}
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
assert_eq!
(
indexer
.backend
()
.current_size
(),
10
);
indexer
.shutdown
();
}
#[tokio::test]
async
fn
test_thread_pool_indexer_multiple_workers
()
{
let
indexer
=
make_indexer
(
4
,
16
);
for
worker_id
in
0
..
8
{
indexer
.apply_event
(
create_store_event
(
worker_id
,
1
,
vec!
[
1
,
2
,
worker_id
+
10
],
None
,
))
.await
;
}
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
assert_eq!
(
indexer
.backend
()
.get_workers
()
.len
(),
8
);
let
scores
=
indexer
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
)])
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
8
);
for
(
_
,
score
)
in
scores
.scores
.iter
()
{
assert_eq!
(
*
score
,
2
);
}
indexer
.shutdown
();
WorkerTask
::
Terminate
=>
{
break
;
}
#[tokio::test]
async
fn
test_thread_pool_indexer_shutdown_idempotent
()
{
let
indexer
=
make_indexer
(
2
,
16
);
indexer
.apply_event
(
create_store_event
(
0
,
1
,
vec!
[
1
,
2
,
3
],
None
))
.await
;
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
indexer
.shutdown
();
indexer
.shutdown
();
}
#[tokio::test]
async
fn
test_thread_pool_indexer_concurrent_operations
()
{
use
std
::
sync
::
Arc
;
let
indexer
=
Arc
::
new
(
make_indexer
(
4
,
16
));
for
worker_id
in
0
..
4
{
indexer
.apply_event
(
create_store_event
(
worker_id
,
1
,
vec!
[
1
,
2
,
3
,
4
,
5
],
None
))
.await
;
}
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
let
sequence
=
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
)];
let
mut
handles
=
Vec
::
new
();
for
_
in
0
..
10
{
let
idx
=
indexer
.clone
();
let
seq
=
sequence
.clone
();
handles
.push
(
tokio
::
spawn
(
async
move
{
idx
.find_matches
(
seq
)
.await
.unwrap
()
},
));
tracing
::
debug!
(
"ConcurrentRadixTree worker thread shutting down"
);
Ok
(())
}
for
handle
in
handles
{
let
scores
=
handle
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
4
);
fn
find_matches
(
&
self
,
sequence
:
&
[
LocalBlockHash
],
early_exit
:
bool
)
->
OverlapScores
{
self
.find_matches_impl
(
sequence
,
early_exit
)
}
indexer
.shutdown
();
}
fn
dump_events
(
&
self
)
->
Option
<
Vec
<
RouterEvent
>>
{
Some
(
self
.dump_tree_as_events
())
}
}
lib/kv-router/src/indexer.rs
View file @
9e5014da
...
...
@@ -359,6 +359,14 @@ pub trait KvIndexerInterface {
async
fn
flush
(
&
self
)
->
usize
;
}
pub
enum
WorkerTask
{
Event
(
RouterEvent
),
/// Permanently remove a worker from tracking (keep_worker: false).
RemoveWorker
(
WorkerId
),
DumpEvents
(
oneshot
::
Sender
<
anyhow
::
Result
<
Vec
<
RouterEvent
>>>
),
Terminate
,
}
// ============================================================================
// SyncIndexer trait and ThreadPoolIndexer generic wrapper
// ============================================================================
...
...
@@ -373,17 +381,18 @@ pub trait KvIndexerInterface {
/// - Sticky event routing to N worker threads
/// - Inline reads on the caller's thread (no channel dispatch for find_matches)
pub
trait
SyncIndexer
:
Send
+
Sync
+
'static
{
fn
worker
(
&
self
,
event_receiver
:
flume
::
Receiver
<
WorkerTask
>
)
->
anyhow
::
Result
<
()
>
;
/// Find matches for a sequence of block hashes.
fn
find_matches
(
&
self
,
sequence
:
&
[
LocalBlockHash
],
early_exit
:
bool
)
->
OverlapScores
;
/// Apply a router event to the data structure.
fn
apply_event
(
&
self
,
event
:
RouterEvent
)
->
Result
<
(),
KvCacheEventError
>
;
/// Remove all entries for a worker.
fn
remove_worker
(
&
self
,
worker_id
:
WorkerId
);
/// Dump the data structure as router events for reconstruction.
fn
dump_events
(
&
self
)
->
Vec
<
RouterEvent
>
;
/// Dump events directly from the shared structure, bypassing worker channels.
/// Returns `Some(events)` for backends whose tree state is fully shared (e.g.
/// ConcurrentRadixTree). Returns `None` for backends that keep per-thread
/// state and must dump via the worker channel.
fn
dump_events
(
&
self
)
->
Option
<
Vec
<
RouterEvent
>>
{
None
}
}
/// Generic wrapper that provides [`KvIndexerInterface`] for any [`SyncIndexer`] backend.
...
...
@@ -415,9 +424,9 @@ pub struct ThreadPoolIndexer<T: SyncIndexer> {
/// Counter for round-robin assignment of new WorkerIds.
worker_assignment_count
:
AtomicUsize
,
/// Channels to send
event
s to worker threads (one per thread).
/// Sending `
Non
e` signals the thread to shut down.
worker_event_channels
:
Vec
<
flume
::
Sender
<
Option
<
RouterEvent
>
>>
,
/// Channels to send
task
s to worker threads (one per thread).
/// Sending `
WorkerTask::Terminat
e` signals the thread to shut down.
worker_event_channels
:
Vec
<
flume
::
Sender
<
WorkerTask
>>
,
/// Number of worker threads.
num_workers
:
usize
,
...
...
@@ -450,18 +459,13 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> {
let
mut
worker_event_senders
=
Vec
::
new
();
let
mut
thread_handles
=
Vec
::
new
();
for
_
in
0
..
num_workers
{
let
(
event_sender
,
event_receiver
)
=
flume
::
unbounded
::
<
Option
<
RouterEvent
>
>
();
let
(
event_sender
,
event_receiver
)
=
flume
::
unbounded
::
<
WorkerTask
>
();
worker_event_senders
.push
(
event_sender
);
let
backend
=
Arc
::
clone
(
&
backend
);
let
handle
=
std
::
thread
::
spawn
(
move
||
{
while
let
Ok
(
Some
(
event
))
=
event_receiver
.recv
()
{
if
let
Err
(
e
)
=
backend
.apply_event
(
event
)
{
tracing
::
warn!
(
"Failed to apply event: {:?}"
,
e
);
}
}
tracing
::
debug!
(
"Worker thread shutting down"
);
backend
.worker
(
event_receiver
)
.unwrap
();
});
thread_handles
.push
(
handle
);
}
...
...
@@ -530,7 +534,7 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
});
// Send event to the assigned worker thread
if
let
Err
(
e
)
=
self
.worker_event_channels
[
thread_idx
]
.send
(
Some
(
event
))
{
if
let
Err
(
e
)
=
self
.worker_event_channels
[
thread_idx
]
.send
(
WorkerTask
::
Event
(
event
))
{
tracing
::
error!
(
"Failed to send event to worker thread {}: {:?}"
,
thread_idx
,
...
...
@@ -540,14 +544,34 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
}
async
fn
remove_worker
(
&
self
,
worker_id
:
WorkerId
)
{
// Execute inline - the backend is thread-safe
self
.backend
.remove_worker
(
worker_id
);
// Route to the worker's assigned thread (if any), otherwise broadcast
// to all threads since dp_ranks may be spread across threads.
let
thread_idx
=
self
.worker_assignments
.get
(
&
worker_id
)
.map
(|
v
|
*
v
);
match
thread_idx
{
Some
(
idx
)
=>
{
if
let
Err
(
e
)
=
self
.worker_event_channels
[
idx
]
.send
(
WorkerTask
::
RemoveWorker
(
worker_id
))
{
tracing
::
error!
(
"Failed to send RemoveWorker to worker thread {}: {:?}"
,
idx
,
e
);
}
}
None
=>
{
// Worker was never assigned a thread - broadcast to all
for
channel
in
&
self
.worker_event_channels
{
let
_
=
channel
.send
(
WorkerTask
::
RemoveWorker
(
worker_id
));
}
}
}
}
fn
shutdown
(
&
self
)
{
// Send shutdown signal
(None)
to all worker threads
// Send shutdown signal to all worker threads
for
channel
in
self
.worker_event_channels
.iter
()
{
let
_
=
channel
.send
(
Non
e
);
let
_
=
channel
.send
(
WorkerTask
::
Terminat
e
);
}
// Take ownership of thread handles and join them
...
...
@@ -565,8 +589,41 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
}
async
fn
dump_events
(
&
self
)
->
Result
<
Vec
<
RouterEvent
>
,
KvRouterError
>
{
// Execute inline - the backend is thread-safe
Ok
(
self
.backend
.dump_events
())
// Fast path: backend can dump directly from shared state (e.g. ConcurrentRadixTree).
if
let
Some
(
events
)
=
self
.backend
.dump_events
()
{
return
Ok
(
events
);
}
// Slow path: collect from each worker thread via channel (e.g. PositionalIndexer).
let
mut
receivers
=
Vec
::
new
();
for
channel
in
&
self
.worker_event_channels
{
let
(
resp_tx
,
resp_rx
)
=
oneshot
::
channel
::
<
anyhow
::
Result
<
Vec
<
RouterEvent
>>>
();
let
dump_req
=
WorkerTask
::
DumpEvents
(
resp_tx
);
channel
.send
(
dump_req
)
.map_err
(|
_
|
KvRouterError
::
IndexerOffline
)
?
;
receivers
.push
(
resp_rx
);
}
let
mut
event_id_counter
=
0
;
let
mut
all_events
=
Vec
::
new
();
for
resp_rx
in
receivers
{
let
mut
events
=
resp_rx
.await
.map_err
(|
_
|
KvRouterError
::
IndexerDroppedRequest
)
?
.map_err
(|
_
|
KvRouterError
::
IndexerOffline
)
?
;
for
event
in
&
mut
events
{
event
.event.event_id
=
event_id_counter
;
event_id_counter
+=
1
;
}
all_events
.extend
(
events
);
}
Ok
(
all_events
)
}
async
fn
process_routing_decision_for_request
(
...
...
@@ -2354,6 +2411,16 @@ mod tests {
index
.shutdown
();
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_shutdown_idempotent
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
index
.shutdown
();
index
.shutdown
();
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_find_matches_for_request
(
variant
:
&
str
)
{
...
...
lib/kv-router/src/nested_map.rs
View file @
9e5014da
...
...
@@ -21,10 +21,10 @@
//! `KvIndexerInterface` with sticky event routing and worker threads, wrap it
//! in a `ThreadPoolIndexer`.
use
dashmap
::
DashMap
;
use
parking_lot
::
RwLock
;
use
rustc_hash
::{
FxBuildHasher
,
FxHashMap
,
FxHashSet
};
use
std
::
sync
::
atomic
::{
AtomicUsize
,
Ordering
};
use
crate
::
indexer
::
SyncIndexer
;
use
crate
::
indexer
::
{
SyncIndexer
,
WorkerTask
}
;
use
crate
::
protocols
::{
ExternalSequenceBlockHash
,
KvCacheEvent
,
KvCacheEventData
,
KvCacheEventError
,
KvCacheStoreData
,
KvCacheStoredBlockData
,
LocalBlockHash
,
OverlapScores
,
RouterEvent
,
WorkerId
,
WorkerWithDpRank
,
...
...
@@ -100,7 +100,7 @@ impl SeqEntry {
}
}
type
LevelIndex
=
RwLock
<
FxHashMap
<
ExternalSequenceBlockHash
,
(
usize
,
LocalBlockHash
)
>
>
;
pub
type
LevelIndex
=
FxHashMap
<
ExternalSequenceBlockHash
,
(
usize
,
LocalBlockHash
)
>
;
/// Positional HashMap-based KV cache index.
///
...
...
@@ -108,11 +108,8 @@ type LevelIndex = RwLock<FxHashMap<ExternalSequenceBlockHash, (usize, LocalBlock
/// All methods are synchronous and thread-safe.
pub
struct
PositionalIndexer
{
index
:
DashMap
<
(
usize
,
LocalBlockHash
),
SeqEntry
,
FxBuildHasher
>
,
/// Per-worker reverse lookup: worker -> seq_hash -> (position, local_hash)
/// Enables efficient remove operations without global flat reverse map.
/// Uses a single RwLock rather than DashMap because structural mutations
/// (adding/removing workers) are rare; the hot path is read-only.
worker_blocks
:
RwLock
<
FxHashMap
<
WorkerWithDpRank
,
LevelIndex
>>
,
tree_sizes
:
DashMap
<
WorkerWithDpRank
,
AtomicUsize
,
FxBuildHasher
>
,
jump_size
:
usize
,
}
...
...
@@ -129,7 +126,7 @@ impl PositionalIndexer {
Self
{
index
:
DashMap
::
with_hasher
(
FxBuildHasher
),
worker_blocks
:
RwLock
::
new
(
FxHashMap
::
default
()
),
tree_sizes
:
DashMap
::
with_hasher
(
FxBuildHasher
),
jump_size
,
}
}
...
...
@@ -140,83 +137,37 @@ impl PositionalIndexer {
// ============================================================================
impl
SyncIndexer
for
PositionalIndexer
{
fn
find_matches
(
&
self
,
sequence
:
&
[
LocalBlockHash
],
early_exit
:
bool
)
->
OverlapScores
{
self
.jump_search_matches
(
sequence
,
early_exit
)
}
fn
worker
(
&
self
,
event_receiver
:
flume
::
Receiver
<
WorkerTask
>
)
->
anyhow
::
Result
<
()
>
{
let
mut
worker_blocks
=
FxHashMap
::
default
();
fn
apply_event
(
&
self
,
event
:
RouterEvent
)
->
Result
<
(),
KvCacheEventError
>
{
Self
::
apply_event_impl
(
&
self
.index
,
&
self
.worker_blocks
,
event
)
while
let
Ok
(
task
)
=
event_receiver
.recv
()
{
match
task
{
WorkerTask
::
Event
(
event
)
=>
{
if
let
Err
(
e
)
=
self
.apply_event
(
&
mut
worker_blocks
,
event
)
{
tracing
::
warn!
(
"Failed to apply event: {:?}"
,
e
);
}
fn
remove_worker
(
&
self
,
worker_id
:
WorkerId
)
{
Self
::
remove_or_clear_worker_blocks_impl
(
&
self
.index
,
&
self
.worker_blocks
,
worker_id
,
false
,
);
}
fn
dump_events
(
&
self
)
->
Vec
<
RouterEvent
>
{
let
mut
events
=
Vec
::
new
();
let
mut
event_id
=
0u64
;
let
wb
=
self
.worker_blocks
.read
();
for
(
worker
,
level_index
)
in
wb
.iter
()
{
let
worker
=
*
worker
;
let
worker_map
=
level_index
.read
();
// Collect (position, local_hash, seq_hash) and sort by position
// so parents are emitted before children during replay.
let
mut
blocks
:
Vec
<
_
>
=
worker_map
.iter
()
.map
(|(
seq_hash
,
(
pos
,
local_hash
))|
(
*
pos
,
*
local_hash
,
*
seq_hash
))
.collect
();
blocks
.sort_unstable_by_key
(|(
pos
,
_
,
_
)|
*
pos
);
// Track one valid seq_hash per position for parent_hash synthesis.
let
mut
last_at_position
:
FxHashMap
<
usize
,
ExternalSequenceBlockHash
>
=
FxHashMap
::
default
();
for
(
pos
,
local_hash
,
seq_hash
)
in
blocks
{
let
parent_hash
=
if
pos
==
0
{
None
}
else
{
match
last_at_position
.get
(
&
(
pos
-
1
))
{
Some
(
&
parent
)
=>
Some
(
parent
),
None
=>
{
tracing
::
warn!
(
worker_id
=
worker
.worker_id
.to_string
(),
dp_rank
=
worker
.dp_rank
,
position
=
pos
,
"Orphaned block at position with no parent; skipping in dump"
);
continue
;
WorkerTask
::
RemoveWorker
(
worker_id
)
=>
{
self
.remove_or_clear_worker_blocks_impl
(
&
mut
worker_blocks
,
worker_id
,
false
);
}
WorkerTask
::
DumpEvents
(
sender
)
=>
{
let
events
=
self
.dump_events
(
&
worker_blocks
);
if
let
Err
(
e
)
=
sender
.send
(
Ok
(
events
))
{
tracing
::
warn!
(
"Failed to send events: {:?}"
,
e
);
}
}
WorkerTask
::
Terminate
=>
{
break
;
}
};
events
.push
(
RouterEvent
{
worker_id
:
worker
.worker_id
,
event
:
KvCacheEvent
{
event_id
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
,
blocks
:
vec!
[
KvCacheStoredBlockData
{
block_hash
:
seq_hash
,
tokens_hash
:
local_hash
,
mm_extra_info
:
None
,
}],
}),
dp_rank
:
worker
.dp_rank
,
},
});
event_id
+=
1
;
last_at_position
.insert
(
pos
,
seq_hash
);
}
}
events
tracing
::
debug!
(
"PositionalIndexer worker thread shutting down"
);
Ok
(())
}
fn
find_matches
(
&
self
,
sequence
:
&
[
LocalBlockHash
],
early_exit
:
bool
)
->
OverlapScores
{
self
.jump_search_matches
(
sequence
,
early_exit
)
}
}
...
...
@@ -227,9 +178,9 @@ impl SyncIndexer for PositionalIndexer {
impl
PositionalIndexer
{
/// Process an event using the provided index and worker_blocks.
/// This is called from worker threads.
fn
apply_event
_impl
(
index
:
&
DashMap
<
(
usize
,
LocalBlockHash
),
SeqEntry
,
FxBuildHasher
>
,
worker_blocks
:
&
RwLock
<
FxHashMap
<
WorkerWithDpRank
,
LevelIndex
>
>
,
pub
fn
apply_event
(
&
self
,
worker_blocks
:
&
mut
FxHashMap
<
WorkerWithDpRank
,
LevelIndex
>
,
event
:
RouterEvent
,
)
->
Result
<
(),
KvCacheEventError
>
{
let
(
worker_id
,
kv_event
)
=
(
event
.worker_id
,
event
.event
);
...
...
@@ -245,50 +196,32 @@ impl PositionalIndexer {
match
op
{
KvCacheEventData
::
Stored
(
store_data
)
=>
{
S
elf
::
store_blocks_impl
(
index
,
worker_blocks
,
worker
,
store_data
,
id
)
?
;
s
elf
.
store_blocks_impl
(
worker_blocks
,
worker
,
store_data
,
id
)
?
;
Ok
(())
}
KvCacheEventData
::
Removed
(
remove_data
)
=>
{
Self
::
remove_blocks_impl
(
index
,
worker_blocks
,
worker
,
&
remove_data
.block_hashes
,
id
,
)
?
;
self
.remove_blocks_impl
(
worker_blocks
,
worker
,
&
remove_data
.block_hashes
,
id
)
?
;
Ok
(())
}
KvCacheEventData
::
Cleared
=>
{
S
elf
::
clear_worker_blocks_impl
(
index
,
worker_blocks
,
worker_id
);
s
elf
.
clear_worker_blocks_impl
(
worker_blocks
,
worker_id
);
Ok
(())
}
}
}
fn
store_blocks_impl
(
index
:
&
DashMap
<
(
usize
,
LocalBlockHash
),
SeqEntry
,
FxBuildHasher
>
,
worker_blocks
:
&
RwLock
<
FxHashMap
<
WorkerWithDpRank
,
LevelIndex
>
>
,
&
self
,
worker_blocks
:
&
mut
FxHashMap
<
WorkerWithDpRank
,
LevelIndex
>
,
worker
:
WorkerWithDpRank
,
store_data
:
KvCacheStoreData
,
event_id
:
u64
,
)
->
Result
<
(),
KvCacheEventError
>
{
let
worker_map
=
worker_blocks
.entry
(
worker
)
.or_default
();
// Determine starting position based on parent_hash
let
start_pos
=
match
store_data
.parent_hash
{
Some
(
parent_hash
)
=>
{
let
wb
=
worker_blocks
.read
();
let
Some
(
level_index
)
=
wb
.get
(
&
worker
)
else
{
tracing
::
warn!
(
worker_id
=
worker
.worker_id
.to_string
(),
dp_rank
=
worker
.dp_rank
,
event_id
,
parent_hash
=
?
parent_hash
,
);
return
Err
(
KvCacheEventError
::
ParentBlockNotFound
);
};
let
worker_map
=
level_index
.read
();
let
Some
(
entry
)
=
worker_map
.get
(
&
parent_hash
)
else
{
tracing
::
warn!
(
worker_id
=
worker
.worker_id
.to_string
(),
...
...
@@ -304,42 +237,45 @@ impl PositionalIndexer {
None
=>
0
,
// Start from position 0
};
if
!
worker_blocks
.read
()
.contains_key
(
&
worker
)
{
worker_blocks
.write
()
.entry
(
worker
)
.or_insert_with
(||
RwLock
::
new
(
FxHashMap
::
default
()));
}
let
worker_blocks_entry
=
worker_blocks
.entry
(
worker
)
.or_default
();
let
wb
=
worker_blocks
.read
();
let
mut
worker_map
=
wb
.get
(
&
worker
)
.unwrap
()
.write
();
let
num_stored_blocks
=
store_data
.blocks
.len
();
for
(
i
,
block_data
)
in
store_data
.blocks
.into_iter
()
.enumerate
()
{
let
position
=
start_pos
+
i
;
let
local_hash
=
block_data
.tokens_hash
;
let
seq_hash
=
block_data
.block_hash
;
index
self
.
index
.entry
((
position
,
local_hash
))
.and_modify
(|
entry
|
entry
.insert
(
seq_hash
,
worker
))
.or_insert_with
(||
SeqEntry
::
new
(
seq_hash
,
worker
));
// Insert into worker_blocks: worker -> seq_hash -> (position, local_hash)
worker_map
.insert
(
seq_hash
,
(
position
,
local_hash
));
worker_blocks_entry
.insert
(
seq_hash
,
(
position
,
local_hash
));
}
match
self
.tree_sizes
.get
(
&
worker
)
{
Some
(
size
)
=>
{
size
.fetch_add
(
num_stored_blocks
,
Ordering
::
Relaxed
);
}
None
=>
{
self
.tree_sizes
.insert
(
worker
,
AtomicUsize
::
new
(
num_stored_blocks
));
}
}
Ok
(())
}
fn
remove_blocks_impl
(
index
:
&
DashMap
<
(
usize
,
LocalBlockHash
),
SeqEntry
,
FxBuildHasher
>
,
worker_blocks
:
&
RwLock
<
FxHashMap
<
WorkerWithDpRank
,
LevelIndex
>
>
,
&
self
,
worker_blocks
:
&
mut
FxHashMap
<
WorkerWithDpRank
,
LevelIndex
>
,
worker
:
WorkerWithDpRank
,
seq_hashes
:
&
Vec
<
ExternalSequenceBlockHash
>
,
event_id
:
u64
,
)
->
Result
<
(),
KvCacheEventError
>
{
let
wb
=
worker_blocks
.read
();
let
level_index
=
wb
.get
(
&
worker
)
.ok_or_else
(||
{
let
worker_map
=
worker_blocks
.get_mut
(
&
worker
)
.ok_or_else
(||
{
tracing
::
warn!
(
worker_id
=
worker
.worker_id
.to_string
(),
dp_rank
=
worker
.dp_rank
,
...
...
@@ -350,7 +286,7 @@ impl PositionalIndexer {
KvCacheEventError
::
BlockNotFound
})
?
;
let
mut
worker_map
=
level_index
.write
()
;
let
mut
num_removed_blocks
=
0
;
for
seq_hash
in
seq_hashes
{
let
Some
((
position
,
local_hash
))
=
worker_map
.remove
(
seq_hash
)
else
{
...
...
@@ -361,13 +297,23 @@ impl PositionalIndexer {
block_hash
=
?
seq_hash
,
"Failed to find block to remove; skipping remove operation"
);
if
let
Some
(
size
)
=
self
.tree_sizes
.get
(
&
worker
)
{
size
.fetch_sub
(
num_removed_blocks
,
Ordering
::
Relaxed
);
}
return
Err
(
KvCacheEventError
::
BlockNotFound
);
};
// Remove from index
if
let
Some
(
mut
entry
)
=
index
.get_mut
(
&
(
position
,
local_hash
))
{
if
let
Some
(
mut
entry
)
=
self
.index
.get_mut
(
&
(
position
,
local_hash
))
{
let
_
=
entry
.remove
(
*
seq_hash
,
worker
);
}
num_removed_blocks
+=
1
;
}
if
let
Some
(
size
)
=
self
.tree_sizes
.get
(
&
worker
)
{
size
.fetch_sub
(
num_removed_blocks
,
Ordering
::
Relaxed
);
}
Ok
(())
...
...
@@ -376,63 +322,114 @@ impl PositionalIndexer {
/// Clear all blocks for a specific worker_id (all dp_ranks), but keep worker tracked.
/// Static version for use in worker threads.
fn
clear_worker_blocks_impl
(
index
:
&
DashMap
<
(
usize
,
LocalBlockHash
),
SeqEntry
,
FxBuildHasher
>
,
worker_blocks
:
&
RwLock
<
FxHashMap
<
WorkerWithDpRank
,
LevelIndex
>
>
,
&
self
,
worker_blocks
:
&
mut
FxHashMap
<
WorkerWithDpRank
,
LevelIndex
>
,
worker_id
:
WorkerId
,
)
{
Self
::
remove_or_clear_worker_blocks_impl
(
index
,
worker_blocks
,
worker_id
,
true
);
}
/// Get total number of blocks across all workers.
pub
fn
current_size
(
&
self
)
->
usize
{
self
.worker_blocks
.read
()
.values
()
.map
(|
level_index
|
level_index
.read
()
.len
())
.sum
()
}
/// Remove a worker and all their blocks completely from the index.
#[allow(dead_code)]
fn
remove_worker_blocks
(
&
self
,
worker_id
:
WorkerId
)
{
Self
::
remove_or_clear_worker_blocks_impl
(
&
self
.index
,
&
self
.worker_blocks
,
worker_id
,
false
,
);
self
.remove_or_clear_worker_blocks_impl
(
worker_blocks
,
worker_id
,
true
);
}
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains tracked with empty blocks.
/// If `keep_worker` is false, the worker is completely removed.
fn
remove_or_clear_worker_blocks_impl
(
index
:
&
DashMap
<
(
usize
,
LocalBlockHash
),
SeqEntry
,
FxBuildHasher
>
,
worker_blocks
:
&
RwLock
<
FxHashMap
<
WorkerWithDpRank
,
LevelIndex
>
>
,
&
self
,
worker_blocks
:
&
mut
FxHashMap
<
WorkerWithDpRank
,
LevelIndex
>
,
worker_id
:
WorkerId
,
keep_worker
:
bool
,
)
{
let
workers
:
Vec
<
WorkerWithDpRank
>
=
worker_blocks
.read
()
.keys
()
.filter
(|
w
|
w
.worker_id
==
worker_id
)
.copied
()
.iter
()
.filter
(|
entry
|
entry
.0
.worker_id
==
worker_id
)
.map
(|
entry
|
*
entry
.0
)
.collect
();
let
mut
wb
=
worker_blocks
.write
();
for
worker
in
workers
{
if
let
Some
(
worker_map
)
=
w
b
.remove
(
&
worker
)
{
for
(
seq_hash
,
(
position
,
local_hash
))
in
worker_map
.
read
()
.
iter
()
{
if
let
Some
(
mut
entry
)
=
index
.get_mut
(
&
(
*
position
,
*
local_hash
))
{
if
let
Some
(
worker_map
)
=
w
orker_blocks
.remove
(
&
worker
)
{
for
(
seq_hash
,
(
position
,
local_hash
))
in
worker_map
.iter
()
{
if
let
Some
(
mut
entry
)
=
self
.
index
.get_mut
(
&
(
*
position
,
*
local_hash
))
{
let
_
=
entry
.remove
(
*
seq_hash
,
worker
);
}
}
}
if
keep_worker
{
wb
.insert
(
worker
,
RwLock
::
new
(
FxHashMap
::
default
()));
// Re-insert worker with empty map to keep it tracked
worker_blocks
.insert
(
worker
,
FxHashMap
::
default
());
// Reset tree size to 0 but keep the entry so scoring remains consistent.
if
let
Some
(
size
)
=
self
.tree_sizes
.get
(
&
worker
)
{
size
.store
(
0
,
Ordering
::
Relaxed
);
}
}
else
{
// Fully remove the worker from tree_sizes.
self
.tree_sizes
.remove
(
&
worker
);
}
}
}
fn
dump_events
(
&
self
,
worker_blocks
:
&
FxHashMap
<
WorkerWithDpRank
,
LevelIndex
>
,
)
->
Vec
<
RouterEvent
>
{
let
mut
events
=
Vec
::
new
();
let
mut
event_id
=
0u64
;
for
(
worker
,
worker_map
)
in
worker_blocks
.iter
()
{
// Collect (position, local_hash, seq_hash) and sort by position
// so parents are emitted before children during replay.
let
mut
blocks
:
Vec
<
_
>
=
worker_map
.iter
()
.map
(|(
seq_hash
,
(
pos
,
local_hash
))|
(
*
pos
,
*
local_hash
,
*
seq_hash
))
.collect
();
blocks
.sort_unstable_by_key
(|(
pos
,
_
,
_
)|
*
pos
);
// Track one valid seq_hash per position for parent_hash synthesis.
// Note: The synthesized parent_hash doesn't need to be the true logical
// parent — during replay it's only used to derive `start_pos = parent.position + 1`,
// so any seq_hash at the previous position is sufficient. The PositionalIndexer
// is position-based, not tree-topology-based.
let
mut
last_at_position
:
FxHashMap
<
usize
,
ExternalSequenceBlockHash
>
=
FxHashMap
::
default
();
for
(
pos
,
local_hash
,
seq_hash
)
in
blocks
{
let
parent_hash
=
if
pos
==
0
{
None
}
else
{
match
last_at_position
.get
(
&
(
pos
-
1
))
{
Some
(
&
parent
)
=>
Some
(
parent
),
None
=>
{
tracing
::
warn!
(
worker_id
=
worker
.worker_id
.to_string
(),
dp_rank
=
worker
.dp_rank
,
position
=
pos
,
"Orphaned block at position with no parent; skipping in dump"
);
continue
;
}
}
};
events
.push
(
RouterEvent
{
worker_id
:
worker
.worker_id
,
event
:
KvCacheEvent
{
event_id
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
,
blocks
:
vec!
[
KvCacheStoredBlockData
{
block_hash
:
seq_hash
,
tokens_hash
:
local_hash
,
mm_extra_info
:
None
,
}],
}),
dp_rank
:
worker
.dp_rank
,
},
});
event_id
+=
1
;
last_at_position
.insert
(
pos
,
seq_hash
);
}
}
events
}
}
...
...
@@ -533,11 +530,10 @@ impl PositionalIndexer {
hi
:
usize
,
early_exit
:
bool
,
)
{
for
pos
in
lo
..
hi
{
if
active
.is_empty
()
{
break
;
return
;
}
for
pos
in
lo
..
hi
{
let
Some
(
entry
)
=
self
.index
.get
(
&
(
pos
,
sequence
[
pos
]))
else
{
for
worker
in
active
.iter
()
{
scores
.scores
.insert
(
*
worker
,
pos
as
u32
);
...
...
@@ -568,6 +564,7 @@ impl PositionalIndexer {
scores
.scores
.insert
(
*
worker
,
pos
as
u32
);
}
active
.clear
();
break
;
}
}
}
...
...
@@ -626,10 +623,12 @@ impl PositionalIndexer {
scores
.scores
.insert
(
*
worker
,
1
);
}
// Populate tree_sizes
let
wb
=
self
.worker_blocks
.read
();
for
worker
in
scores
.scores
.keys
()
{
if
let
Some
(
level_index
)
=
wb
.get
(
worker
)
{
scores
.tree_sizes
.insert
(
*
worker
,
level_index
.read
()
.len
());
if
let
Some
(
worker_tree_size
)
=
self
.tree_sizes
.get
(
worker
)
{
scores
.tree_sizes
.insert
(
*
worker
,
worker_tree_size
.load
(
Ordering
::
Relaxed
));
}
}
return
scores
;
...
...
@@ -677,11 +676,11 @@ impl PositionalIndexer {
scores
.scores
.insert
(
worker
,
final_score
);
}
// Populate tree_sizes from worker_blocks
let
wb
=
self
.worker_blocks
.read
();
for
worker
in
scores
.scores
.keys
()
{
if
let
Some
(
level_index
)
=
wb
.get
(
worker
)
{
scores
.tree_sizes
.insert
(
*
worker
,
level_index
.read
()
.len
());
if
let
Some
(
worker_tree_size
)
=
self
.tree_sizes
.get
(
worker
)
{
scores
.tree_sizes
.insert
(
*
worker
,
worker_tree_size
.load
(
Ordering
::
Relaxed
));
}
}
...
...
tests/fault_tolerance/deploy/container/Dockerfile.local_vllm
View file @
9e5014da
...
...
@@ -182,6 +182,7 @@ RUN apt-get update -y && \
pybind11-dev \
clang \
libclang-dev \
libfontconfig-dev \
protobuf-compiler && \
rm -rf /var/lib/apt/lists/*
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment