Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
3ea22fcf
Unverified
Commit
3ea22fcf
authored
Nov 12, 2025
by
Waël Boukhobza
Committed by
GitHub
Nov 12, 2025
Browse files
feat(router): max tree size based pruning (#4057)
Signed-off-by:
Wael Boukhobza
<
wawa_wael@live.fr
>
parent
a207b4be
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
431 additions
and
58 deletions
+431
-58
lib/bindings/python/rust/llm/kv.rs
lib/bindings/python/rust/llm/kv.rs
+5
-0
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+5
-0
lib/llm/src/kv_router/approx.rs
lib/llm/src/kv_router/approx.rs
+402
-58
lib/llm/src/kv_router/indexer.rs
lib/llm/src/kv_router/indexer.rs
+19
-0
No files found.
lib/bindings/python/rust/llm/kv.rs
View file @
3ea22fcf
...
...
@@ -726,10 +726,15 @@ impl ApproxKvIndexer {
#[new]
fn
new
(
component
:
Component
,
kv_block_size
:
usize
,
ttl_secs
:
f64
)
->
PyResult
<
Self
>
{
let
ttl
=
tokio
::
time
::
Duration
::
from_secs_f64
(
ttl_secs
);
let
prune_config
=
Some
(
llm_rs
::
kv_router
::
approx
::
PruneConfig
{
max_tree_size
:
2u
size
.pow
(
14
),
// 2** 14 = 16384
prune_target_ratio
:
0.8
,
});
let
inner
=
Arc
::
new
(
llm_rs
::
kv_router
::
approx
::
ApproxKvIndexer
::
new
(
component
.inner
.drt
()
.runtime
()
.child_token
(),
kv_block_size
as
u32
,
ttl
,
prune_config
,
));
Ok
(
Self
{
inner
})
}
...
...
lib/llm/src/kv_router.rs
View file @
3ea22fcf
...
...
@@ -36,6 +36,7 @@ pub use prefill_router::PrefillRouter;
use
crate
::{
kv_router
::{
approx
::
ApproxKvIndexer
,
approx
::
PruneConfig
,
indexer
::{
KvIndexer
,
KvIndexerInterface
,
KvRouterError
,
OverlapScores
,
RouterEvent
,
compute_block_hash_for_seq
,
compute_seq_hash_for_block
,
...
...
@@ -259,6 +260,10 @@ impl KvRouter {
cancellation_token
.clone
(),
block_size
,
Duration
::
from_secs
(
120
),
Some
(
PruneConfig
{
max_tree_size
:
2u
size
.pow
(
14
),
// 2** 14 = 16384
prune_target_ratio
:
0.8
,
}),
))
};
...
...
lib/llm/src/kv_router/approx.rs
View file @
3ea22fcf
...
...
@@ -13,13 +13,15 @@
//!
//! - The thinking behind this is that if we send a request to a worker, and shortly after get a request with a similar prefix, odds
//! are that routing to the same worker will result in a large cache hit.
//! - Another benefit is the ability to bound the size of the radix tree, which is not possible if we were trying to accurately represent
//! the state of each worker.
use
async_trait
::
async_trait
;
use
std
::
cmp
::
Reverse
;
use
std
::
collections
::{
BinaryHeap
,
HashMap
};
use
std
::
hash
::
Hash
;
use
std
::
sync
::
OnceLock
;
use
tokio
::
sync
::{
mpsc
,
oneshot
};
use
tokio
::
sync
::{
mpsc
,
oneshot
,
watch
};
use
tokio
::
time
::{
Duration
,
Instant
};
use
tokio_util
::
sync
::
CancellationToken
;
...
...
@@ -54,45 +56,78 @@ struct RouterResult {
sequence_hashes
:
Vec
<
u64
>
,
}
#[derive(Debug,
Clone,
Copy,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord)]
struct
TimerEntry
{
/// The key of the timer.
/// Block entry to be inserted in the [`PruneManager::expirations`] heap.
#[derive(Debug,
Clone,
Copy,
Hash,
PartialEq,
Eq)]
struct
BlockEntry
{
/// The key of the block entry.
key
:
ExternalSequenceBlockHash
,
/// The worker (with dp_rank) that stored this block.
worker
:
WorkerWithDpRank
,
/// The position of this block in the sequence (0-indexed).
seq_position
:
usize
,
}
impl
PartialOrd
for
BlockEntry
{
fn
partial_cmp
(
&
self
,
other
:
&
Self
)
->
Option
<
std
::
cmp
::
Ordering
>
{
Some
(
self
.cmp
(
other
))
}
}
impl
Ord
for
BlockEntry
{
fn
cmp
(
&
self
,
other
:
&
Self
)
->
std
::
cmp
::
Ordering
{
// Break ties by sequence position (important for pruning), then by key, then by worker.
self
.seq_position
.cmp
(
&
other
.seq_position
)
.then_with
(||
self
.key
.cmp
(
&
other
.key
))
.then_with
(||
self
.worker
.cmp
(
&
other
.worker
))
}
}
#[derive(Debug,
Clone)]
pub
struct
PruneConfig
{
/// The maximum tree size before pruning is considered.
pub
max_tree_size
:
usize
,
/// The target size ratio to prune down to when max_tree_size is exceeded.
/// For example, if max_tree_size is 100 and target_size_ratio is 0.5,
/// we will prune down to 50 nodes when max_tree_size is exceeded.
pub
prune_target_ratio
:
f64
,
}
/// A data structure to manage a collection of timers, addressable by a key.
/// This is structured as a sort of "priority queue" of keys, where the priority is the expiration time.
/// It supports insertion as well as updating the expiration time of a key.
/// The [`
Timer
Manager::expirations`] heap is lazily updated to reflect the true expiration times in [`
Timer
Manager::timers`]
/// The [`
Prune
Manager::expirations`] heap is lazily updated to reflect the true expiration times in [`
Prune
Manager::timers`]
/// For now, we have a fixed expiration time for all keys.
#[derive(Debug)]
struct
Timer
Manager
<
K
:
Clone
+
Hash
+
Eq
+
Ord
>
{
struct
Prune
Manager
<
K
:
Clone
+
Hash
+
Eq
+
Ord
>
{
/// The source of truth. Maps a key to its current expiration instant.
timers
:
HashMap
<
K
,
Instant
>
,
/// A min-heap of (expiration_instant, key) used to efficiently find the
/// next expiring timer. An entry in this heap is "stale" if the instant
/// does not match the one in the `timers` map.
expirations
:
BinaryHeap
<
Reverse
<
(
Instant
,
K
)
>>
,
/// The expiration duration of the timers.
ttl
:
Duration
,
/// A max-heap of (Reverse<expiration_instant>, key) used to efficiently find the
/// next expiring timer. Reverse<Instant> makes earlier times pop first.
/// An entry in this heap is "stale" if the instant does not match the one in the `timers` map.
expirations
:
BinaryHeap
<
(
Reverse
<
Instant
>
,
K
)
>
,
/// Threshold for rebuilding the heap.
/// The heap will be rebuilt from scratch to remove stale entries.
threshold
:
usize
,
/// The expiration duration of the timers.
ttl
:
Duration
,
/// The configuration for tree-size pruning.
prune_config
:
Option
<
PruneConfig
>
,
}
impl
<
K
:
Clone
+
Hash
+
Eq
+
Ord
>
Timer
Manager
<
K
>
{
/// Creates a new, empty
Timer
Manager.
pub
fn
new
(
ttl
:
Duration
,
threshold
:
usize
)
->
Self
{
Timer
Manager
{
impl
<
K
:
Clone
+
Hash
+
Eq
+
Ord
>
Prune
Manager
<
K
>
{
/// Creates a new, empty
Prune
Manager.
pub
fn
new
(
ttl
:
Duration
,
threshold
:
usize
,
prune_config
:
Option
<
PruneConfig
>
)
->
Self
{
Prune
Manager
{
timers
:
HashMap
::
new
(),
expirations
:
BinaryHeap
::
new
(),
ttl
,
threshold
,
prune_config
,
}
}
...
...
@@ -101,7 +136,7 @@ impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
self
.expirations
=
self
.timers
.iter
()
.map
(|(
key
,
&
expiry
)|
Reverse
(
(
expiry
,
key
.clone
()))
)
.map
(|(
key
,
&
expiry
)|
(
Reverse
(
expiry
)
,
key
.clone
()))
.collect
();
}
...
...
@@ -120,7 +155,7 @@ impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
// Push the new expiration onto the heap. If the key was updated,
// this leaves a "stale" entry on the heap for the old time,
// which will be ignored when it's popped.
self
.expirations
.push
(
Reverse
(
(
expiry_time
,
key
))
)
;
self
.expirations
.push
(
(
Reverse
(
expiry_time
)
,
key
));
}
// Check if we should rebuild the heap to remove stale entries
...
...
@@ -135,14 +170,14 @@ impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
let
mut
expired_keys
=
Vec
::
new
();
let
now
=
Instant
::
now
();
while
let
Some
(
Reverse
(
(
expiry_time
,
_
))
)
=
self
.expirations
.peek
()
{
while
let
Some
(
(
Reverse
(
expiry_time
)
,
_
))
=
self
.expirations
.peek
()
{
// If the next timer in the heap is not yet expired, we can stop.
if
*
expiry_time
>
now
{
break
;
}
// The timer might be expired, so pop it from the heap.
let
Reverse
(
(
expiry_time
,
key
)
)
=
self
.expirations
.pop
()
.unwrap
();
let
(
Reverse
(
expiry_time
)
,
key
)
=
self
.expirations
.pop
()
.unwrap
();
if
self
.timers
.get
(
&
key
)
==
Some
(
&
expiry_time
)
{
// This is a valid, non-stale, expired timer.
...
...
@@ -158,7 +193,57 @@ impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
pub
fn
peek_next_expiry
(
&
self
)
->
Option
<
Instant
>
{
self
.expirations
.peek
()
.map
(|
Reverse
((
expiry_time
,
_
))|
*
expiry_time
)
.map
(|(
Reverse
(
expiry_time
),
_
)|
*
expiry_time
)
}
/// Prunes the tree if the current size is greater than the max tree size.
pub
fn
prune
(
&
mut
self
,
current_size
:
usize
)
->
Result
<
Vec
<
K
>
,
KvRouterError
>
{
let
max_tree_size
:
usize
;
let
prune_target_ratio
:
f64
;
if
let
Some
(
prune_config
)
=
&
self
.prune_config
{
max_tree_size
=
prune_config
.max_tree_size
;
prune_target_ratio
=
prune_config
.prune_target_ratio
;
}
else
{
tracing
::
error!
(
"Prune was called but prune config is None. This should never happen"
);
return
Err
(
KvRouterError
::
PruneFailed
(
"prune config is missing"
.to_string
(),
));
}
if
current_size
<=
max_tree_size
{
// Tree size within bounds, no pruning needed.
return
Ok
(
Vec
::
new
());
}
tracing
::
info!
(
"Pruning: tree size ({}) exceeded max tree size ({}), starting pruning"
,
current_size
,
max_tree_size
);
// Number of blocks that will be kept after pruning.
let
target_size
=
(
max_tree_size
as
f64
*
prune_target_ratio
)
as
usize
;
let
mut
pruned_keys
=
Vec
::
new
();
let
mut
num_pruned
=
0
;
while
num_pruned
<
current_size
.saturating_sub
(
target_size
)
{
if
let
Some
((
Reverse
(
expiry_time
),
key
))
=
self
.expirations
.pop
()
{
if
self
.timers
.get
(
&
key
)
==
Some
(
&
expiry_time
)
{
// This is a valid, non-stale timer.
self
.timers
.remove
(
&
key
);
pruned_keys
.push
(
key
);
num_pruned
+=
1
;
}
}
else
{
break
;
}
}
tracing
::
info!
(
"Pruning: pruned ({}) blocks from tree"
,
num_pruned
);
Ok
(
pruned_keys
)
}
}
...
...
@@ -180,13 +265,19 @@ pub struct ApproxKvIndexer {
}
impl
ApproxKvIndexer
{
pub
fn
new
(
token
:
CancellationToken
,
kv_block_size
:
u32
,
ttl
:
Duration
)
->
Self
{
pub
fn
new
(
token
:
CancellationToken
,
kv_block_size
:
u32
,
ttl
:
Duration
,
prune_config
:
Option
<
PruneConfig
>
,
)
->
Self
{
let
(
match_tx
,
mut
match_rx
)
=
mpsc
::
channel
::
<
MatchRequest
>
(
2048
);
let
(
route_tx
,
mut
route_rx
)
=
mpsc
::
channel
::
<
RouterResult
>
(
2048
);
let
(
remove_worker_tx
,
mut
remove_worker_rx
)
=
mpsc
::
channel
::
<
WorkerId
>
(
16
);
let
(
_
get_workers_tx
,
mut
get_workers_rx
)
=
mpsc
::
channel
::
<
super
::
indexer
::
GetWorkersRequest
>
(
16
);
let
(
dump_tx
,
mut
dump_rx
)
=
mpsc
::
channel
::
<
DumpRequest
>
(
16
);
let
(
prune_tx
,
mut
prune_rx
)
=
watch
::
channel
(
false
);
let
cancel_clone
=
token
.clone
();
let
task
=
std
::
thread
::
spawn
(
move
||
{
// create a new tokio runtime which will only perform work on a single thread
...
...
@@ -197,12 +288,13 @@ impl ApproxKvIndexer {
runtime
.block_on
(
async
move
{
let
mut
trie
=
RadixTree
::
new
();
// Use a reasonable threshold - can be made configurable if needed
let
mut
timer
_manager
:
Timer
Manager
<
Timer
Entry
>
=
Timer
Manager
::
new
(
ttl
,
50
);
// Use a reasonable threshold
for ttl
- can be made configurable if needed
let
mut
prune
_manager
:
Prune
Manager
<
Block
Entry
>
=
Prune
Manager
::
new
(
ttl
,
50
,
prune_config
.clone
()
);
let
mut
event_id
=
0
;
loop
{
// Create a future that sleeps until the next expiration time.
let
expiry_fut
=
if
let
Some
(
next_expiry
)
=
timer
_manager
.peek_next_expiry
()
{
let
expiry_fut
=
if
let
Some
(
next_expiry
)
=
prune
_manager
.peek_next_expiry
()
{
tokio
::
time
::
sleep_until
(
next_expiry
)
}
else
{
// If there are no timers, sleep forever.
...
...
@@ -245,12 +337,29 @@ impl ApproxKvIndexer {
}
);
let
_
=
trie
.apply_event
(
event
);
timer_manager
.insert
(
result
.sequence_hashes
.iter
()
.map
(|
h
|
TimerEntry
{
key
:
ExternalSequenceBlockHash
(
*
h
),
worker
:
result
.worker
,
})
.collect
());
if
trie
.apply_event
(
event
)
.is_ok
()
{
prune_manager
.insert
(
result
.sequence_hashes
.iter
()
.enumerate
()
.map
(|(
idx
,
h
)|
BlockEntry
{
key
:
ExternalSequenceBlockHash
(
*
h
),
worker
:
result
.worker
,
seq_position
:
idx
,
})
.collect
());
// Check if we need to prune due to tree size exceeding max threshold.
if
let
Some
(
prune_config
)
=
&
prune_manager
.prune_config
{
let
current_size
=
trie
.current_size
();
if
current_size
>
prune_config
.max_tree_size
{
tracing
::
info!
(
"Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning"
,
current_size
,
prune_config
.max_tree_size
);
// Send a signal to the pruning watcher to schedule pruning.
if
let
Err
(
e
)
=
prune_tx
.send
(
true
)
{
tracing
::
error!
(
"Failed to send prune schedule signal: {:?}"
,
e
);
}
}
}
}
}
Some
(
dump_req
)
=
dump_rx
.recv
()
=>
{
...
...
@@ -263,8 +372,33 @@ impl ApproxKvIndexer {
request
.resp
.send
(
scores
)
.unwrap
();
}
Ok
(
_
)
=
prune_rx
.changed
()
=>
{
// The tree has exceeded the max tree size, so proceed with pruning.
if
let
Ok
(
pruned
)
=
prune_manager
.prune
(
trie
.current_size
())
{
pruned
.iter
()
.for_each
(|
p
|
{
event_id
+=
1
;
let
event
=
RouterEvent
::
new
(
p
.worker.worker_id
,
KvCacheEvent
{
event_id
,
data
:
KvCacheEventData
::
Removed
(
KvCacheRemoveData
{
block_hashes
:
vec!
[
p
.key
],
}),
dp_rank
:
p
.worker.dp_rank
,
}
);
let
_
=
trie
.apply_event
(
event
);
});
// Reset the pruning watcher to false to indicate that pruning is complete.
if
let
Err
(
e
)
=
prune_tx
.send
(
true
)
{
tracing
::
error!
(
"Failed to send prune completion signal: {:?}"
,
e
);
}
}
}
_
=
expiry_fut
=>
{
let
expired
=
timer
_manager
.pop_expired
();
let
expired
=
prune
_manager
.pop_expired
();
expired
.iter
()
.for_each
(|
e
|
{
event_id
+=
1
;
...
...
@@ -424,7 +558,7 @@ mod tests {
const
KV_BLOCK_SIZE
:
u32
=
4
;
impl
<
T
:
Clone
+
Hash
+
Eq
+
Ord
>
Timer
Manager
<
T
>
{
impl
<
T
:
Clone
+
Hash
+
Eq
+
Ord
>
Prune
Manager
<
T
>
{
pub
fn
get_expiry
(
&
self
,
key
:
&
T
)
->
Option
<&
Instant
>
{
self
.timers
.get
(
key
)
}
...
...
@@ -449,43 +583,43 @@ mod tests {
}
}
/// Validate basic insert / expiry behaviour of [`
Timer
Manager`].
/// Validate basic insert / expiry behaviour of [`
Prune
Manager`].
#[tokio::test]
async
fn
test_
timer
_manager_expiry
()
{
async
fn
test_
prune
_manager_expiry
()
{
const
TTL
:
Duration
=
Duration
::
from_millis
(
50
);
let
mut
t
m
:
Timer
Manager
<
u32
>
=
Timer
Manager
::
new
(
TTL
,
50
);
let
mut
p
m
:
Prune
Manager
<
u32
>
=
Prune
Manager
::
new
(
TTL
,
50
,
None
);
t
m
.insert
(
vec!
[
1
,
2
,
3
]);
assert
!
(
t
m
.get_expiry
(
&
1
)
.is_some
());
assert
!
(
t
m
.get_expiry
(
&
2
)
.is_some
());
assert
!
(
t
m
.get_expiry
(
&
3
)
.is_some
());
p
m
.insert
(
vec!
[
1
,
2
,
3
]);
assert
!
(
p
m
.get_expiry
(
&
1
)
.is_some
());
assert
!
(
p
m
.get_expiry
(
&
2
)
.is_some
());
assert
!
(
p
m
.get_expiry
(
&
3
)
.is_some
());
// Wait until after the TTL
time
::
sleep
(
TTL
+
Duration
::
from_millis
(
20
))
.await
;
let
expired
=
t
m
.pop_expired
();
let
expired
=
p
m
.pop_expired
();
assert_eq!
(
expired
.len
(),
3
);
assert
!
(
t
m
.get_expiry
(
&
1
)
.is_none
());
assert
!
(
t
m
.get_expiry
(
&
2
)
.is_none
());
assert
!
(
t
m
.get_expiry
(
&
3
)
.is_none
());
assert
!
(
p
m
.get_expiry
(
&
1
)
.is_none
());
assert
!
(
p
m
.get_expiry
(
&
2
)
.is_none
());
assert
!
(
p
m
.get_expiry
(
&
3
)
.is_none
());
}
/// Validate that reinserting an existing key extends its TTL and prevents premature expiry.
#[tokio::test]
async
fn
test_
timer
_manager_update_resets_ttl
()
{
async
fn
test_
prune
_manager_update_resets_ttl
()
{
// Validate that reinserting an existing key extends its TTL and prevents premature expiry.
const
TTL
:
Duration
=
Duration
::
from_millis
(
50
);
let
mut
t
m
:
Timer
Manager
<
u32
>
=
Timer
Manager
::
new
(
TTL
,
50
);
let
mut
p
m
:
Prune
Manager
<
u32
>
=
Prune
Manager
::
new
(
TTL
,
50
,
None
);
// Initial insert and capture the original expiry.
t
m
.insert
(
vec!
[
42
]);
let
first_expiry
=
*
t
m
p
m
.insert
(
vec!
[
42
]);
let
first_expiry
=
*
p
m
.get_expiry
(
&
42
)
.expect
(
"expiry missing after first insert"
);
// Wait for half of the original TTL before reinserting.
time
::
sleep
(
Duration
::
from_millis
(
25
))
.await
;
t
m
.insert
(
vec!
[
42
]);
let
second_expiry
=
*
t
m
p
m
.insert
(
vec!
[
42
]);
let
second_expiry
=
*
p
m
.get_expiry
(
&
42
)
.expect
(
"expiry missing after reinsertion"
);
...
...
@@ -494,7 +628,7 @@ mod tests {
// Wait until *after* the first expiry would have fired, but *before* the new expiry.
time
::
sleep
(
Duration
::
from_millis
(
30
))
.await
;
// 25ms already elapsed, +30ms = 55ms > first TTL
let
expired
=
t
m
.pop_expired
();
let
expired
=
p
m
.pop_expired
();
assert
!
(
expired
.is_empty
(),
"key expired prematurely despite TTL refresh"
...
...
@@ -502,7 +636,7 @@ mod tests {
// Now wait until after the second expiry should have occurred.
time
::
sleep
(
Duration
::
from_millis
(
30
))
.await
;
// Ensure we pass the refreshed TTL
let
expired_after
=
t
m
.pop_expired
();
let
expired_after
=
p
m
.pop_expired
();
assert_eq!
(
expired_after
,
vec!
[
42
]);
}
...
...
@@ -514,7 +648,7 @@ mod tests {
async
fn
test_approx_kv_indexer_basic_flow
()
{
const
TTL
:
Duration
=
Duration
::
from_millis
(
200
);
let
cancel
=
CancellationToken
::
new
();
let
indexer
=
ApproxKvIndexer
::
new
(
cancel
.clone
(),
KV_BLOCK_SIZE
,
TTL
);
let
indexer
=
ApproxKvIndexer
::
new
(
cancel
.clone
(),
KV_BLOCK_SIZE
,
TTL
,
None
);
let
tokens
:
Vec
<
u32
>
=
vec!
[
1
,
2
,
3
,
4
];
// Exactly one KV block
let
worker_id
:
WorkerId
=
0
;
...
...
@@ -556,7 +690,7 @@ mod tests {
async
fn
test_remove_worker
()
{
const
TTL
:
Duration
=
Duration
::
from_secs
(
5
);
// Large enough to avoid expiry during test
let
cancel
=
CancellationToken
::
new
();
let
mut
indexer
=
ApproxKvIndexer
::
new
(
cancel
.clone
(),
KV_BLOCK_SIZE
,
TTL
);
let
mut
indexer
=
ApproxKvIndexer
::
new
(
cancel
.clone
(),
KV_BLOCK_SIZE
,
TTL
,
None
);
let
tokens
:
Vec
<
u32
>
=
vec!
[
10
,
11
,
12
,
13
];
let
worker_id
:
WorkerId
=
7
;
...
...
@@ -595,7 +729,7 @@ mod tests {
const
TTL
:
Duration
=
Duration
::
from_secs
(
5
);
// Large enough to avoid expiry during test
let
cancel
=
CancellationToken
::
new
();
let
mut
indexer
=
ApproxKvIndexer
::
new
(
cancel
.clone
(),
KV_BLOCK_SIZE
,
TTL
);
let
mut
indexer
=
ApproxKvIndexer
::
new
(
cancel
.clone
(),
KV_BLOCK_SIZE
,
TTL
,
None
);
let
tokens
:
Vec
<
u32
>
=
vec!
[
100
,
101
,
102
,
103
];
let
worker_0
:
WorkerId
=
30
;
...
...
@@ -653,7 +787,7 @@ mod tests {
const
TTL
:
Duration
=
Duration
::
from_secs
(
5
);
let
cancel
=
CancellationToken
::
new
();
let
indexer
=
ApproxKvIndexer
::
new
(
cancel
.clone
(),
KV_BLOCK_SIZE
,
TTL
);
let
indexer
=
ApproxKvIndexer
::
new
(
cancel
.clone
(),
KV_BLOCK_SIZE
,
TTL
,
None
);
// Sequence A : single block
let
seq_a
:
Vec
<
u32
>
=
vec!
[
1
,
2
,
3
,
4
];
...
...
@@ -699,7 +833,7 @@ mod tests {
const
TTL
:
Duration
=
Duration
::
from_secs
(
5
);
let
cancel
=
CancellationToken
::
new
();
let
indexer
=
ApproxKvIndexer
::
new
(
cancel
.clone
(),
KV_BLOCK_SIZE
,
TTL
);
let
indexer
=
ApproxKvIndexer
::
new
(
cancel
.clone
(),
KV_BLOCK_SIZE
,
TTL
,
None
);
let
tokens
:
Vec
<
u32
>
=
vec!
[
9
,
8
,
7
,
6
];
let
worker_0
:
WorkerId
=
21
;
...
...
@@ -750,4 +884,214 @@ mod tests {
Some
(
&
1
)
);
}
/// Test that pruning returns empty when tree size is within the max tree size.
#[tokio::test]
async
fn
test_prune_manager_no_prune_when_within_bounds
()
{
const
TTL
:
Duration
=
Duration
::
from_secs
(
10
);
let
prune_config
=
PruneConfig
{
max_tree_size
:
100
,
prune_target_ratio
:
0.5
,
};
let
mut
pm
:
PruneManager
<
u32
>
=
PruneManager
::
new
(
TTL
,
50
,
Some
(
prune_config
));
// Insert 50 keys (well below max_tree_size of 100)
pm
.insert
((
0
..
50
)
.collect
());
// Pruning should return empty vec when size is within bounds
let
pruned
=
pm
.prune
(
50
)
.unwrap
();
assert
!
(
pruned
.is_empty
());
// All keys should still be present
for
i
in
0
..
50
{
assert
!
(
pm
.get_expiry
(
&
i
)
.is_some
());
}
}
/// Test that pruning removes the oldest entries first.
#[tokio::test]
async
fn
test_prune_manager_prune_removes_oldest_first
()
{
const
TTL
:
Duration
=
Duration
::
from_secs
(
10
);
let
prune_config
=
PruneConfig
{
max_tree_size
:
10
,
prune_target_ratio
:
0.5
,
};
let
mut
pm
:
PruneManager
<
u32
>
=
PruneManager
::
new
(
TTL
,
50
,
Some
(
prune_config
));
// Insert keys one at a time with delays to ensure different timestamps
for
i
in
1
..=
15
{
pm
.insert
(
vec!
[
i
]);
time
::
sleep
(
Duration
::
from_millis
(
1
))
.await
;
}
// Total: 15 keys. Trigger pruning with current_size = 15
let
pruned
=
pm
.prune
(
15
)
.unwrap
();
// Should prune down to 5 (10 * 0.5), so 10 keys should be pruned (15 - 5)
assert_eq!
(
pruned
.len
(),
10
);
// The oldest keys should be pruned first
for
i
in
1
..=
10
{
assert
!
(
pruned
.contains
(
&
i
));
}
// The newer keys should still be present
for
i
in
11
..=
15
{
assert
!
(
pm
.get_expiry
(
&
i
)
.is_some
());
}
}
/// Test that pruning fails gracefully when config is None.
#[tokio::test]
async
fn
test_prune_manager_prune_fails_without_config
()
{
const
TTL
:
Duration
=
Duration
::
from_secs
(
10
);
let
mut
pm
:
PruneManager
<
u32
>
=
PruneManager
::
new
(
TTL
,
50
,
None
);
pm
.insert
(
vec!
[
1
,
2
,
3
]);
// Pruning should fail when prune_config is None
let
result
=
pm
.prune
(
150
);
assert
!
(
result
.is_err
());
assert
!
(
matches!
(
result
,
Err
(
KvRouterError
::
PruneFailed
(
_
))));
}
/// Test that BlockEntry ordering prioritizes sequence position.
#[test]
fn
test_block_entry_ordering
()
{
let
worker
=
WorkerWithDpRank
::
from_worker_id
(
0
);
let
entry1
=
BlockEntry
{
key
:
ExternalSequenceBlockHash
(
100
),
worker
,
seq_position
:
0
,
};
let
entry2
=
BlockEntry
{
key
:
ExternalSequenceBlockHash
(
50
),
worker
,
seq_position
:
1
,
};
// entry1 < entry2 because seq_position 0 < 1
assert
!
(
entry1
<
entry2
);
}
/// End-to-end test for [`ApproxKvIndexer`] with pruning
/// 0. Max tree size is 5, target size is 2 (prune_target_ratio = 0.4)
/// 1. Insert 5 blocks (at max_tree_size but not exceeding)
/// 2. Verify all 5 blocks are present
/// 3. Insert 6th block (exceeds threshold, triggers reactive pruning)
/// 4. Verify pruning occurred: 4 oldest blocks removed
/// 5. Verify 2 newest blocks remain
#[tokio::test]
async
fn
test_approx_indexer_e2e_pruning
()
{
const
TTL
:
Duration
=
Duration
::
from_secs
(
60
);
// Long TTL to avoid expiry
let
prune_config
=
PruneConfig
{
max_tree_size
:
5
,
// Very small to trigger pruning quickly
prune_target_ratio
:
0.4
,
// target size is 5 * 0.4 = 2
};
let
cancel
=
CancellationToken
::
new
();
let
indexer
=
ApproxKvIndexer
::
new
(
cancel
.clone
(),
KV_BLOCK_SIZE
,
TTL
,
Some
(
prune_config
));
let
worker
=
WorkerWithDpRank
::
from_worker_id
(
42
);
// Insert 5 sequences (5 blocks total, at max_tree_size but not exceeding)
for
i
in
0
..
5
{
let
tokens
:
Vec
<
u32
>
=
vec!
[
i
*
10
,
i
*
10
+
1
,
i
*
10
+
2
,
i
*
10
+
3
];
indexer
.process_routing_decision_for_request
(
&
tokens
,
worker
)
.await
.unwrap
();
time
::
sleep
(
Duration
::
from_millis
(
1
))
.await
;
// Ensure different timestamps
}
// Verify all 5 blocks are present (no pruning yet)
for
i
in
0
..
5
{
let
tokens
:
Vec
<
u32
>
=
vec!
[
i
*
10
,
i
*
10
+
1
,
i
*
10
+
2
,
i
*
10
+
3
];
let
scores
=
indexer
.find_matches_for_request
(
&
tokens
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.get
(
&
worker
)
.copied
(),
Some
(
1
),
"Block {} should be present before threshold is exceeded"
,
i
);
}
// Insert 6th block - this exceeds max_tree_size and should trigger reactive pruning
let
tokens
:
Vec
<
u32
>
=
vec!
[
50
,
51
,
52
,
53
];
indexer
.process_routing_decision_for_request
(
&
tokens
,
worker
)
.await
.unwrap
();
// Wait for pruning to complete
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
// After pruning, we will have exactly 2 blocks (5 * 0.4 = 2)
// The 2 newest blocks (i=4, i=5) will remain, oldest 4 blocks (i=0,1,2,3) will be pruned
// Verify that the 4 oldest blocks are pruned
for
i
in
0
..
4
{
let
tokens
:
Vec
<
u32
>
=
vec!
[
i
*
10
,
i
*
10
+
1
,
i
*
10
+
2
,
i
*
10
+
3
];
let
scores
=
indexer
.find_matches_for_request
(
&
tokens
)
.await
.unwrap
();
assert
!
(
scores
.scores
.get
(
&
worker
)
.copied
()
.unwrap_or
(
0
)
==
0
,
"Block {} should have been pruned but is still present"
,
i
);
}
// Verify the 2 newest blocks are present
for
i
in
4
..
6
{
let
tokens
:
Vec
<
u32
>
=
vec!
[
i
*
10
,
i
*
10
+
1
,
i
*
10
+
2
,
i
*
10
+
3
];
let
scores
=
indexer
.find_matches_for_request
(
&
tokens
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.get
(
&
worker
)
.copied
(),
Some
(
1
),
"Block {} should have been present but was pruned"
,
i
);
}
}
/// Test that re-inserting a key updates its position in the pruning queue.
#[tokio::test]
async
fn
test_prune_manager_prune_reinsertion_updates_position
()
{
const
TTL
:
Duration
=
Duration
::
from_secs
(
10
);
let
prune_config
=
PruneConfig
{
max_tree_size
:
5
,
prune_target_ratio
:
0.8
,
};
let
mut
pm
:
PruneManager
<
u32
>
=
PruneManager
::
new
(
TTL
,
50
,
Some
(
prune_config
));
// Insert keys
for
i
in
1
..=
10
{
pm
.insert
(
vec!
[
i
]);
time
::
sleep
(
Duration
::
from_millis
(
1
))
.await
;
}
// Re-insert key 1 (should move it to the back of the queue)
pm
.insert
(
vec!
[
1
]);
// Total: 10 unique keys. Trigger pruning: current_size = 10, target = 4, so prune 6 keys
// Order by expiry (oldest first): 2, 3, 4, 5, 6, 7, 8, 9, 10, 1 (re-inserted)
let
pruned
=
pm
.prune
(
10
)
.unwrap
();
assert_eq!
(
pruned
.len
(),
6
);
// The oldest keys (2-7) should be pruned
for
i
in
2
..=
7
{
assert
!
(
pruned
.contains
(
&
i
));
}
// The newest keys (8-10) should still be present
for
i
in
8
..=
10
{
assert
!
(
pm
.get_expiry
(
&
i
)
.is_some
());
}
// Key 1 should still be present (it was refreshed and is now near the end)
assert
!
(
pm
.get_expiry
(
&
1
)
.is_some
());
}
}
lib/llm/src/kv_router/indexer.rs
View file @
3ea22fcf
...
...
@@ -68,6 +68,9 @@ pub enum KvRouterError {
#[error(
"Indexer is dropped request"
)]
IndexerDroppedRequest
,
#[error(
"Prune operation failed: {0}"
)]
PruneFailed
(
String
),
}
/// Errors that can occur during KV Cache Event processing.
...
...
@@ -235,6 +238,8 @@ pub struct RadixTree {
lookup
:
HashMap
<
WorkerWithDpRank
,
HashMap
<
ExternalSequenceBlockHash
,
SharedRadixBlock
>>
,
/// The time buffer the radix tree should check when considering frequence of block accesses
expiration_duration
:
Option
<
Duration
>
,
/// The tree current size.
current_size
:
usize
,
}
impl
Default
for
RadixTree
{
...
...
@@ -254,6 +259,7 @@ impl RadixTree {
root
:
Rc
::
new
(
RefCell
::
new
(
RadixBlock
::
new
())),
lookup
:
HashMap
::
new
(),
expiration_duration
,
current_size
:
0
,
}
}
...
...
@@ -380,6 +386,9 @@ impl RadixTree {
.children
.insert
(
block_id
.tokens_hash
,
new_block
.clone
());
// increment the current size when creating a new block
self
.current_size
=
self
.current_size
.saturating_add
(
1
);
new_block
}
};
...
...
@@ -428,6 +437,9 @@ impl RadixTree {
if
guard
.workers
.is_empty
()
{
// if no workers are using this block, that is true for all children
guard
.children
.clear
();
// Decrement the current size when removing the last worker from a node
self
.current_size
=
self
.current_size
.saturating_sub
(
1
);
}
// remove the block from the lookup table
worker_lookup
.remove
(
&
block
);
...
...
@@ -460,6 +472,9 @@ impl RadixTree {
// If no workers are using this block, that is true for all children
if
block
.borrow
()
.workers
.is_empty
()
{
block
.borrow_mut
()
.children
.clear
();
// Decrement the current size when removing the last worker from a node
self
.current_size
=
self
.current_size
.saturating_sub
(
1
);
}
});
...
...
@@ -560,6 +575,10 @@ impl RadixTree {
events
}
pub
fn
current_size
(
&
self
)
->
usize
{
self
.current_size
}
}
/// Metrics for the KV Indexer.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment