Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
0fba01c2
Unverified
Commit
0fba01c2
authored
Dec 16, 2025
by
Yan Ru Pei
Committed by
GitHub
Dec 17, 2025
Browse files
fix: make gap detection work e2e (#4993)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
54636097
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1164 additions
and
1288 deletions
+1164
-1288
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+0
-1
lib/llm/src/kv_router/indexer.rs
lib/llm/src/kv_router/indexer.rs
+827
-1008
lib/llm/src/kv_router/publisher.rs
lib/llm/src/kv_router/publisher.rs
+27
-48
lib/llm/src/kv_router/scoring.rs
lib/llm/src/kv_router/scoring.rs
+0
-75
lib/llm/src/kv_router/subscriber.rs
lib/llm/src/kv_router/subscriber.rs
+66
-62
tests/conftest.py
tests/conftest.py
+35
-0
tests/router/common.py
tests/router/common.py
+142
-58
tests/router/test_router_e2e_with_mockers.py
tests/router/test_router_e2e_with_mockers.py
+67
-36
No files found.
lib/llm/src/kv_router.rs
View file @
0fba01c2
...
@@ -31,7 +31,6 @@ pub mod protocols;
...
@@ -31,7 +31,6 @@ pub mod protocols;
pub
mod
publisher
;
pub
mod
publisher
;
pub
mod
recorder
;
pub
mod
recorder
;
pub
mod
scheduler
;
pub
mod
scheduler
;
pub
mod
scoring
;
pub
mod
sequence
;
pub
mod
sequence
;
pub
mod
subscriber
;
pub
mod
subscriber
;
pub
mod
worker_query
;
pub
mod
worker_query
;
...
...
lib/llm/src/kv_router/indexer.rs
View file @
0fba01c2
...
@@ -814,13 +814,6 @@ pub struct GetWorkersRequest {
...
@@ -814,13 +814,6 @@ pub struct GetWorkersRequest {
pub
resp
:
oneshot
::
Sender
<
Vec
<
WorkerId
>>
,
pub
resp
:
oneshot
::
Sender
<
Vec
<
WorkerId
>>
,
}
}
/// A request to get the last received event ID per worker.
/// Used for fault tolerance recovery to determine which events to request from workers.
pub
struct
GetLastReceivedEventIdsRequest
{
/// Channel to send the last received event IDs per worker
pub
resp
:
oneshot
::
Sender
<
HashMap
<
WorkerId
,
u64
>>
,
}
#[async_trait]
#[async_trait]
pub
trait
KvIndexerInterface
{
pub
trait
KvIndexerInterface
{
/// Find matches for a given sequence of `LocalBlockHash`es.
/// Find matches for a given sequence of `LocalBlockHash`es.
...
@@ -926,8 +919,6 @@ pub struct KvIndexer {
...
@@ -926,8 +919,6 @@ pub struct KvIndexer {
dump_tx
:
mpsc
::
Sender
<
DumpRequest
>
,
dump_tx
:
mpsc
::
Sender
<
DumpRequest
>
,
/// A sender for routing decision requests.
/// A sender for routing decision requests.
routing_tx
:
mpsc
::
Sender
<
RoutingDecisionRequest
>
,
routing_tx
:
mpsc
::
Sender
<
RoutingDecisionRequest
>
,
/// A sender for getting last received event IDs (for fault tolerance recovery).
last_event_ids_tx
:
mpsc
::
Sender
<
GetLastReceivedEventIdsRequest
>
,
/// The size of the KV block this indexer can handle.
/// The size of the KV block this indexer can handle.
kv_block_size
:
u32
,
kv_block_size
:
u32
,
/// Reference counter for Clone-aware Drop.
/// Reference counter for Clone-aware Drop.
...
@@ -962,8 +953,6 @@ impl KvIndexer {
...
@@ -962,8 +953,6 @@ impl KvIndexer {
let
(
dump_tx
,
dump_rx
)
=
mpsc
::
channel
::
<
DumpRequest
>
(
16
);
let
(
dump_tx
,
dump_rx
)
=
mpsc
::
channel
::
<
DumpRequest
>
(
16
);
let
(
routing_tx
,
mut
routing_rx
)
=
mpsc
::
channel
::
<
RoutingDecisionRequest
>
(
2048
);
let
(
routing_tx
,
mut
routing_rx
)
=
mpsc
::
channel
::
<
RoutingDecisionRequest
>
(
2048
);
let
(
prune_tx
,
mut
prune_rx
)
=
mpsc
::
channel
::
<
()
>
(
1
);
let
(
prune_tx
,
mut
prune_rx
)
=
mpsc
::
channel
::
<
()
>
(
1
);
let
(
last_event_ids_tx
,
mut
last_event_ids_rx
)
=
mpsc
::
channel
::
<
GetLastReceivedEventIdsRequest
>
(
16
);
let
cancel_clone
=
token
.clone
();
let
cancel_clone
=
token
.clone
();
...
@@ -989,10 +978,6 @@ impl KvIndexer {
...
@@ -989,10 +978,6 @@ impl KvIndexer {
});
});
let
mut
event_id_counter
=
0u64
;
let
mut
event_id_counter
=
0u64
;
// Track last received event ID per worker (for fault tolerance recovery)
// Only used when enable_event_tracking is true
let
mut
last_received_event_id
:
HashMap
<
WorkerId
,
u64
>
=
HashMap
::
new
();
loop
{
loop
{
// Create a future that sleeps until the next expiration time
// Create a future that sleeps until the next expiration time
let
expiry_fut
=
if
let
Some
(
ref
pm
)
=
prune_manager
let
expiry_fut
=
if
let
Some
(
ref
pm
)
=
prune_manager
...
@@ -1019,10 +1004,6 @@ impl KvIndexer {
...
@@ -1019,10 +1004,6 @@ impl KvIndexer {
let
_
=
get_workers_req
.resp
.send
(
workers
);
let
_
=
get_workers_req
.resp
.send
(
workers
);
}
}
Some
(
req
)
=
last_event_ids_rx
.recv
()
=>
{
let
_
=
req
.resp
.send
(
last_received_event_id
.clone
());
}
Some
(
_
)
=
prune_rx
.recv
()
=>
{
Some
(
_
)
=
prune_rx
.recv
()
=>
{
// Tree size-based pruning triggered
// Tree size-based pruning triggered
let
Some
(
ref
mut
pm
)
=
prune_manager
else
{
continue
};
let
Some
(
ref
mut
pm
)
=
prune_manager
else
{
continue
};
...
@@ -1045,33 +1026,6 @@ impl KvIndexer {
...
@@ -1045,33 +1026,6 @@ impl KvIndexer {
}
}
Some
(
event
)
=
event_rx
.recv
()
=>
{
Some
(
event
)
=
event_rx
.recv
()
=>
{
// Track last received event ID per worker
// Check for gaps before updating the last received ID
// TODO should this trigger a recovery event?
let
last_id
=
*
last_received_event_id
.get
(
&
event
.worker_id
)
.unwrap_or
(
&
0
);
let
incoming_id
=
event
.event.event_id
;
// Detect gap: if incoming ID is more than 1 greater than last received
if
incoming_id
>
last_id
+
1
&&
last_id
>
0
{
let
gap_start
=
last_id
+
1
;
let
gap_end
=
incoming_id
-
1
;
tracing
::
warn!
(
worker_id
=
event
.worker_id
,
gap_start
,
gap_end
,
gap_size
=
gap_end
-
gap_start
+
1
,
"Event ID gap detected! Missed events [{}, {}].
\
If this is a global KvIndexer, within a KvRouter context,
consider calling KvRouter::query_worker_local_kv() to potentially recover worker-stored events."
,
gap_start
,
gap_end
,
);
}
// Update last received event ID (use max to handle out-of-order events)
let
entry
=
last_received_event_id
.entry
(
event
.worker_id
)
.or_insert
(
0
);
*
entry
=
(
*
entry
)
.max
(
event
.event.event_id
);
let
event_type
=
KvIndexerMetrics
::
get_event_type
(
&
event
.event.data
);
let
event_type
=
KvIndexerMetrics
::
get_event_type
(
&
event
.event.data
);
let
result
=
trie
.apply_event
(
event
.clone
());
let
result
=
trie
.apply_event
(
event
.clone
());
let
result_is_ok
=
result
.is_ok
();
let
result_is_ok
=
result
.is_ok
();
...
@@ -1200,7 +1154,6 @@ impl KvIndexer {
...
@@ -1200,7 +1154,6 @@ impl KvIndexer {
get_workers_tx
,
get_workers_tx
,
dump_tx
,
dump_tx
,
routing_tx
,
routing_tx
,
last_event_ids_tx
,
kv_block_size
,
kv_block_size
,
_
ref_count
:
Arc
::
new
(()),
_
ref_count
:
Arc
::
new
(()),
}
}
...
@@ -1253,48 +1206,6 @@ impl KvIndexer {
...
@@ -1253,48 +1206,6 @@ impl KvIndexer {
pub
fn
get_workers_sender
(
&
self
)
->
mpsc
::
Sender
<
GetWorkersRequest
>
{
pub
fn
get_workers_sender
(
&
self
)
->
mpsc
::
Sender
<
GetWorkersRequest
>
{
self
.get_workers_tx
.clone
()
self
.get_workers_tx
.clone
()
}
}
/// Get a sender for last received event IDs requests.
///
/// ### Returns
///
/// A `mpsc::Sender` for `GetLastReceivedEventIdsRequest`s.
pub
fn
last_event_ids_sender
(
&
self
)
->
mpsc
::
Sender
<
GetLastReceivedEventIdsRequest
>
{
self
.last_event_ids_tx
.clone
()
}
/// Get the last received event ID for each worker.
///
/// This method is used for **fault tolerance recovery** when the router needs to
/// catch up on missed events after a disconnect. By tracking the last event ID
/// received from each worker, the router can query workers for events starting
/// from `last_id + 1` to recover missed state.
///
/// **Note**: This method is intdned for the global `KvIndexer` used by routers,
/// not on `LocalKvIndexer` (worker-side) or `KvIndexerSharded`.
///
/// ### Returns
///
/// A `HashMap` mapping worker IDs to their last received event ID.
///
pub
async
fn
get_last_received_event_ids
(
&
self
,
)
->
Result
<
HashMap
<
WorkerId
,
u64
>
,
KvRouterError
>
{
let
(
resp_tx
,
resp_rx
)
=
oneshot
::
channel
();
let
req
=
GetLastReceivedEventIdsRequest
{
resp
:
resp_tx
};
if
let
Err
(
e
)
=
self
.last_event_ids_tx
.send
(
req
)
.await
{
tracing
::
error!
(
"Failed to send last event IDs request: {:?}; the indexer maybe offline"
,
e
);
return
Err
(
KvRouterError
::
IndexerOffline
);
}
resp_rx
.await
.map_err
(|
_
|
KvRouterError
::
IndexerDroppedRequest
)
}
}
}
#[async_trait]
#[async_trait]
...
@@ -1571,7 +1482,7 @@ impl LocalKvIndexer {
...
@@ -1571,7 +1482,7 @@ impl LocalKvIndexer {
"Non-consecutive KV event id; buffer may have gaps"
"Non-consecutive KV event id; buffer may have gaps"
);
);
}
}
tracing
::
info
!
(
tracing
::
debug
!
(
"Recorded event {:?} in buffer, now size is {}"
,
"Recorded event {:?} in buffer, now size is {}"
,
event
,
event
,
buffer
.len
()
buffer
.len
()
...
@@ -1640,350 +1551,104 @@ impl LocalKvIndexer {
...
@@ -1640,350 +1551,104 @@ impl LocalKvIndexer {
}
}
}
}
#[cfg(test)]
// Implement KvIndexerInterface by delegating to the underlying indexer
mod
local_kv_indexer_tests
{
#[async_trait]
use
super
::
*
;
impl
KvIndexerInterface
for
LocalKvIndexer
{
async
fn
find_matches
(
fn
make_indexer_with_events
(
ids
:
&
[
u64
])
->
LocalKvIndexer
{
&
self
,
let
indexer
=
LocalKvIndexer
::
new
(
sequence
:
Vec
<
LocalBlockHash
>
,
CancellationToken
::
new
(),
)
->
Result
<
OverlapScores
,
KvRouterError
>
{
4
,
self
.indexer
.find_matches
(
sequence
)
.await
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
()),
32
,
);
{
let
mut
buffer
=
indexer
.event_buffer
.lock
()
.unwrap
();
for
&
id
in
ids
{
buffer
.push_back
(
RouterEvent
::
new
(
0
,
KvCacheEvent
{
event_id
:
id
,
data
:
KvCacheEventData
::
Cleared
,
dp_rank
:
0
,
},
));
}
}
indexer
}
}
#[tokio::test]
async
fn
find_matches_for_request
(
async
fn
returns_slice_within_range
()
{
&
self
,
let
indexer
=
make_indexer_with_events
(
&
[
1
,
2
,
3
,
4
,
5
]);
tokens
:
&
[
u32
],
)
->
Result
<
OverlapScores
,
KvRouterError
>
{
// Helper to extract events from response
self
.indexer
.find_matches_for_request
(
tokens
)
.await
let
extract_events
=
|
resp
:
WorkerKvQueryResponse
|
->
Vec
<
RouterEvent
>
{
match
resp
{
WorkerKvQueryResponse
::
Events
(
e
)
=>
e
,
WorkerKvQueryResponse
::
TreeDump
(
e
)
=>
e
,
_
=>
panic!
(
"Unexpected response type"
),
}
}
};
let
get_ids
=
|
events
:
Vec
<
RouterEvent
>
|
->
Vec
<
u64
>
{
events
.iter
()
.map
(|
e
|
e
.event.event_id
)
.collect
()
};
// Test get_events_in_id_range (buffer queries)
// Range is [start, end] inclusive
let
result
=
indexer
.get_events_in_id_range
(
Some
(
2
),
Some
(
4
))
.await
;
let
ids
=
get_ids
(
extract_events
(
result
));
assert_eq!
(
ids
,
vec!
[
2
,
3
,
4
]);
// inclusive range [2, 4]
let
result
=
indexer
.get_events_in_id_range
(
Some
(
2
),
Some
(
6
))
.await
;
async
fn
apply_event
(
&
mut
self
,
event
:
RouterEvent
)
{
let
ids
=
get_ids
(
extract_events
(
result
));
// Use the buffering version
assert_eq!
(
ids
,
vec!
[
2
,
3
,
4
,
5
]);
// clamp end to buffer max
let
_
=
self
.apply_event_with_buffer
(
event
)
.await
;
// start_id=0 is before buffer (first is 1), so should trigger tree dump
let
result
=
indexer
.get_events_in_id_range
(
Some
(
0
),
Some
(
4
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TreeDump
(
_
)));
let
result
=
indexer
.get_events_in_id_range
(
Some
(
3
),
Some
(
3
))
.await
;
let
ids
=
get_ids
(
extract_events
(
result
));
assert_eq!
(
ids
,
vec!
[
3
]);
// single element when start == end
// Invalid range: end < start
let
result
=
indexer
.get_events_in_id_range
(
Some
(
5
),
Some
(
2
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
InvalidRange
{
..
}));
}
}
#[tokio::test]
async
fn
remove_worker
(
&
mut
self
,
worker
:
WorkerId
)
{
async
fn
test_get_events_in_id_range_all_cases
()
{
let
_
=
self
.indexer
.remove_worker_sender
()
.send
(
worker
)
.await
;
use
crate
::
kv_router
::
protocols
::{
ExternalSequenceBlockHash
,
LocalBlockHash
};
}
// Create indexer with small buffer (5 events max)
fn
shutdown
(
&
mut
self
)
{
// This way older events will only be in the tree, not the buffer
// Note: Since indexer is Arc<KvIndexer>, we can't call mutable methods directly.
let
indexer
=
LocalKvIndexer
::
new
(
// The indexer will be shut down when the CancellationToken is cancelled
CancellationToken
::
new
(),
// or when the last Arc reference is dropped.
4
,
// block_size
}
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
()),
5
,
// max_buffer_size - only keeps 5 most recent events
);
// Helper to create a test event
async
fn
dump_events
(
&
self
)
->
Result
<
Vec
<
RouterEvent
>
,
KvRouterError
>
{
let
make_event
=
|
id
:
u64
|
{
self
.indexer
.dump_events
()
.await
RouterEvent
::
new
(
}
0
,
// worker_id
KvCacheEvent
{
event_id
:
id
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
vec!
[
KvCacheStoredBlockData
{
block_hash
:
ExternalSequenceBlockHash
(
id
*
100
),
tokens_hash
:
LocalBlockHash
(
id
*
200
),
}],
}),
dp_rank
:
0
,
},
)
};
// Add 10 events (IDs 5-14)
async
fn
process_routing_decision
(
// Buffer will only keep the last 5: events 10-14
&
self
,
// Tree will have all blocks
worker
:
WorkerWithDpRank
,
for
id
in
5
..
15
{
local_hashes
:
Vec
<
LocalBlockHash
>
,
indexer
sequence_hashes
:
Vec
<
SequenceHash
>
,
.apply_event_with_buffer
(
make_event
(
id
))
)
->
Result
<
(),
KvRouterError
>
{
// TODO I guess the local kvindexers have little use for this method?
// Keeping it here now to implement the trait fully
self
.indexer
.process_routing_decision
(
worker
,
local_hashes
,
sequence_hashes
)
.await
.await
.unwrap
();
}
}
// Wait for events to be processed by the tree
async
fn
process_routing_decision_for_request
(
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
100
))
.await
;
&
self
,
tokens
:
&
[
u32
],
// Helper to extract events from response
worker
:
WorkerWithDpRank
,
let
extract_events
=
|
resp
:
WorkerKvQueryResponse
|
->
Vec
<
RouterEvent
>
{
)
->
Result
<
(),
KvRouterError
>
{
match
resp
{
// TODO I guess the local kvindexers have little use for this method?
WorkerKvQueryResponse
::
Events
(
e
)
=>
e
,
// Keeping it here now to implement the trait fully
WorkerKvQueryResponse
::
TreeDump
(
e
)
=>
e
,
self
.indexer
_
=>
panic!
(
"Unexpected response type: {:?}"
,
resp
),
.process_routing_decision_for_request
(
tokens
,
worker
)
.await
}
}
};
}
// Helper to extract event IDs from result
let
get_ids
=
|
events
:
Vec
<
RouterEvent
>
|
->
Vec
<
u64
>
{
events
.iter
()
.map
(|
e
|
e
.event.event_id
)
.collect
()
};
// Verify buffer state: should have events 10-14 (last 5)
let
buffer_events
=
indexer
.get_all_events_in_buffer
();
assert_eq!
(
get_ids
(
buffer_events
),
vec!
[
10
,
11
,
12
,
13
,
14
],
"Buffer should have events 10-14"
);
// ========== BUFFER PATH TESTS (start_id >= first_buffered) ==========
#[derive(Debug,
Clone)]
// Range is [start, end] inclusive
pub
struct
ShardedMatchRequest
{
sequence
:
Vec
<
LocalBlockHash
>
,
early_exit
:
bool
,
resp
:
mpsc
::
Sender
<
OverlapScores
>
,
}
// Test: start_id within buffer, no end
/// A sharded KV Indexer that partitions the RadixTree across multiple independent shards.
let
result
=
indexer
.get_events_in_id_range
(
Some
(
11
),
None
)
.await
;
///
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
Events
(
_
)));
/// ## Sharding Strategy
assert_eq!
(
/// - Each worker is **permanently assigned** to a single shard on first event
get_ids
(
extract_events
(
result
)),
/// - All KV blocks from a worker exist only in that worker's assigned shard
vec!
[
11
,
12
,
13
,
14
],
/// - New workers are assigned to the shard with the fewest workers (load balancing)
"start_id=11 (in buffer) should return [11, 14]"
///
);
/// ## Operation
/// - **Events**: Routed directly to the worker's assigned shard
/// - **Match requests**: Broadcast to all shards (scatter-gather pattern)
/// - **Threading**: Each shard runs in its own thread with a single-threaded runtime
///
/// This design ensures no cross-shard synchronization for writes while enabling
/// parallel processing and better scalability.
pub
struct
KvIndexerSharded
{
/// A `CancellationToken` for managing shutdown.
cancel
:
CancellationToken
,
/// The size of the KV block this indexer can handle.
kv_block_size
:
u32
,
worker_assignments
:
HashMap
<
WorkerId
,
usize
>
,
worker_counts
:
Vec
<
usize
>
,
// Test: start_id at buffer boundary
event_tx
:
Vec
<
mpsc
::
Sender
<
RouterEvent
>>
,
let
result
=
indexer
.get_events_in_id_range
(
Some
(
10
),
None
)
.await
;
request_broadcast_tx
:
broadcast
::
Sender
<
ShardedMatchRequest
>
,
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
Events
(
_
)));
remove_worker_tx
:
Vec
<
mpsc
::
Sender
<
WorkerId
>>
,
assert_eq!
(
dump_tx
:
Vec
<
mpsc
::
Sender
<
DumpRequest
>>
,
get_ids
(
extract_events
(
result
)),
routing_tx
:
Vec
<
mpsc
::
Sender
<
RoutingDecisionRequest
>>
,
vec!
[
10
,
11
,
12
,
13
,
14
],
tasks
:
Vec
<
JoinHandle
<
()
>>
,
"start_id=10 (buffer start) should return [10, 14]"
}
);
// Test: both start and end within buffer (inclusive)
let
result
=
indexer
.get_events_in_id_range
(
Some
(
11
),
Some
(
13
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
Events
(
_
)));
assert_eq!
(
get_ids
(
extract_events
(
result
)),
vec!
[
11
,
12
,
13
],
"range [11, 13] inclusive should return 3 events"
);
let
result
=
indexer
.get_events_in_id_range
(
Some
(
10
),
Some
(
14
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
Events
(
_
)));
assert_eq!
(
get_ids
(
extract_events
(
result
)),
vec!
[
10
,
11
,
12
,
13
,
14
],
"range [10, 14] should return all buffer events"
);
// ========== TREE DUMP PATH TESTS (range extends before buffer) ==========
// Note: Tree dumps return synthetic 0-indexed event IDs, so we just check
// that we get events back (the IDs won't match original IDs)
// Test: (None, None) dumps entire tree
let
result
=
indexer
.get_events_in_id_range
(
None
,
None
)
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TreeDump
(
_
)));
assert_eq!
(
extract_events
(
result
)
.len
(),
10
,
"(None, None) should dump entire tree (10 events)"
);
// Test: (None, Some(_)) dumps entire tree
let
result
=
indexer
.get_events_in_id_range
(
None
,
Some
(
8
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TreeDump
(
_
)));
assert_eq!
(
extract_events
(
result
)
.len
(),
10
,
"(None, Some(_)) dumps entire tree - end_id is ignored for tree dumps"
);
// Test: start_id before buffer triggers tree dump
let
result
=
indexer
.get_events_in_id_range
(
Some
(
7
),
None
)
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TreeDump
(
_
)));
assert_eq!
(
extract_events
(
result
)
.len
(),
10
,
"start_id=7 (before buffer) should dump entire tree"
);
let
result
=
indexer
.get_events_in_id_range
(
Some
(
5
),
Some
(
12
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TreeDump
(
_
)));
assert_eq!
(
extract_events
(
result
)
.len
(),
10
,
"range [5, 12] extending before buffer should dump entire tree"
);
// ========== EDGE CASES ==========
// Single element when start == end (inclusive range)
let
result
=
indexer
.get_events_in_id_range
(
Some
(
12
),
Some
(
12
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
Events
(
_
)));
assert_eq!
(
get_ids
(
extract_events
(
result
)),
vec!
[
12
],
"start == end should return single event"
);
// InvalidRange when start > end
let
result
=
indexer
.get_events_in_id_range
(
Some
(
15
),
Some
(
10
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
InvalidRange
{
..
}),
"start > end should return InvalidRange"
);
// TooNew when start_id is beyond buffer
let
result
=
indexer
.get_events_in_id_range
(
Some
(
100
),
Some
(
200
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TooNew
{
..
}),
"start_id beyond buffer should return TooNew"
);
// Request with end beyond buffer but valid start -> buffer returns what it has
let
result
=
indexer
.get_events_in_id_range
(
Some
(
12
),
Some
(
100
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
Events
(
_
)));
assert_eq!
(
get_ids
(
extract_events
(
result
)),
vec!
[
12
,
13
,
14
],
"range with end beyond buffer should return available buffer events"
);
}
}
// Implement KvIndexerInterface by delegating to the underlying indexer
#[async_trait]
impl
KvIndexerInterface
for
LocalKvIndexer
{
async
fn
find_matches
(
&
self
,
sequence
:
Vec
<
LocalBlockHash
>
,
)
->
Result
<
OverlapScores
,
KvRouterError
>
{
self
.indexer
.find_matches
(
sequence
)
.await
}
async
fn
find_matches_for_request
(
&
self
,
tokens
:
&
[
u32
],
)
->
Result
<
OverlapScores
,
KvRouterError
>
{
self
.indexer
.find_matches_for_request
(
tokens
)
.await
}
async
fn
apply_event
(
&
mut
self
,
event
:
RouterEvent
)
{
// Use the buffering version
let
_
=
self
.apply_event_with_buffer
(
event
)
.await
;
}
async
fn
remove_worker
(
&
mut
self
,
worker
:
WorkerId
)
{
let
_
=
self
.indexer
.remove_worker_sender
()
.send
(
worker
)
.await
;
}
fn
shutdown
(
&
mut
self
)
{
// Note: Since indexer is Arc<KvIndexer>, we can't call mutable methods directly.
// The indexer will be shut down when the CancellationToken is cancelled
// or when the last Arc reference is dropped.
}
async
fn
dump_events
(
&
self
)
->
Result
<
Vec
<
RouterEvent
>
,
KvRouterError
>
{
self
.indexer
.dump_events
()
.await
}
async
fn
process_routing_decision
(
&
self
,
worker
:
WorkerWithDpRank
,
local_hashes
:
Vec
<
LocalBlockHash
>
,
sequence_hashes
:
Vec
<
SequenceHash
>
,
)
->
Result
<
(),
KvRouterError
>
{
// TODO I guess the local kvindexers have little use for this method?
// Keeping it here now to implement the trait fully
self
.indexer
.process_routing_decision
(
worker
,
local_hashes
,
sequence_hashes
)
.await
}
async
fn
process_routing_decision_for_request
(
&
self
,
tokens
:
&
[
u32
],
worker
:
WorkerWithDpRank
,
)
->
Result
<
(),
KvRouterError
>
{
// TODO I guess the local kvindexers have little use for this method?
// Keeping it here now to implement the trait fully
self
.indexer
.process_routing_decision_for_request
(
tokens
,
worker
)
.await
}
}
#[derive(Debug,
Clone)]
pub
struct
ShardedMatchRequest
{
sequence
:
Vec
<
LocalBlockHash
>
,
early_exit
:
bool
,
resp
:
mpsc
::
Sender
<
OverlapScores
>
,
}
/// A sharded KV Indexer that partitions the RadixTree across multiple independent shards.
///
/// ## Sharding Strategy
/// - Each worker is **permanently assigned** to a single shard on first event
/// - All KV blocks from a worker exist only in that worker's assigned shard
/// - New workers are assigned to the shard with the fewest workers (load balancing)
///
/// ## Operation
/// - **Events**: Routed directly to the worker's assigned shard
/// - **Match requests**: Broadcast to all shards (scatter-gather pattern)
/// - **Threading**: Each shard runs in its own thread with a single-threaded runtime
///
/// This design ensures no cross-shard synchronization for writes while enabling
/// parallel processing and better scalability.
pub
struct
KvIndexerSharded
{
/// A `CancellationToken` for managing shutdown.
cancel
:
CancellationToken
,
/// The size of the KV block this indexer can handle.
kv_block_size
:
u32
,
worker_assignments
:
HashMap
<
WorkerId
,
usize
>
,
worker_counts
:
Vec
<
usize
>
,
event_tx
:
Vec
<
mpsc
::
Sender
<
RouterEvent
>>
,
request_broadcast_tx
:
broadcast
::
Sender
<
ShardedMatchRequest
>
,
remove_worker_tx
:
Vec
<
mpsc
::
Sender
<
WorkerId
>>
,
dump_tx
:
Vec
<
mpsc
::
Sender
<
DumpRequest
>>
,
routing_tx
:
Vec
<
mpsc
::
Sender
<
RoutingDecisionRequest
>>
,
tasks
:
Vec
<
JoinHandle
<
()
>>
,
}
impl
KvIndexerSharded
{
impl
KvIndexerSharded
{
/// Create a new `KvIndexerSharded`.
/// Create a new `KvIndexerSharded`.
...
@@ -2429,8 +2094,8 @@ impl Drop for KvIndexerSharded {
...
@@ -2429,8 +2094,8 @@ impl Drop for KvIndexerSharded {
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
super
::
*
;
use
super
::
*
;
use
crate
::
kv_router
::
protocols
::{
ExternalSequenceBlockHash
,
LocalBlockHash
};
use
rstest
::
rstest
;
use
rstest
::
rstest
;
use
rstest_reuse
::{
self
,
*
};
use
rstest_reuse
::{
self
,
*
};
use
tokio
::
time
;
use
tokio
::
time
;
...
@@ -3640,46 +3305,245 @@ mod tests {
...
@@ -3640,46 +3305,245 @@ mod tests {
assert
!
(
result
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)));
assert
!
(
result
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_1
)));
assert
!
(
result
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
)));
assert
!
(
result
.contains_key
(
&
WorkerWithDpRank
::
from_worker_id
(
worker_2
)));
}
}
}
#[cfg(test)]
// LocalKvIndexer tests
mod
tests_local_indexer
{
fn
make_indexer_with_events
(
ids
:
&
[
u64
])
->
LocalKvIndexer
{
use
super
::
*
;
let
indexer
=
LocalKvIndexer
::
new
(
use
crate
::
kv_router
::
protocols
::{
ExternalSequenceBlockHash
,
LocalBlockHash
};
CancellationToken
::
new
(),
use
tokio
::
time
;
4
,
use
tokio_util
::
sync
::
CancellationToken
;
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
()),
32
,
);
{
let
mut
buffer
=
indexer
.event_buffer
.lock
()
.unwrap
();
for
&
id
in
ids
{
buffer
.push_back
(
RouterEvent
::
new
(
0
,
KvCacheEvent
{
event_id
:
id
,
data
:
KvCacheEventData
::
Cleared
,
dp_rank
:
0
,
},
));
}
}
indexer
}
fn
setup
()
{
#[tokio::test]
dynamo_runtime
::
logging
::
init
();
async
fn
returns_slice_within_range
()
{
let
indexer
=
make_indexer_with_events
(
&
[
1
,
2
,
3
,
4
,
5
]);
// Helper to extract events from response
let
extract_events
=
|
resp
:
WorkerKvQueryResponse
|
->
Vec
<
RouterEvent
>
{
match
resp
{
WorkerKvQueryResponse
::
Events
(
e
)
=>
e
,
WorkerKvQueryResponse
::
TreeDump
(
e
)
=>
e
,
_
=>
panic!
(
"Unexpected response type"
),
}
}
};
fn
make_blocks
(
hashes
:
Vec
<
u64
>
)
->
Vec
<
KvCacheStoredBlockData
>
{
let
get_ids
=
|
events
:
Vec
<
RouterEvent
>
|
->
Vec
<
u64
>
{
hashes
events
.iter
()
.map
(|
e
|
e
.event.event_id
)
.collect
()
.iter
()
};
.map
(|
i
|
KvCacheStoredBlockData
{
tokens_hash
:
LocalBlockHash
(
*
i
),
// Test get_events_in_id_range (buffer queries)
block_hash
:
ExternalSequenceBlockHash
(
*
i
*
100
),
// Range is [start, end] inclusive
})
let
result
=
indexer
.get_events_in_id_range
(
Some
(
2
),
Some
(
4
))
.await
;
.collect
()
let
ids
=
get_ids
(
extract_events
(
result
));
assert_eq!
(
ids
,
vec!
[
2
,
3
,
4
]);
// inclusive range [2, 4]
let
result
=
indexer
.get_events_in_id_range
(
Some
(
2
),
Some
(
6
))
.await
;
let
ids
=
get_ids
(
extract_events
(
result
));
assert_eq!
(
ids
,
vec!
[
2
,
3
,
4
,
5
]);
// clamp end to buffer max
// start_id=0 is before buffer (first is 1), so should trigger tree dump
let
result
=
indexer
.get_events_in_id_range
(
Some
(
0
),
Some
(
4
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TreeDump
(
_
)));
let
result
=
indexer
.get_events_in_id_range
(
Some
(
3
),
Some
(
3
))
.await
;
let
ids
=
get_ids
(
extract_events
(
result
));
assert_eq!
(
ids
,
vec!
[
3
]);
// single element when start == end
// Invalid range: end < start
let
result
=
indexer
.get_events_in_id_range
(
Some
(
5
),
Some
(
2
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
InvalidRange
{
..
}));
}
}
fn
create_store_event
(
#[tokio::test]
worker_id
:
WorkerId
,
async
fn
test_get_events_in_id_range_all_cases
()
{
event_id
:
u64
,
// Create indexer with small buffer (5 events max)
hashes
:
Vec
<
u64
>
,
// This way older events will only be in the tree, not the buffer
parent
:
Option
<
ExternalSequenceBlockHash
>
,
let
indexer
=
LocalKvIndexer
::
new
(
)
->
RouterEvent
{
CancellationToken
::
new
(),
RouterEvent
{
4
,
// block_size
worker_id
,
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
()),
event
:
KvCacheEvent
{
5
,
// max_buffer_size - only keeps 5 most recent events
event_id
,
);
// Helper to create a test event
let
make_event
=
|
id
:
u64
|
{
RouterEvent
::
new
(
0
,
// worker_id
KvCacheEvent
{
event_id
:
id
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
parent
,
parent_hash
:
None
,
blocks
:
make_blocks
(
hashes
),
blocks
:
vec!
[
KvCacheStoredBlockData
{
block_hash
:
ExternalSequenceBlockHash
(
id
*
100
),
tokens_hash
:
LocalBlockHash
(
id
*
200
),
}],
}),
}),
dp_rank
:
0
,
dp_rank
:
0
,
},
},
)
};
// Add 10 events (IDs 5-14)
// Buffer will only keep the last 5: events 10-14
// Tree will have all blocks
for
id
in
5
..
15
{
indexer
.apply_event_with_buffer
(
make_event
(
id
))
.await
.unwrap
();
}
}
// Wait for events to be processed by the tree
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
100
))
.await
;
// Helper to extract events from response
let
extract_events
=
|
resp
:
WorkerKvQueryResponse
|
->
Vec
<
RouterEvent
>
{
match
resp
{
WorkerKvQueryResponse
::
Events
(
e
)
=>
e
,
WorkerKvQueryResponse
::
TreeDump
(
e
)
=>
e
,
_
=>
panic!
(
"Unexpected response type: {:?}"
,
resp
),
}
};
// Helper to extract event IDs from result
let
get_ids
=
|
events
:
Vec
<
RouterEvent
>
|
->
Vec
<
u64
>
{
events
.iter
()
.map
(|
e
|
e
.event.event_id
)
.collect
()
};
// Verify buffer state: should have events 10-14 (last 5)
let
buffer_events
=
indexer
.get_all_events_in_buffer
();
assert_eq!
(
get_ids
(
buffer_events
),
vec!
[
10
,
11
,
12
,
13
,
14
],
"Buffer should have events 10-14"
);
// ========== BUFFER PATH TESTS (start_id >= first_buffered) ==========
// Range is [start, end] inclusive
// Test: start_id within buffer, no end
let
result
=
indexer
.get_events_in_id_range
(
Some
(
11
),
None
)
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
Events
(
_
)));
assert_eq!
(
get_ids
(
extract_events
(
result
)),
vec!
[
11
,
12
,
13
,
14
],
"start_id=11 (in buffer) should return [11, 14]"
);
// Test: start_id at buffer boundary
let
result
=
indexer
.get_events_in_id_range
(
Some
(
10
),
None
)
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
Events
(
_
)));
assert_eq!
(
get_ids
(
extract_events
(
result
)),
vec!
[
10
,
11
,
12
,
13
,
14
],
"start_id=10 (buffer start) should return [10, 14]"
);
// Test: both start and end within buffer (inclusive)
let
result
=
indexer
.get_events_in_id_range
(
Some
(
11
),
Some
(
13
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
Events
(
_
)));
assert_eq!
(
get_ids
(
extract_events
(
result
)),
vec!
[
11
,
12
,
13
],
"range [11, 13] inclusive should return 3 events"
);
let
result
=
indexer
.get_events_in_id_range
(
Some
(
10
),
Some
(
14
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
Events
(
_
)));
assert_eq!
(
get_ids
(
extract_events
(
result
)),
vec!
[
10
,
11
,
12
,
13
,
14
],
"range [10, 14] should return all buffer events"
);
// ========== TREE DUMP PATH TESTS (range extends before buffer) ==========
// Note: Tree dumps return synthetic 0-indexed event IDs, so we just check
// that we get events back (the IDs won't match original IDs)
// Test: (None, None) dumps entire tree
let
result
=
indexer
.get_events_in_id_range
(
None
,
None
)
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TreeDump
(
_
)));
assert_eq!
(
extract_events
(
result
)
.len
(),
10
,
"(None, None) should dump entire tree (10 events)"
);
// Test: (None, Some(_)) dumps entire tree
let
result
=
indexer
.get_events_in_id_range
(
None
,
Some
(
8
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TreeDump
(
_
)));
assert_eq!
(
extract_events
(
result
)
.len
(),
10
,
"(None, Some(_)) dumps entire tree - end_id is ignored for tree dumps"
);
// Test: start_id before buffer triggers tree dump
let
result
=
indexer
.get_events_in_id_range
(
Some
(
7
),
None
)
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TreeDump
(
_
)));
assert_eq!
(
extract_events
(
result
)
.len
(),
10
,
"start_id=7 (before buffer) should dump entire tree"
);
let
result
=
indexer
.get_events_in_id_range
(
Some
(
5
),
Some
(
12
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TreeDump
(
_
)));
assert_eq!
(
extract_events
(
result
)
.len
(),
10
,
"range [5, 12] extending before buffer should dump entire tree"
);
// ========== EDGE CASES ==========
// Single element when start == end (inclusive range)
let
result
=
indexer
.get_events_in_id_range
(
Some
(
12
),
Some
(
12
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
Events
(
_
)));
assert_eq!
(
get_ids
(
extract_events
(
result
)),
vec!
[
12
],
"start == end should return single event"
);
// InvalidRange when start > end
let
result
=
indexer
.get_events_in_id_range
(
Some
(
15
),
Some
(
10
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
InvalidRange
{
..
}),
"start > end should return InvalidRange"
);
// TooNew when start_id is beyond buffer
let
result
=
indexer
.get_events_in_id_range
(
Some
(
100
),
Some
(
200
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TooNew
{
..
}),
"start_id beyond buffer should return TooNew"
);
// Request with end beyond buffer but valid start -> buffer returns what it has
let
result
=
indexer
.get_events_in_id_range
(
Some
(
12
),
Some
(
100
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
Events
(
_
)));
assert_eq!
(
get_ids
(
extract_events
(
result
)),
vec!
[
12
,
13
,
14
],
"range with end beyond buffer should return available buffer events"
);
}
}
#[tokio::test]
#[tokio::test]
...
@@ -3752,49 +3616,4 @@ mod tests_local_indexer {
...
@@ -3752,49 +3616,4 @@ mod tests_local_indexer {
_
=>
panic!
(
"Expected Stored event"
),
_
=>
panic!
(
"Expected Stored event"
),
}
}
}
}
#[tokio::test]
async
fn
test_gap_detection_per_worker
()
{
setup
();
let
token
=
CancellationToken
::
new
();
let
metrics
=
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
());
let
indexer
=
KvIndexer
::
new
(
token
.clone
(),
4
,
metrics
);
let
worker_a
:
WorkerId
=
100
;
let
worker_b
:
WorkerId
=
200
;
let
event_tx
=
indexer
.event_sender
();
// Worker A: events 1, 2, 3 (no gap)
for
id
in
1
..=
3
{
let
event
=
create_store_event
(
worker_a
,
id
,
vec!
[
id
],
None
);
event_tx
.send
(
event
)
.await
.unwrap
();
}
// Worker B: events 1, then 5 (gap of 2, 3, 4)
let
event_b1
=
create_store_event
(
worker_b
,
1
,
vec!
[
10
],
None
);
event_tx
.send
(
event_b1
)
.await
.unwrap
();
let
event_b5
=
create_store_event
(
worker_b
,
5
,
vec!
[
50
],
None
);
event_tx
.send
(
event_b5
)
.await
.unwrap
();
// Give time for events to be processed
time
::
sleep
(
Duration
::
from_millis
(
20
))
.await
;
// Verify each worker has correct last_received_event_id
let
last_ids
=
indexer
.get_last_received_event_ids
()
.await
.unwrap
();
assert_eq!
(
last_ids
.get
(
&
worker_a
),
Some
(
&
3
),
"Worker A should have last_id = 3 (no gap)"
);
assert_eq!
(
last_ids
.get
(
&
worker_b
),
Some
(
&
5
),
"Worker B should have last_id = 5 (despite gap)"
);
// Cleanup
token
.cancel
();
}
}
}
lib/llm/src/kv_router/publisher.rs
View file @
0fba01c2
...
@@ -36,6 +36,12 @@ use crate::kv_router::{
...
@@ -36,6 +36,12 @@ use crate::kv_router::{
};
};
use
dynamo_runtime
::
config
::
environment_names
::
nats
as
env_nats
;
use
dynamo_runtime
::
config
::
environment_names
::
nats
as
env_nats
;
// Error handling configuration for ZMQ operations
const
INITIAL_BACKOFF_MS
:
u64
=
10
;
const
MAX_BACKOFF_MS
:
u64
=
5000
;
const
MAX_CONSECUTIVE_ERRORS
:
u32
=
10
;
const
MAX_BACKOFF_EXPONENT
:
u32
=
8
;
// Cap at 2^8 = 256x multiplier to prevent overflow
// -------------------------------------------------------------------------
// -------------------------------------------------------------------------
// KV Event Publishers -----------------------------------------------------
// KV Event Publishers -----------------------------------------------------
// -------------------------------------------------------------------------
// -------------------------------------------------------------------------
...
@@ -125,15 +131,14 @@ impl KvEventPublisher {
...
@@ -125,15 +131,14 @@ impl KvEventPublisher {
// Infer worker_id from component's connection
// Infer worker_id from component's connection
let
worker_id
=
component
.drt
()
.connection_id
();
let
worker_id
=
component
.drt
()
.connection_id
();
let
component_name
=
component
.name
();
tracing
::
info!
(
tracing
::
info!
(
worker_id
,
"Initializing KvEventPublisher for worker {worker_id} in component {component_name}"
component
=
component
.name
(),
"Initializing KvEventPublisher for worker {worker_id} in component {component}"
);
);
if
enable_local_indexer
{
if
enable_local_indexer
{
tracing
::
info!
(
tracing
::
info!
(
"LocalKvIndexer enabled for worker {worker_id} in component {component}"
"LocalKvIndexer enabled for worker {worker_id} in component {component
_name
}"
);
);
}
}
...
@@ -321,27 +326,25 @@ async fn start_worker_kv_query_service(
...
@@ -321,27 +326,25 @@ async fn start_worker_kv_query_service(
let
mut
subscriber
=
match
component
.subscribe
(
&
subject
)
.await
{
let
mut
subscriber
=
match
component
.subscribe
(
&
subject
)
.await
{
Ok
(
sub
)
=>
sub
,
Ok
(
sub
)
=>
sub
,
Err
(
e
)
=>
{
Err
(
e
)
=>
{
tracing
::
error!
(
"Failed to subscribe to {}: {}"
,
subject
,
e
);
tracing
::
error!
(
return
;
// No ? because function doesn't return Result
"Query service failed to subscribe for worker {worker_id} on subject {subject}: {e}"
);
return
;
}
}
};
};
tracing
::
debug!
(
tracing
::
info!
(
"Query service listening on NATS for worker {worker_id} on subject {subject}"
);
"Query service on worker {} listening on NATS subject: {}"
,
worker_id
,
subject
);
// Receive query request from router, retrieve event(s) from LocalKvIndexer, return response
// Receive query request from router, retrieve event(s) from LocalKvIndexer, return response
loop
{
loop
{
tokio
::
select!
{
tokio
::
select!
{
_
=
cancellation_token
.cancelled
()
=>
{
_
=
cancellation_token
.cancelled
()
=>
{
tracing
::
info!
(
"
Router-Worker communication channel
received cancellation signal"
);
tracing
::
info!
(
"
Query service
received cancellation signal
for worker {worker_id}
"
);
break
;
break
;
}
}
msg
=
subscriber
.next
()
=>
{
msg
=
subscriber
.next
()
=>
{
let
Some
(
msg
)
=
msg
else
{
let
Some
(
msg
)
=
msg
else
{
tracing
::
debug!
(
"Router-Worker stream ended.
"
);
tracing
::
warn!
(
"Query service NATS stream ended for worker {worker_id}
"
);
break
;
break
;
};
};
...
@@ -349,12 +352,12 @@ async fn start_worker_kv_query_service(
...
@@ -349,12 +352,12 @@ async fn start_worker_kv_query_service(
let
request
:
WorkerKvQueryRequest
=
match
serde_json
::
from_slice
(
&
msg
.payload
)
{
let
request
:
WorkerKvQueryRequest
=
match
serde_json
::
from_slice
(
&
msg
.payload
)
{
Ok
(
request
)
=>
request
,
Ok
(
request
)
=>
request
,
Err
(
e
)
=>
{
Err
(
e
)
=>
{
tracing
::
error!
(
"Failed to deserialize WorkerKvQueryRequest: {}"
,
e
);
tracing
::
error!
(
"Failed to deserialize WorkerKvQueryRequest
for worker {worker_id}
: {
e
}"
);
continue
;
continue
;
}
}
};
};
tracing
::
debug!
(
"Received
WorkerKvQ
uery
R
equest
: {:?}"
,
request
);
tracing
::
debug!
(
"Received
q
uery
r
equest
for worker {worker_id}: {
request
:?}"
);
// Query events based on optional start/end ids
// Query events based on optional start/end ids
let
response
=
local_indexer
let
response
=
local_indexer
...
@@ -366,7 +369,7 @@ async fn start_worker_kv_query_service(
...
@@ -366,7 +369,7 @@ async fn start_worker_kv_query_service(
let
payload
=
match
serde_json
::
to_vec
(
&
response
)
{
let
payload
=
match
serde_json
::
to_vec
(
&
response
)
{
Ok
(
p
)
=>
p
,
Ok
(
p
)
=>
p
,
Err
(
e
)
=>
{
Err
(
e
)
=>
{
tracing
::
error!
(
"Failed to serialize response: {}"
,
e
);
tracing
::
error!
(
"Failed to serialize response
for worker {worker_id}
: {
e
}"
);
continue
;
continue
;
}
}
};
};
...
@@ -377,22 +380,14 @@ async fn start_worker_kv_query_service(
...
@@ -377,22 +380,14 @@ async fn start_worker_kv_query_service(
.kv_router_nats_publish
(
reply_subject
.to_string
(),
payload
.into
())
.kv_router_nats_publish
(
reply_subject
.to_string
(),
payload
.into
())
.await
.await
{
{
tracing
::
error!
(
"Failed to send reply: {}"
,
e
);
tracing
::
error!
(
"Failed to send reply
for worker {worker_id}
: {
e
}"
);
}
}
}
}
}
}
}
}
}
}
}
}
// Error handling configuration for ZMQ operations
const
INITIAL_BACKOFF_MS
:
u64
=
10
;
const
MAX_BACKOFF_MS
:
u64
=
5000
;
const
MAX_CONSECUTIVE_ERRORS
:
u32
=
10
;
const
MAX_BACKOFF_EXPONENT
:
u32
=
8
;
// Cap at 2^8 = 256x multiplier to prevent overflow
/// Calculate exponential backoff duration based on consecutive error count
/// Calculate exponential backoff duration based on consecutive error count
fn
calculate_backoff_ms
(
consecutive_errors
:
u32
)
->
u64
{
fn
calculate_backoff_ms
(
consecutive_errors
:
u32
)
->
u64
{
std
::
cmp
::
min
(
std
::
cmp
::
min
(
...
@@ -481,7 +476,7 @@ pub async fn start_zmq_listener(
...
@@ -481,7 +476,7 @@ pub async fn start_zmq_listener(
let
mut
frames
:
Vec
<
Vec
<
u8
>>
=
msg
.into_vec
()
.into_iter
()
.map
(|
frame
|
frame
.to_vec
())
.collect
();
let
mut
frames
:
Vec
<
Vec
<
u8
>>
=
msg
.into_vec
()
.into_iter
()
.map
(|
frame
|
frame
.to_vec
())
.collect
();
if
frames
.len
()
!=
3
{
if
frames
.len
()
!=
3
{
tracing
::
warn!
(
expected
=
3
,
actual
=%
frames
.len
(),
"Received unexpected ZMQ frame count
"
);
tracing
::
warn!
(
"Received unexpected ZMQ frame count
: expected 3, actual {}"
,
frames
.len
()
);
continue
;
continue
;
}
}
...
@@ -490,7 +485,7 @@ pub async fn start_zmq_listener(
...
@@ -490,7 +485,7 @@ pub async fn start_zmq_listener(
let
seq_bytes
=
frames
.pop
()
.unwrap
();
let
seq_bytes
=
frames
.pop
()
.unwrap
();
if
seq_bytes
.len
()
!=
8
{
if
seq_bytes
.len
()
!=
8
{
tracing
::
warn!
(
expected
=
8
,
actual
=%
seq_bytes
.len
(),
"Invalid sequence number byte length"
);
tracing
::
warn!
(
"Invalid sequence number byte length: expected 8, actual {}"
,
seq_bytes
.len
()
);
continue
;
continue
;
}
}
...
@@ -500,7 +495,7 @@ pub async fn start_zmq_listener(
...
@@ -500,7 +495,7 @@ pub async fn start_zmq_listener(
let
batch_result
=
rmps
::
from_slice
::
<
KvEventBatch
>
(
&
payload
);
let
batch_result
=
rmps
::
from_slice
::
<
KvEventBatch
>
(
&
payload
);
let
Ok
(
batch
)
=
batch_result
else
{
let
Ok
(
batch
)
=
batch_result
else
{
let
e
=
batch_result
.unwrap_err
();
let
e
=
batch_result
.unwrap_err
();
tracing
::
warn!
(
error
=%
e
,
"Failed to decode KVEventBatch msgpack"
);
tracing
::
warn!
(
"Failed to decode KVEventBatch msgpack
: {e}
"
);
continue
;
continue
;
};
};
...
@@ -1821,16 +1816,10 @@ mod tests_startup_helpers {
...
@@ -1821,16 +1816,10 @@ mod tests_startup_helpers {
"Router should only see 1 shared block (not the new block from event_2)"
"Router should only see 1 shared block (not the new block from event_2)"
);
);
// === STEP 4 & 5: Recovery - Query last received event IDs and fetch missed events ===
// === STEP 4 & 5: Recovery - Query worker's local indexer for missed events ===
// Step 4a: Router queries its last received event ID per worker
// In practice, the subscriber detects gaps and triggers recovery automatically.
let
last_ids
=
router_indexer
.get_last_received_event_ids
()
.await
.unwrap
();
// Here we simulate that by querying for events after event_id=1.
let
last_known_id
=
last_ids
.get
(
&
worker_1_id
)
.copied
()
.unwrap_or
(
0
);
let
last_known_id
=
1u64
;
// Router only received event_1
assert_eq!
(
last_known_id
,
1
,
"Router should have last_received_event_id = 1 for worker (only event_1 was forwarded)"
);
// Step 4b: Query worker's local indexer for events after last_known_id
let
response
=
local_indexer_1
let
response
=
local_indexer_1
.get_events_in_id_range
(
Some
(
last_known_id
+
1
),
None
)
.get_events_in_id_range
(
Some
(
last_known_id
+
1
),
None
)
.await
;
.await
;
...
@@ -1868,14 +1857,6 @@ mod tests_startup_helpers {
...
@@ -1868,14 +1857,6 @@ mod tests_startup_helpers {
"Router should now see both blocks after recovery"
"Router should now see both blocks after recovery"
);
);
// assert: Router's last_received_event_id is updated after recovery
let
last_ids_after
=
router_indexer
.get_last_received_event_ids
()
.await
.unwrap
();
assert_eq!
(
last_ids_after
.get
(
&
worker_1_id
),
Some
(
&
2
),
"Router should have last_received_event_id = 2 after recovery"
);
token
.cancel
();
token
.cancel
();
}
}
}
}
...
@@ -2043,8 +2024,6 @@ mod test_integration_publisher {
...
@@ -2043,8 +2024,6 @@ mod test_integration_publisher {
#[tokio::test]
#[tokio::test]
#[ignore]
// Mark as ignored as requested, because CI's integrations still don't have NATS
#[ignore]
// Mark as ignored as requested, because CI's integrations still don't have NATS
async
fn
test_kvstats_prometheus_gauge_updates
()
{
async
fn
test_kvstats_prometheus_gauge_updates
()
{
use
crate
::
kv_router
::
publisher
::
kvstats
;
// Test that publish() updates Prometheus gauges correctly using real Component
// Test that publish() updates Prometheus gauges correctly using real Component
let
publisher
=
WorkerMetricsPublisher
::
new
()
.unwrap
();
let
publisher
=
WorkerMetricsPublisher
::
new
()
.unwrap
();
...
...
lib/llm/src/kv_router/scoring.rs
deleted
100644 → 0
View file @
54636097
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Scoring functions for the KV router.
use
super
::
protocols
::
LoadMetrics
;
use
serde
::{
Deserialize
,
Serialize
};
use
std
::
collections
::
HashMap
;
/// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
/// is cleaned (not optional)
#[derive(Debug,
Clone,
Serialize,
Deserialize,
PartialEq)]
pub
struct
Endpoint
{
pub
name
:
String
,
pub
subject
:
String
,
pub
data
:
LoadMetrics
,
}
impl
Endpoint
{
pub
fn
worker_id
(
&
self
)
->
u64
{
u64
::
from_str_radix
(
self
.subject
.split
(
"-"
)
.last
()
.expect
(
"invalid subject"
)
.to_string
()
.as_str
(),
16
,
)
.expect
(
"invalid worker id"
)
}
}
#[derive(Debug,
Default,
Serialize,
Deserialize,
Clone,
PartialEq)]
pub
struct
ProcessedEndpoints
{
pub
endpoints
:
HashMap
<
u64
,
Endpoint
>
,
pub
load_avg
:
f64
,
pub
load_std
:
f64
,
}
impl
ProcessedEndpoints
{
pub
fn
new
(
endpoints
:
Vec
<
Endpoint
>
)
->
Self
{
// compute some basic statistics
let
load_values
:
Vec
<
f64
>
=
endpoints
.iter
()
.map
(|
endpoint
|
endpoint
.data
.kv_active_blocks
()
as
f64
)
.collect
();
let
load_avg
=
load_values
.iter
()
.copied
()
.sum
::
<
f64
>
()
/
load_values
.len
()
as
f64
;
let
variance
=
load_values
.iter
()
.map
(|
&
x
|
(
x
-
load_avg
)
.powi
(
2
))
.sum
::
<
f64
>
()
/
load_values
.len
()
as
f64
;
let
load_std
=
variance
.sqrt
();
let
endpoints
=
endpoints
.into_iter
()
.map
(|
e
|
(
e
.worker_id
(),
e
))
.collect
();
ProcessedEndpoints
{
endpoints
,
load_avg
,
load_std
,
}
}
pub
fn
worker_ids
(
&
self
)
->
Vec
<
u64
>
{
self
.endpoints
.keys
()
.copied
()
.collect
()
}
pub
fn
active_blocks
(
&
self
)
->
HashMap
<
u64
,
usize
>
{
self
.endpoints
.iter
()
.map
(|(
&
worker_id
,
endpoint
)|
(
worker_id
,
endpoint
.data
.kv_active_blocks
()
as
usize
))
.collect
()
}
}
lib/llm/src/kv_router/subscriber.rs
View file @
0fba01c2
...
@@ -32,6 +32,10 @@ const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10;
...
@@ -32,6 +32,10 @@ const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10;
const
CHECK_INTERVAL_BASE
:
Duration
=
Duration
::
from_secs
(
1
);
const
CHECK_INTERVAL_BASE
:
Duration
=
Duration
::
from_secs
(
1
);
const
CHECK_INTERVAL_JITTER_MS
:
i64
=
100
;
const
CHECK_INTERVAL_JITTER_MS
:
i64
=
100
;
// Worker query retry configuration
const
WORKER_QUERY_MAX_RETRIES
:
u32
=
8
;
const
WORKER_QUERY_INITIAL_BACKOFF_MS
:
u64
=
200
;
// ============================================================================
// ============================================================================
// Local KvIndexer-based Recovery
// Local KvIndexer-based Recovery
// ============================================================================
// ============================================================================
...
@@ -65,8 +69,7 @@ pub async fn recover_from_all_workers(
...
@@ -65,8 +69,7 @@ pub async fn recover_from_all_workers(
// Skip workers without local indexer
// Skip workers without local indexer
if
!
worker_query_client
.has_local_indexer
(
worker_id
)
{
if
!
worker_query_client
.has_local_indexer
(
worker_id
)
{
tracing
::
debug!
(
tracing
::
debug!
(
worker_id
,
"Skipping recovery - worker {worker_id} does not have local indexer enabled"
"Skipping recovery - worker does not have local indexer enabled"
);
);
continue
;
continue
;
}
}
...
@@ -101,10 +104,7 @@ pub async fn recover_from_all_workers(
...
@@ -101,10 +104,7 @@ pub async fn recover_from_all_workers(
// Log summary
// Log summary
if
total_recovered
>
0
||
failed_workers
>
0
{
if
total_recovered
>
0
||
failed_workers
>
0
{
tracing
::
info!
(
tracing
::
info!
(
total_recovered
,
"Startup recovery completed: {total_recovered} events recovered from {successful_workers} workers, {failed_workers} workers failed"
successful_workers
,
failed_workers
,
"Startup recovery completed"
);
);
}
}
...
@@ -133,35 +133,61 @@ pub async fn recover_from_worker(
...
@@ -133,35 +133,61 @@ pub async fn recover_from_worker(
)
->
Result
<
usize
>
{
)
->
Result
<
usize
>
{
if
worker_query_client
.has_local_indexer
(
worker_id
)
{
if
worker_query_client
.has_local_indexer
(
worker_id
)
{
tracing
::
debug!
(
tracing
::
debug!
(
worker_id
,
"Attempting recovery from worker {worker_id}, start_event_id: {start_event_id:?}, end_event_id: {end_event_id:?}"
start_event_id
=
?
start_event_id
,
end_event_id
=
?
end_event_id
,
"Attempting recovery from worker"
);
);
}
else
{
}
else
{
tracing
::
warn!
(
tracing
::
warn!
(
"Worker {worker_id} does not have local indexer enabled, skipping recovery"
);
worker_id
,
"Worker does not have local indexer enabled, skipping recovery"
);
return
Ok
(
0
);
return
Ok
(
0
);
}
}
// Query worker for events in range
// Query worker for events in range, with retry logic for transient failures
let
response
=
worker_query_client
// (e.g., worker's query service not yet re-subscribed after NATS restart)
let
mut
response
=
None
;
let
mut
last_error
=
None
;
for
attempt
in
0
..
WORKER_QUERY_MAX_RETRIES
{
match
worker_query_client
.query_worker
(
worker_id
,
start_event_id
,
end_event_id
)
.query_worker
(
worker_id
,
start_event_id
,
end_event_id
)
.await
?
;
.await
{
Ok
(
resp
)
=>
{
if
attempt
>
0
{
tracing
::
info!
(
"Worker {worker_id} query succeeded after retry {attempt}"
);
}
response
=
Some
(
resp
);
break
;
}
Err
(
e
)
=>
{
last_error
=
Some
(
e
);
if
attempt
<
WORKER_QUERY_MAX_RETRIES
-
1
{
let
backoff_ms
=
WORKER_QUERY_INITIAL_BACKOFF_MS
*
2_u64
.pow
(
attempt
);
tracing
::
warn!
(
"Worker {worker_id} query failed on attempt {attempt}, retrying after {backoff_ms}ms"
);
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
backoff_ms
))
.await
;
}
}
}
}
let
response
=
match
response
{
Some
(
r
)
=>
r
,
None
=>
return
Err
(
last_error
.unwrap_or_else
(||
anyhow
::
anyhow!
(
"No response"
))),
};
// Handle response variants
// Handle response variants
let
events
=
match
response
{
let
events
=
match
response
{
WorkerKvQueryResponse
::
Events
(
events
)
=>
{
WorkerKvQueryResponse
::
Events
(
events
)
=>
{
tracing
::
debug!
(
worker_id
,
count
=
events
.len
(),
"Got buffered events"
);
tracing
::
debug!
(
"Got {count} buffered events from worker {worker_id}"
,
count
=
events
.len
()
);
events
events
}
}
WorkerKvQueryResponse
::
TreeDump
(
events
)
=>
{
WorkerKvQueryResponse
::
TreeDump
(
events
)
=>
{
tracing
::
info!
(
tracing
::
info!
(
worker_id
,
"Got tree dump from worker {worker_id} (range too old or unspecified), count: {count}"
,
count
=
events
.len
(),
count
=
events
.len
()
"Got tree dump (range too old or unspecified)"
);
);
events
events
}
}
...
@@ -171,11 +197,7 @@ pub async fn recover_from_worker(
...
@@ -171,11 +197,7 @@ pub async fn recover_from_worker(
newest_available
,
newest_available
,
}
=>
{
}
=>
{
tracing
::
warn!
(
tracing
::
warn!
(
worker_id
,
"Worker {worker_id} requested range is newer than available data: requested_start: {requested_start:?}, requested_end: {requested_end:?}, newest_available: {newest_available}"
?
requested_start
,
?
requested_end
,
newest_available
,
"Requested range is newer than available data"
);
);
return
Ok
(
0
);
return
Ok
(
0
);
}
}
...
@@ -188,24 +210,21 @@ pub async fn recover_from_worker(
...
@@ -188,24 +210,21 @@ pub async fn recover_from_worker(
if
events_count
==
0
{
if
events_count
==
0
{
tracing
::
debug!
(
tracing
::
debug!
(
worker_id
,
"No events to recover from worker {worker_id}, start_event_id: {start_event_id:?}"
start_event_id
=
?
start_event_id
,
"No events to recover from worker"
);
);
return
Ok
(
0
);
return
Ok
(
0
);
}
}
tracing
::
info!
(
tracing
::
info!
(
worker_id
,
"Recovered {events_count} events from worker {worker_id}, start_event_id: {start_event_id:?}"
start_event_id
=
?
start_event_id
,
events_count
,
"Recovered {events_count} events from worker"
);
);
// Apply recovered events to the indexer
// Apply recovered events to the indexer
for
event
in
events
{
for
event
in
events
{
if
let
Err
(
e
)
=
event_tx
.send
(
event
)
.await
{
if
let
Err
(
e
)
=
event_tx
.send
(
event
)
.await
{
tracing
::
error!
(
worker_id
,
error
=
%
e
,
"Failed to send recovered event to indexer"
);
tracing
::
error!
(
"Failed to send recovered event to indexer for worker {worker_id}: {e}"
);
anyhow
::
bail!
(
"Failed to send recovered event: {e}"
);
anyhow
::
bail!
(
"Failed to send recovered event: {e}"
);
}
}
}
}
...
@@ -528,8 +547,7 @@ pub async fn start_kv_router_background(
...
@@ -528,8 +547,7 @@ pub async fn start_kv_router_background(
};
};
tracing
::
warn!
(
tracing
::
warn!
(
worker_id
=
worker_id
,
"DISCOVERY: Generate endpoint instance removed, removing worker {worker_id}"
"DISCOVERY: Generate endpoint instance removed, removing worker"
);
);
if
let
Err
(
e
)
=
remove_worker_tx
.send
(
worker_id
)
.await
{
if
let
Err
(
e
)
=
remove_worker_tx
.send
(
worker_id
)
.await
{
...
@@ -611,8 +629,7 @@ pub async fn start_kv_router_background(
...
@@ -611,8 +629,7 @@ pub async fn start_kv_router_background(
let
consumer_to_delete
=
router_instance_id
.to_string
();
let
consumer_to_delete
=
router_instance_id
.to_string
();
tracing
::
info!
(
tracing
::
info!
(
router_instance_id
=
router_instance_id
,
"DISCOVERY: Router instance {router_instance_id} removed, attempting to delete orphaned consumer: {consumer_to_delete}"
"DISCOVERY: Router instance removed, attempting to delete orphaned consumer: {consumer_to_delete}"
);
);
// Delete the consumer (allow race condition if multiple routers try to delete)
// Delete the consumer (allow race condition if multiple routers try to delete)
...
@@ -653,11 +670,11 @@ pub async fn start_kv_router_background_nats_core(
...
@@ -653,11 +670,11 @@ pub async fn start_kv_router_background_nats_core(
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
// Subscribe to KV events using NATS Core
// Subscribe to KV events using NATS Core
let
mut
subscriber
=
component
.subscribe
(
KV_EVENT_SUBJECT
)
.await
?
;
let
mut
subscriber
=
component
.subscribe
(
KV_EVENT_SUBJECT
)
.await
?
;
let
kv_event_subject
=
format!
(
"{}.{}"
,
component
.subject
(),
KV_EVENT_SUBJECT
);
tracing
::
info!
(
tracing
::
info!
(
"KV Router using NATS Core subscription on subject: {}.{} (local_indexer mode)"
,
subject
=
%
kv_event_subject
,
component
.subject
(),
"KV Router using NATS Core subscription (local_indexer mode)"
KV_EVENT_SUBJECT
);
);
// Get the generate endpoint and watch for instance events (add/remove)
// Get the generate endpoint and watch for instance events (add/remove)
...
@@ -696,8 +713,7 @@ pub async fn start_kv_router_background_nats_core(
...
@@ -696,8 +713,7 @@ pub async fn start_kv_router_background_nats_core(
let
worker_id
=
_
instance
.instance_id
();
let
worker_id
=
_
instance
.instance_id
();
tracing
::
info!
(
tracing
::
info!
(
worker_id
=
worker_id
,
"DISCOVERY: Worker {worker_id} added, dumping local indexer into router"
"DISCOVERY: Worker added, dumping local indexer into router"
);
);
// Query worker's local indexer and dump all events
// Query worker's local indexer and dump all events
...
@@ -712,24 +728,19 @@ pub async fn start_kv_router_background_nats_core(
...
@@ -712,24 +728,19 @@ pub async fn start_kv_router_background_nats_core(
{
{
Ok
(
count
)
=>
{
Ok
(
count
)
=>
{
tracing
::
info!
(
tracing
::
info!
(
worker_id
=
worker_id
,
"Successfully dumped worker {worker_id}'s local indexer, recovered {count} events"
events_recovered
=
count
,
"Successfully dumped worker's local indexer"
);
);
}
}
Err
(
e
)
=>
{
Err
(
e
)
=>
{
tracing
::
warn!
(
tracing
::
warn!
(
worker_id
=
worker_id
,
"Failed to dump worker {worker_id}'s local indexer (may not have local indexer enabled): {e}"
error
=
%
e
,
"Failed to dump worker's local indexer (may not have local indexer enabled)"
);
);
}
}
}
}
}
}
DiscoveryEvent
::
Removed
(
worker_id
)
=>
{
DiscoveryEvent
::
Removed
(
worker_id
)
=>
{
tracing
::
warn!
(
tracing
::
warn!
(
worker_id
=
worker_id
,
"DISCOVERY: Worker {worker_id} removed, removing from router indexer"
"DISCOVERY: Worker removed, removing from router indexer"
);
);
if
let
Err
(
e
)
=
remove_worker_tx
.send
(
worker_id
)
.await
{
if
let
Err
(
e
)
=
remove_worker_tx
.send
(
worker_id
)
.await
{
...
@@ -760,12 +771,9 @@ pub async fn start_kv_router_background_nats_core(
...
@@ -760,12 +771,9 @@ pub async fn start_kv_router_background_nats_core(
// Gap detected - recover missing events before processing current
// Gap detected - recover missing events before processing current
let
gap_start
=
last_id
+
1
;
let
gap_start
=
last_id
+
1
;
let
gap_end
=
event_id
-
1
;
let
gap_end
=
event_id
-
1
;
let
gap_size
=
gap_end
-
gap_start
+
1
;
tracing
::
warn!
(
tracing
::
warn!
(
worker_id
,
"Event ID gap detected for worker {worker_id}, recovering events [{gap_start}, {gap_end}], gap_size: {gap_size}"
gap_start
,
gap_end
,
gap_size
=
gap_end
-
gap_start
+
1
,
"Event ID gap detected, recovering events [{gap_start}, {gap_end}]"
);
);
// Note: While recovering, new events may queue in the NATS subscriber's
// Note: While recovering, new events may queue in the NATS subscriber's
...
@@ -779,11 +787,7 @@ pub async fn start_kv_router_background_nats_core(
...
@@ -779,11 +787,7 @@ pub async fn start_kv_router_background_nats_core(
&
kv_events_tx
,
&
kv_events_tx
,
)
.await
{
)
.await
{
tracing
::
error!
(
tracing
::
error!
(
worker_id
,
"Failed to recover gap events for worker {worker_id} (gap_start: {gap_start}, gap_end: {gap_end}); proceeding with current event anyway: {e}"
gap_start
,
gap_end
,
error
=
%
e
,
"Failed to recover gap events; proceeding with current event anyway"
);
);
// Note: If recovery fails, we still apply the current event.
// Note: If recovery fails, we still apply the current event.
// The tree will have a gap, but it's better than dropping the event.
// The tree will have a gap, but it's better than dropping the event.
...
...
tests/conftest.py
View file @
0fba01c2
...
@@ -305,6 +305,8 @@ class NatsServer(ManagedProcess):
...
@@ -305,6 +305,8 @@ class NatsServer(ManagedProcess):
self
.
port
=
port
self
.
port
=
port
self
.
use_random_port
=
use_random_port
# Track if we allocated the port
self
.
use_random_port
=
use_random_port
# Track if we allocated the port
self
.
_request
=
request
# Store for restart
self
.
_timeout
=
timeout
data_dir
=
tempfile
.
mkdtemp
(
prefix
=
"nats_"
)
data_dir
=
tempfile
.
mkdtemp
(
prefix
=
"nats_"
)
command
=
[
command
=
[
"nats-server"
,
"nats-server"
,
...
@@ -336,6 +338,39 @@ class NatsServer(ManagedProcess):
...
@@ -336,6 +338,39 @@ class NatsServer(ManagedProcess):
return
super
().
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
return
super
().
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
def
stop
(
self
):
"""Stop the NATS server for restart. Does not release port or clean up fully."""
_logger
.
info
(
f
"Stopping NATS server on port
{
self
.
port
}
"
)
self
.
_terminate_process_group
()
if
self
.
proc
:
try
:
self
.
proc
.
wait
(
timeout
=
10
)
except
Exception
as
e
:
_logger
.
warning
(
f
"Error waiting for NATS process to stop:
{
e
}
"
)
self
.
proc
=
None
def
start
(
self
):
"""Restart a stopped NATS server with fresh state."""
_logger
.
info
(
f
"Starting NATS server on port
{
self
.
port
}
with fresh state"
)
# Clean up old data directory and create fresh one
if
self
.
data_dir
:
shutil
.
rmtree
(
self
.
data_dir
,
ignore_errors
=
True
)
self
.
data_dir
=
tempfile
.
mkdtemp
(
prefix
=
"nats_"
)
# Rebuild command with new data_dir
self
.
command
=
[
"nats-server"
,
"-js"
,
"--trace"
,
"--store_dir"
,
self
.
data_dir
,
"-p"
,
str
(
self
.
port
),
]
self
.
_start_process
()
self
.
_check_ports
(
self
.
_timeout
)
class
SharedManagedProcess
:
class
SharedManagedProcess
:
"""Base class for ManagedProcess with file-based reference counting for multi-process sharing."""
"""Base class for ManagedProcess with file-based reference counting for multi-process sharing."""
...
...
tests/router/common.py
View file @
0fba01c2
...
@@ -8,7 +8,7 @@ import os
...
@@ -8,7 +8,7 @@ import os
import
random
import
random
import
string
import
string
import
time
import
time
from
typing
import
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
aiohttp
import
aiohttp
import
nats
import
nats
...
@@ -16,6 +16,9 @@ import nats
...
@@ -16,6 +16,9 @@ import nats
from
dynamo._core
import
DistributedRuntime
,
KvPushRouter
,
KvRouterConfig
from
dynamo._core
import
DistributedRuntime
,
KvPushRouter
,
KvRouterConfig
from
tests.utils.managed_process
import
ManagedProcess
from
tests.utils.managed_process
import
ManagedProcess
if
TYPE_CHECKING
:
from
tests.conftest
import
NatsServer
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
NUM_REQUESTS
=
100
NUM_REQUESTS
=
100
...
@@ -220,7 +223,7 @@ async def wait_for_frontend_ready(
...
@@ -220,7 +223,7 @@ async def wait_for_frontend_ready(
logger
.
debug
(
f
"Error checking models endpoint:
{
e
}
"
)
logger
.
debug
(
f
"Error checking models endpoint:
{
e
}
"
)
# Wait before next poll
# Wait before next poll
await
asyncio
.
sleep
(
2
)
await
asyncio
.
sleep
(
1
)
# Phase 2: Wait for chat completions pipeline to be ready
# Phase 2: Wait for chat completions pipeline to be ready
logger
.
info
(
"Waiting for chat completions pipeline to be built..."
)
logger
.
info
(
"Waiting for chat completions pipeline to be built..."
)
...
@@ -253,7 +256,7 @@ async def wait_for_frontend_ready(
...
@@ -253,7 +256,7 @@ async def wait_for_frontend_ready(
logger
.
debug
(
f
"Error testing chat completions:
{
e
}
"
)
logger
.
debug
(
f
"Error testing chat completions:
{
e
}
"
)
# Wait before next poll
# Wait before next poll
await
asyncio
.
sleep
(
2
)
await
asyncio
.
sleep
(
1
)
async
def
wait_for_workers_ready
(
async
def
wait_for_workers_ready
(
...
@@ -1321,6 +1324,9 @@ def _test_router_indexers_sync(
...
@@ -1321,6 +1324,9 @@ def _test_router_indexers_sync(
model_name
:
str
,
model_name
:
str
,
num_workers
:
int
,
num_workers
:
int
,
store_backend
:
str
=
"etcd"
,
store_backend
:
str
=
"etcd"
,
request_plane
:
str
=
"nats"
,
test_nats_interruption
:
bool
=
False
,
nats_server
:
Optional
[
"NatsServer"
]
=
None
,
):
):
"""Test that two KV routers have synchronized indexer states after processing requests.
"""Test that two KV routers have synchronized indexer states after processing requests.
...
@@ -1333,16 +1339,30 @@ def _test_router_indexers_sync(
...
@@ -1333,16 +1339,30 @@ def _test_router_indexers_sync(
This validates that the snapshot mechanism works and routers can sync state from NATS.
This validates that the snapshot mechanism works and routers can sync state from NATS.
When test_nats_interruption=True (requires nats_server and request_plane="tcp"):
- After first router sends 25 requests, NATS is stopped
- 10 more requests sent while NATS is down (stored locally by local indexer)
- NATS restarted (fresh state), recovery mechanism re-syncs
- Second router starts and sends 25 requests
- NATS stopped again, 10 more requests sent
- NATS restarted, 5 more requests sent
- Verify both routers converge to same state
Args:
Args:
engine_workers: Backend worker instance ({MockerProcess, VLLMProcess, TRTLLMProcess}) (already initialized with __enter__())
engine_workers: Backend worker instance ({MockerProcess, VLLMProcess, TRTLLMProcess}) (already initialized with __enter__())
block_size: Block size for KV cache
block_size: Block size for KV cache
model_name: Model name to use for requests
model_name: Model name to use for requests
num_workers: Expected number of workers
num_workers: Expected number of workers
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
request_plane: Request plane to use ("nats" or "tcp"). Defaults to "nats".
test_nats_interruption: If True, test NATS interruption recovery. Defaults to False.
nats_server: NatsServer instance for stop/start (required if test_nats_interruption=True).
Raises:
Raises:
AssertionError: If router states don't synchronize correctly or snapshot is missing
AssertionError: If router states don't synchronize correctly or snapshot is missing
"""
"""
if
test_nats_interruption
and
nats_server
is
None
:
raise
ValueError
(
"nats_server is required when test_nats_interruption=True"
)
# Use async to manage the test flow
# Use async to manage the test flow
async
def
test_sync
():
async
def
test_sync
():
...
@@ -1386,7 +1406,7 @@ def _test_router_indexers_sync(
...
@@ -1386,7 +1406,7 @@ def _test_router_indexers_sync(
# Create first runtime and endpoint for router 1
# Create first runtime and endpoint for router 1
logger
.
info
(
"Creating first KV router with its own runtime"
)
logger
.
info
(
"Creating first KV router with its own runtime"
)
runtime1
=
get_runtime
(
store_backend
)
runtime1
=
get_runtime
(
store_backend
,
request_plane
)
namespace1
=
runtime1
.
namespace
(
engine_workers
.
namespace
)
namespace1
=
runtime1
.
namespace
(
engine_workers
.
namespace
)
component1
=
namespace1
.
component
(
engine_workers
.
component_name
)
component1
=
namespace1
.
component
(
engine_workers
.
component_name
)
endpoint1
=
component1
.
endpoint
(
"generate"
)
endpoint1
=
component1
.
endpoint
(
"generate"
)
...
@@ -1413,13 +1433,35 @@ def _test_router_indexers_sync(
...
@@ -1413,13 +1433,35 @@ def _test_router_indexers_sync(
successful1
==
25
successful1
==
25
),
f
"Expected 25 successful requests to router 1, got
{
successful1
}
"
),
f
"Expected 25 successful requests to router 1, got
{
successful1
}
"
# NATS interruption test: stop NATS, send requests, restart
if
test_nats_interruption
:
await
asyncio
.
sleep
(
1
)
assert
nats_server
is
not
None
# Validated at function entry
logger
.
info
(
"=== NATS INTERRUPTION TEST: Phase 1 ==="
)
logger
.
info
(
"Stopping NATS server"
)
nats_server
.
stop
()
logger
.
info
(
"Sending 10 requests while NATS is down (via TCP)"
)
successful_offline1
=
await
send_requests_to_router
(
kv_push_router1
,
10
,
"Router 1 (NATS down)"
,
endpoint1
)
assert
(
successful_offline1
==
10
),
f
"Expected 10 successful requests while NATS down, got
{
successful_offline1
}
"
logger
.
info
(
"Restarting NATS server (fresh state)"
)
nats_server
.
start
()
await
asyncio
.
sleep
(
5
)
# Wait for a second before creating the second router
# Wait for a second before creating the second router
logger
.
info
(
"Waiting for 1 second before creating second router"
)
logger
.
info
(
"Waiting for 1 second before creating second router"
)
await
asyncio
.
sleep
(
1
)
await
asyncio
.
sleep
(
1
)
# Create second runtime and endpoint for router 2
# Create second runtime and endpoint for router 2
logger
.
info
(
"Creating second KV router with its own runtime"
)
logger
.
info
(
"Creating second KV router with its own runtime"
)
runtime2
=
get_runtime
(
store_backend
)
runtime2
=
get_runtime
(
store_backend
,
request_plane
)
namespace2
=
runtime2
.
namespace
(
engine_workers
.
namespace
)
namespace2
=
runtime2
.
namespace
(
engine_workers
.
namespace
)
component2
=
namespace2
.
component
(
engine_workers
.
component_name
)
component2
=
namespace2
.
component
(
engine_workers
.
component_name
)
endpoint2
=
component2
.
endpoint
(
"generate"
)
endpoint2
=
component2
.
endpoint
(
"generate"
)
...
@@ -1439,12 +1481,44 @@ def _test_router_indexers_sync(
...
@@ -1439,12 +1481,44 @@ def _test_router_indexers_sync(
successful2
==
25
successful2
==
25
),
f
"Expected 25 successful requests to router 2, got
{
successful2
}
"
),
f
"Expected 25 successful requests to router 2, got
{
successful2
}
"
# NATS interruption test: stop NATS again, send requests, restart, send more
if
test_nats_interruption
:
await
asyncio
.
sleep
(
1
)
assert
nats_server
is
not
None
# Validated at function entry
logger
.
info
(
"=== NATS INTERRUPTION TEST: Phase 2 ==="
)
logger
.
info
(
"Stopping NATS server"
)
nats_server
.
stop
()
logger
.
info
(
"Sending 10 requests while NATS is down (via TCP)"
)
successful_offline2
=
await
send_requests_to_router
(
kv_push_router2
,
10
,
"Router 2 (NATS down)"
,
endpoint2
)
assert
(
successful_offline2
==
10
),
f
"Expected 10 successful requests while NATS down, got
{
successful_offline2
}
"
logger
.
info
(
"Restarting NATS server (fresh state)"
)
nats_server
.
start
()
await
asyncio
.
sleep
(
5
)
logger
.
info
(
"Sending 5 more requests after NATS recovery"
)
successful_recovery
=
await
send_requests_to_router
(
kv_push_router1
,
5
,
"Router 1 (post-recovery)"
,
endpoint1
)
assert
(
successful_recovery
==
5
),
f
"Expected 5 successful requests post-recovery, got
{
successful_recovery
}
"
# Wait for all requests to complete (they should already be complete from gather)
# Wait for all requests to complete (they should already be complete from gather)
# Wait another 1 second for internal synchronization
# Wait another 1 second for internal synchronization
logger
.
info
(
"Waiting for final synchronization"
)
logger
.
info
(
"Waiting for final synchronization"
)
await
asyncio
.
sleep
(
1
)
await
asyncio
.
sleep
(
1
)
# Verify NATS object store bucket was created with snapshot
# Verify NATS object store bucket was created with snapshot
# Skip this verification for NATS interruption test since NATS restarts fresh
# (local indexer recovery doesn't rely on NATS persistence)
if
not
test_nats_interruption
:
# Mirror the Rust bucket naming logic from subscriber.rs:
# Mirror the Rust bucket naming logic from subscriber.rs:
# component.subject() -> "namespace.{ns}.component.{comp}"
# component.subject() -> "namespace.{ns}.component.{comp}"
# then slugify (convert dots to dashes, lowercase, etc) and append "-radix-bucket"
# then slugify (convert dots to dashes, lowercase, etc) and append "-radix-bucket"
...
@@ -1485,6 +1559,10 @@ def _test_router_indexers_sync(
...
@@ -1485,6 +1559,10 @@ def _test_router_indexers_sync(
f
"Expected snapshot to be created in bucket '
{
expected_bucket
}
' with file '
{
expected_file
}
'. "
f
"Expected snapshot to be created in bucket '
{
expected_bucket
}
' with file '
{
expected_file
}
'. "
f
"Router sent 25 requests with snapshot_threshold=20, so snapshot should have been triggered."
f
"Router sent 25 requests with snapshot_threshold=20, so snapshot should have been triggered."
)
)
else
:
logger
.
info
(
"Skipping NATS object store verification (NATS was restarted fresh for interruption test)"
)
# Dump states from both routers
# Dump states from both routers
logger
.
info
(
"Dumping states from both routers"
)
logger
.
info
(
"Dumping states from both routers"
)
...
@@ -1562,6 +1640,8 @@ def _test_router_indexers_sync(
...
@@ -1562,6 +1640,8 @@ def _test_router_indexers_sync(
logger
.
info
(
"Successfully verified that both router states are equal"
)
logger
.
info
(
"Successfully verified that both router states are equal"
)
# Verify NATS consumers are created (while routers are still alive)
# Verify NATS consumers are created (while routers are still alive)
# Skip this for NATS interruption test since it uses local indexer (NATS Core, not JetStream)
if
not
test_nats_interruption
:
logger
.
info
(
"Verifying NATS consumers exist for both routers"
)
logger
.
info
(
"Verifying NATS consumers exist for both routers"
)
component_subject
=
f
"namespace.
{
engine_workers
.
namespace
}
.component.
{
engine_workers
.
component_name
}
"
component_subject
=
f
"namespace.
{
engine_workers
.
namespace
}
.component.
{
engine_workers
.
component_name
}
"
slugified
=
component_subject
.
lower
().
replace
(
"."
,
"-"
).
replace
(
"_"
,
"-"
)
slugified
=
component_subject
.
lower
().
replace
(
"."
,
"-"
).
replace
(
"_"
,
"-"
)
...
@@ -1581,6 +1661,10 @@ def _test_router_indexers_sync(
...
@@ -1581,6 +1661,10 @@ def _test_router_indexers_sync(
logger
.
info
(
"✓ Verified 2 durable consumers exist (one per router)"
)
logger
.
info
(
"✓ Verified 2 durable consumers exist (one per router)"
)
finally
:
finally
:
await
nc
.
close
()
await
nc
.
close
()
else
:
logger
.
info
(
"Skipping NATS consumers verification (local indexer uses NATS Core, not JetStream)"
)
# Run the async test
# Run the async test
asyncio
.
run
(
test_sync
())
asyncio
.
run
(
test_sync
())
...
...
tests/router/test_router_e2e_with_mockers.py
View file @
0fba01c2
...
@@ -2,10 +2,12 @@
...
@@ -2,10 +2,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
logging
import
logging
import
os
import
os
from
contextlib
import
nullcontext
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
import
pytest
import
pytest
from
tests.conftest
import
EtcdServer
,
NatsServer
from
tests.router.common
import
(
# utilities
from
tests.router.common
import
(
# utilities
_test_busy_threshold_endpoint
,
_test_busy_threshold_endpoint
,
_test_python_router_bindings
,
_test_python_router_bindings
,
...
@@ -476,25 +478,50 @@ def test_kv_push_router_bindings(
...
@@ -476,25 +478,50 @@ def test_kv_push_router_bindings(
mockers
.
__exit__
(
None
,
None
,
None
)
mockers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
parametrize
(
"store_backend"
,
[
"etcd"
,
"file"
])
# NO @pytest.mark.parallel - nats_core variant stops/restarts NATS
@
pytest
.
mark
.
parametrize
(
"store_backend,use_nats_core,request_plane"
,
[
(
"etcd"
,
False
,
"nats"
),
# JetStream mode
# ("etcd", True, "tcp"), # ignored, needs unconditional nats_client
(
"file"
,
False
,
"nats"
),
# File backend
],
ids
=
[
"jetstream"
,
"file"
],
# "nats_core" commented out to match commented test case
)
def
test_indexers_sync
(
def
test_indexers_sync
(
request
,
request
,
runtime_services_session
,
predownload_tokenizers
,
predownload_tokenizers
,
file_storage_backend
,
file_storage_backend
,
store_backend
,
store_backend
,
use_nats_core
,
request_plane
,
):
):
"""
"""
Test that two KV routers have synchronized indexer states after processing requests.
Test that two KV routers have synchronized indexer states after processing requests.
This test verifies that both routers converge to the same internal state.
This test verifies that both routers converge to the same internal state.
Tests with both etcd and file storage backends.
"""
# runtime_services starts etcd and nats
Tests with three configurations:
logger
.
info
(
f
"Starting indexers sync test with
{
store_backend
}
storage backend"
)
- jetstream: etcd backend, JetStream for KV events, NATS request plane
- nats_core: etcd backend, local indexer with NATS Core, TCP request plane
(includes NATS interruption/recovery testing)
- file: file backend, JetStream for KV events, NATS request plane
"""
logger
.
info
(
f
"Starting indexers sync test: store_backend=
{
store_backend
}
, "
f
"use_nats_core=
{
use_nats_core
}
, request_plane=
{
request_plane
}
"
)
# Start NATS manually (needed for all variants - KV event sync)
with
NatsServer
(
request
)
as
nats_server
:
# Start etcd if needed
etcd_ctx
=
EtcdServer
(
request
)
if
store_backend
==
"etcd"
else
nullcontext
()
with
etcd_ctx
:
# Create mocker args dictionary
# Create mocker args dictionary
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
}
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
,
"enable_local_indexer"
:
use_nats_core
,
}
try
:
try
:
# Start mocker instances
# Start mocker instances
...
@@ -504,6 +531,7 @@ def test_indexers_sync(
...
@@ -504,6 +531,7 @@ def test_indexers_sync(
mocker_args
=
mocker_args
,
mocker_args
=
mocker_args
,
num_mockers
=
NUM_MOCKERS
,
num_mockers
=
NUM_MOCKERS
,
store_backend
=
store_backend
,
store_backend
=
store_backend
,
request_plane
=
request_plane
,
)
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
mockers
.
__enter__
()
mockers
.
__enter__
()
...
@@ -516,6 +544,9 @@ def test_indexers_sync(
...
@@ -516,6 +544,9 @@ def test_indexers_sync(
model_name
=
MODEL_NAME
,
model_name
=
MODEL_NAME
,
num_workers
=
NUM_MOCKERS
,
num_workers
=
NUM_MOCKERS
,
store_backend
=
store_backend
,
store_backend
=
store_backend
,
request_plane
=
request_plane
,
test_nats_interruption
=
use_nats_core
,
nats_server
=
nats_server
if
use_nats_core
else
None
,
)
)
logger
.
info
(
"Indexers sync test completed successfully"
)
logger
.
info
(
"Indexers sync test completed successfully"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment