Unverified Commit 0fba01c2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: make gap detection work e2e (#4993)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 54636097
......@@ -31,7 +31,6 @@ pub mod protocols;
pub mod publisher;
pub mod recorder;
pub mod scheduler;
pub mod scoring;
pub mod sequence;
pub mod subscriber;
pub mod worker_query;
......
......@@ -814,13 +814,6 @@ pub struct GetWorkersRequest {
pub resp: oneshot::Sender<Vec<WorkerId>>,
}
/// A request to get the last received event ID per worker.
/// Used for fault tolerance recovery to determine which events to request from workers.
pub struct GetLastReceivedEventIdsRequest {
/// Channel to send the last received event IDs per worker
pub resp: oneshot::Sender<HashMap<WorkerId, u64>>,
}
#[async_trait]
pub trait KvIndexerInterface {
/// Find matches for a given sequence of `LocalBlockHash`es.
......@@ -926,8 +919,6 @@ pub struct KvIndexer {
dump_tx: mpsc::Sender<DumpRequest>,
/// A sender for routing decision requests.
routing_tx: mpsc::Sender<RoutingDecisionRequest>,
/// A sender for getting last received event IDs (for fault tolerance recovery).
last_event_ids_tx: mpsc::Sender<GetLastReceivedEventIdsRequest>,
/// The size of the KV block this indexer can handle.
kv_block_size: u32,
/// Reference counter for Clone-aware Drop.
......@@ -962,8 +953,6 @@ impl KvIndexer {
let (dump_tx, dump_rx) = mpsc::channel::<DumpRequest>(16);
let (routing_tx, mut routing_rx) = mpsc::channel::<RoutingDecisionRequest>(2048);
let (prune_tx, mut prune_rx) = mpsc::channel::<()>(1);
let (last_event_ids_tx, mut last_event_ids_rx) =
mpsc::channel::<GetLastReceivedEventIdsRequest>(16);
let cancel_clone = token.clone();
......@@ -989,10 +978,6 @@ impl KvIndexer {
});
let mut event_id_counter = 0u64;
// Track last received event ID per worker (for fault tolerance recovery)
// Only used when enable_event_tracking is true
let mut last_received_event_id: HashMap<WorkerId, u64> = HashMap::new();
loop {
// Create a future that sleeps until the next expiration time
let expiry_fut = if let Some(ref pm) = prune_manager
......@@ -1019,10 +1004,6 @@ impl KvIndexer {
let _ = get_workers_req.resp.send(workers);
}
Some(req) = last_event_ids_rx.recv() => {
let _ = req.resp.send(last_received_event_id.clone());
}
Some(_) = prune_rx.recv() => {
// Tree size-based pruning triggered
let Some(ref mut pm) = prune_manager else { continue };
......@@ -1045,33 +1026,6 @@ impl KvIndexer {
}
Some(event) = event_rx.recv() => {
// Track last received event ID per worker
// Check for gaps before updating the last received ID
// TODO should this trigger a recovery event?
let last_id = *last_received_event_id.get(&event.worker_id).unwrap_or(&0);
let incoming_id = event.event.event_id;
// Detect gap: if incoming ID is more than 1 greater than last received
if incoming_id > last_id + 1 && last_id > 0 {
let gap_start = last_id + 1;
let gap_end = incoming_id - 1;
tracing::warn!(
worker_id = event.worker_id,
gap_start,
gap_end,
gap_size = gap_end - gap_start + 1,
"Event ID gap detected! Missed events [{}, {}]. \
If this is a global KvIndexer, within a KvRouter context,
consider calling KvRouter::query_worker_local_kv() to potentially recover worker-stored events.",
gap_start,
gap_end,
);
}
// Update last received event ID (use max to handle out-of-order events)
let entry = last_received_event_id.entry(event.worker_id).or_insert(0);
*entry = (*entry).max(event.event.event_id);
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let result = trie.apply_event(event.clone());
let result_is_ok = result.is_ok();
......@@ -1200,7 +1154,6 @@ impl KvIndexer {
get_workers_tx,
dump_tx,
routing_tx,
last_event_ids_tx,
kv_block_size,
_ref_count: Arc::new(()),
}
......@@ -1253,48 +1206,6 @@ impl KvIndexer {
pub fn get_workers_sender(&self) -> mpsc::Sender<GetWorkersRequest> {
self.get_workers_tx.clone()
}
/// Get a sender for last received event IDs requests.
///
/// ### Returns
///
/// A `mpsc::Sender` for `GetLastReceivedEventIdsRequest`s.
pub fn last_event_ids_sender(&self) -> mpsc::Sender<GetLastReceivedEventIdsRequest> {
self.last_event_ids_tx.clone()
}
/// Get the last received event ID for each worker.
///
/// This method is used for **fault tolerance recovery** when the router needs to
/// catch up on missed events after a disconnect. By tracking the last event ID
/// received from each worker, the router can query workers for events starting
/// from `last_id + 1` to recover missed state.
///
/// **Note**: This method is intdned for the global `KvIndexer` used by routers,
/// not on `LocalKvIndexer` (worker-side) or `KvIndexerSharded`.
///
/// ### Returns
///
/// A `HashMap` mapping worker IDs to their last received event ID.
///
pub async fn get_last_received_event_ids(
&self,
) -> Result<HashMap<WorkerId, u64>, KvRouterError> {
let (resp_tx, resp_rx) = oneshot::channel();
let req = GetLastReceivedEventIdsRequest { resp: resp_tx };
if let Err(e) = self.last_event_ids_tx.send(req).await {
tracing::error!(
"Failed to send last event IDs request: {:?}; the indexer maybe offline",
e
);
return Err(KvRouterError::IndexerOffline);
}
resp_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
}
}
#[async_trait]
......@@ -1571,7 +1482,7 @@ impl LocalKvIndexer {
"Non-consecutive KV event id; buffer may have gaps"
);
}
tracing::info!(
tracing::debug!(
"Recorded event {:?} in buffer, now size is {}",
event,
buffer.len()
......@@ -1640,384 +1551,138 @@ impl LocalKvIndexer {
}
}
#[cfg(test)]
mod local_kv_indexer_tests {
use super::*;
fn make_indexer_with_events(ids: &[u64]) -> LocalKvIndexer {
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
32,
);
{
let mut buffer = indexer.event_buffer.lock().unwrap();
for &id in ids {
buffer.push_back(RouterEvent::new(
0,
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Cleared,
dp_rank: 0,
},
));
}
}
indexer
// Implement KvIndexerInterface by delegating to the underlying indexer
#[async_trait]
impl KvIndexerInterface for LocalKvIndexer {
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
self.indexer.find_matches(sequence).await
}
#[tokio::test]
async fn returns_slice_within_range() {
let indexer = make_indexer_with_events(&[1, 2, 3, 4, 5]);
// Helper to extract events from response
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
match resp {
WorkerKvQueryResponse::Events(e) => e,
WorkerKvQueryResponse::TreeDump(e) => e,
_ => panic!("Unexpected response type"),
}
};
async fn find_matches_for_request(
&self,
tokens: &[u32],
) -> Result<OverlapScores, KvRouterError> {
self.indexer.find_matches_for_request(tokens).await
}
let get_ids = |events: Vec<RouterEvent>| -> Vec<u64> {
events.iter().map(|e| e.event.event_id).collect()
};
async fn apply_event(&mut self, event: RouterEvent) {
// Use the buffering version
let _ = self.apply_event_with_buffer(event).await;
}
// Test get_events_in_id_range (buffer queries)
// Range is [start, end] inclusive
let result = indexer.get_events_in_id_range(Some(2), Some(4)).await;
let ids = get_ids(extract_events(result));
assert_eq!(ids, vec![2, 3, 4]); // inclusive range [2, 4]
async fn remove_worker(&mut self, worker: WorkerId) {
let _ = self.indexer.remove_worker_sender().send(worker).await;
}
let result = indexer.get_events_in_id_range(Some(2), Some(6)).await;
let ids = get_ids(extract_events(result));
assert_eq!(ids, vec![2, 3, 4, 5]); // clamp end to buffer max
fn shutdown(&mut self) {
// Note: Since indexer is Arc<KvIndexer>, we can't call mutable methods directly.
// The indexer will be shut down when the CancellationToken is cancelled
// or when the last Arc reference is dropped.
}
// start_id=0 is before buffer (first is 1), so should trigger tree dump
let result = indexer.get_events_in_id_range(Some(0), Some(4)).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.indexer.dump_events().await
}
let result = indexer.get_events_in_id_range(Some(3), Some(3)).await;
let ids = get_ids(extract_events(result));
assert_eq!(ids, vec![3]); // single element when start == end
async fn process_routing_decision(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
// TODO I guess the local kvindexers have little use for this method?
// Keeping it here now to implement the trait fully
self.indexer
.process_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
// Invalid range: end < start
let result = indexer.get_events_in_id_range(Some(5), Some(2)).await;
assert!(matches!(result, WorkerKvQueryResponse::InvalidRange { .. }));
async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
// TODO I guess the local kvindexers have little use for this method?
// Keeping it here now to implement the trait fully
self.indexer
.process_routing_decision_for_request(tokens, worker)
.await
}
}
#[tokio::test]
async fn test_get_events_in_id_range_all_cases() {
use crate::kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
#[derive(Debug, Clone)]
pub struct ShardedMatchRequest {
sequence: Vec<LocalBlockHash>,
early_exit: bool,
resp: mpsc::Sender<OverlapScores>,
}
// Create indexer with small buffer (5 events max)
// This way older events will only be in the tree, not the buffer
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4, // block_size
Arc::new(KvIndexerMetrics::new_unregistered()),
5, // max_buffer_size - only keeps 5 most recent events
);
/// A sharded KV Indexer that partitions the RadixTree across multiple independent shards.
///
/// ## Sharding Strategy
/// - Each worker is **permanently assigned** to a single shard on first event
/// - All KV blocks from a worker exist only in that worker's assigned shard
/// - New workers are assigned to the shard with the fewest workers (load balancing)
///
/// ## Operation
/// - **Events**: Routed directly to the worker's assigned shard
/// - **Match requests**: Broadcast to all shards (scatter-gather pattern)
/// - **Threading**: Each shard runs in its own thread with a single-threaded runtime
///
/// This design ensures no cross-shard synchronization for writes while enabling
/// parallel processing and better scalability.
pub struct KvIndexerSharded {
/// A `CancellationToken` for managing shutdown.
cancel: CancellationToken,
/// The size of the KV block this indexer can handle.
kv_block_size: u32,
worker_assignments: HashMap<WorkerId, usize>,
worker_counts: Vec<usize>,
// Helper to create a test event
let make_event = |id: u64| {
RouterEvent::new(
0, // worker_id
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(id * 100),
tokens_hash: LocalBlockHash(id * 200),
}],
}),
dp_rank: 0,
},
)
};
event_tx: Vec<mpsc::Sender<RouterEvent>>,
request_broadcast_tx: broadcast::Sender<ShardedMatchRequest>,
remove_worker_tx: Vec<mpsc::Sender<WorkerId>>,
dump_tx: Vec<mpsc::Sender<DumpRequest>>,
routing_tx: Vec<mpsc::Sender<RoutingDecisionRequest>>,
tasks: Vec<JoinHandle<()>>,
}
// Add 10 events (IDs 5-14)
// Buffer will only keep the last 5: events 10-14
// Tree will have all blocks
for id in 5..15 {
indexer
.apply_event_with_buffer(make_event(id))
.await
.unwrap();
}
impl KvIndexerSharded {
/// Create a new `KvIndexerSharded`.
///
/// ### Arguments
///
/// * `token` - A `CancellationToken` for managing shutdown.
/// * `shards` - A list of kvindexer shards.
/// * `expiration_duration` - The amount of time that block usage should be buffered.
/// * `ttl` - The time-to-live for blocks before they expire.
/// * `prune_config` - Configuration for tree-size based pruning.
///
/// ### Returns
///
/// A new `KvIndexer`.
pub fn new_with_frequency(
token: CancellationToken,
num_shards: usize,
expiration_duration: Option<Duration>,
kv_block_size: u32,
metrics: Arc<KvIndexerMetrics>,
prune_config: Option<PruneConfig>,
) -> Self {
let worker_assignments: HashMap<WorkerId, usize> = HashMap::new();
let worker_counts: Vec<usize> = vec![0; num_shards];
// Wait for events to be processed by the tree
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let mut event_tx = Vec::new();
let mut remove_worker_tx = Vec::new();
let mut get_workers_tx = Vec::new();
let mut dump_tx = Vec::new();
let mut routing_tx = Vec::new();
let mut tasks = Vec::new();
// Helper to extract events from response
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
match resp {
WorkerKvQueryResponse::Events(e) => e,
WorkerKvQueryResponse::TreeDump(e) => e,
_ => panic!("Unexpected response type: {:?}", resp),
}
};
// Helper to extract event IDs from result
let get_ids = |events: Vec<RouterEvent>| -> Vec<u64> {
events.iter().map(|e| e.event.event_id).collect()
};
// Verify buffer state: should have events 10-14 (last 5)
let buffer_events = indexer.get_all_events_in_buffer();
assert_eq!(
get_ids(buffer_events),
vec![10, 11, 12, 13, 14],
"Buffer should have events 10-14"
);
// ========== BUFFER PATH TESTS (start_id >= first_buffered) ==========
// Range is [start, end] inclusive
// Test: start_id within buffer, no end
let result = indexer.get_events_in_id_range(Some(11), None).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
get_ids(extract_events(result)),
vec![11, 12, 13, 14],
"start_id=11 (in buffer) should return [11, 14]"
);
// Test: start_id at buffer boundary
let result = indexer.get_events_in_id_range(Some(10), None).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
get_ids(extract_events(result)),
vec![10, 11, 12, 13, 14],
"start_id=10 (buffer start) should return [10, 14]"
);
// Test: both start and end within buffer (inclusive)
let result = indexer.get_events_in_id_range(Some(11), Some(13)).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
get_ids(extract_events(result)),
vec![11, 12, 13],
"range [11, 13] inclusive should return 3 events"
);
let result = indexer.get_events_in_id_range(Some(10), Some(14)).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
get_ids(extract_events(result)),
vec![10, 11, 12, 13, 14],
"range [10, 14] should return all buffer events"
);
// ========== TREE DUMP PATH TESTS (range extends before buffer) ==========
// Note: Tree dumps return synthetic 0-indexed event IDs, so we just check
// that we get events back (the IDs won't match original IDs)
// Test: (None, None) dumps entire tree
let result = indexer.get_events_in_id_range(None, None).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
assert_eq!(
extract_events(result).len(),
10,
"(None, None) should dump entire tree (10 events)"
);
// Test: (None, Some(_)) dumps entire tree
let result = indexer.get_events_in_id_range(None, Some(8)).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
assert_eq!(
extract_events(result).len(),
10,
"(None, Some(_)) dumps entire tree - end_id is ignored for tree dumps"
);
// Test: start_id before buffer triggers tree dump
let result = indexer.get_events_in_id_range(Some(7), None).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
assert_eq!(
extract_events(result).len(),
10,
"start_id=7 (before buffer) should dump entire tree"
);
let result = indexer.get_events_in_id_range(Some(5), Some(12)).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
assert_eq!(
extract_events(result).len(),
10,
"range [5, 12] extending before buffer should dump entire tree"
);
// ========== EDGE CASES ==========
// Single element when start == end (inclusive range)
let result = indexer.get_events_in_id_range(Some(12), Some(12)).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
get_ids(extract_events(result)),
vec![12],
"start == end should return single event"
);
// InvalidRange when start > end
let result = indexer.get_events_in_id_range(Some(15), Some(10)).await;
assert!(
matches!(result, WorkerKvQueryResponse::InvalidRange { .. }),
"start > end should return InvalidRange"
);
// TooNew when start_id is beyond buffer
let result = indexer.get_events_in_id_range(Some(100), Some(200)).await;
assert!(
matches!(result, WorkerKvQueryResponse::TooNew { .. }),
"start_id beyond buffer should return TooNew"
);
// Request with end beyond buffer but valid start -> buffer returns what it has
let result = indexer.get_events_in_id_range(Some(12), Some(100)).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
get_ids(extract_events(result)),
vec![12, 13, 14],
"range with end beyond buffer should return available buffer events"
);
}
}
// Implement KvIndexerInterface by delegating to the underlying indexer
#[async_trait]
impl KvIndexerInterface for LocalKvIndexer {
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
self.indexer.find_matches(sequence).await
}
async fn find_matches_for_request(
&self,
tokens: &[u32],
) -> Result<OverlapScores, KvRouterError> {
self.indexer.find_matches_for_request(tokens).await
}
async fn apply_event(&mut self, event: RouterEvent) {
// Use the buffering version
let _ = self.apply_event_with_buffer(event).await;
}
async fn remove_worker(&mut self, worker: WorkerId) {
let _ = self.indexer.remove_worker_sender().send(worker).await;
}
fn shutdown(&mut self) {
// Note: Since indexer is Arc<KvIndexer>, we can't call mutable methods directly.
// The indexer will be shut down when the CancellationToken is cancelled
// or when the last Arc reference is dropped.
}
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.indexer.dump_events().await
}
async fn process_routing_decision(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
// TODO I guess the local kvindexers have little use for this method?
// Keeping it here now to implement the trait fully
self.indexer
.process_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
// TODO I guess the local kvindexers have little use for this method?
// Keeping it here now to implement the trait fully
self.indexer
.process_routing_decision_for_request(tokens, worker)
.await
}
}
#[derive(Debug, Clone)]
pub struct ShardedMatchRequest {
sequence: Vec<LocalBlockHash>,
early_exit: bool,
resp: mpsc::Sender<OverlapScores>,
}
/// A sharded KV Indexer that partitions the RadixTree across multiple independent shards.
///
/// ## Sharding Strategy
/// - Each worker is **permanently assigned** to a single shard on first event
/// - All KV blocks from a worker exist only in that worker's assigned shard
/// - New workers are assigned to the shard with the fewest workers (load balancing)
///
/// ## Operation
/// - **Events**: Routed directly to the worker's assigned shard
/// - **Match requests**: Broadcast to all shards (scatter-gather pattern)
/// - **Threading**: Each shard runs in its own thread with a single-threaded runtime
///
/// This design ensures no cross-shard synchronization for writes while enabling
/// parallel processing and better scalability.
pub struct KvIndexerSharded {
/// A `CancellationToken` for managing shutdown.
cancel: CancellationToken,
/// The size of the KV block this indexer can handle.
kv_block_size: u32,
worker_assignments: HashMap<WorkerId, usize>,
worker_counts: Vec<usize>,
event_tx: Vec<mpsc::Sender<RouterEvent>>,
request_broadcast_tx: broadcast::Sender<ShardedMatchRequest>,
remove_worker_tx: Vec<mpsc::Sender<WorkerId>>,
dump_tx: Vec<mpsc::Sender<DumpRequest>>,
routing_tx: Vec<mpsc::Sender<RoutingDecisionRequest>>,
tasks: Vec<JoinHandle<()>>,
}
impl KvIndexerSharded {
/// Create a new `KvIndexerSharded`.
///
/// ### Arguments
///
/// * `token` - A `CancellationToken` for managing shutdown.
/// * `shards` - A list of kvindexer shards.
/// * `expiration_duration` - The amount of time that block usage should be buffered.
/// * `ttl` - The time-to-live for blocks before they expire.
/// * `prune_config` - Configuration for tree-size based pruning.
///
/// ### Returns
///
/// A new `KvIndexer`.
pub fn new_with_frequency(
token: CancellationToken,
num_shards: usize,
expiration_duration: Option<Duration>,
kv_block_size: u32,
metrics: Arc<KvIndexerMetrics>,
prune_config: Option<PruneConfig>,
) -> Self {
let worker_assignments: HashMap<WorkerId, usize> = HashMap::new();
let worker_counts: Vec<usize> = vec![0; num_shards];
let mut event_tx = Vec::new();
let mut remove_worker_tx = Vec::new();
let mut get_workers_tx = Vec::new();
let mut dump_tx = Vec::new();
let mut routing_tx = Vec::new();
let mut tasks = Vec::new();
let (request_broadcast_tx, _) = broadcast::channel::<ShardedMatchRequest>(1048576);
let (request_broadcast_tx, _) = broadcast::channel::<ShardedMatchRequest>(1048576);
for _ in 0..num_shards {
let (shard_event_tx, mut shard_event_rx) = mpsc::channel::<RouterEvent>(2048);
......@@ -2429,8 +2094,8 @@ impl Drop for KvIndexerSharded {
#[cfg(test)]
mod tests {
use super::*;
use crate::kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
use rstest::rstest;
use rstest_reuse::{self, *};
use tokio::time;
......@@ -3072,614 +2737,813 @@ mod tests {
if num_shards == 1 {
Box::new(KvIndexer::new(token.clone(), kv_block_size, metrics.into()))
} else {
Box::new(KvIndexerSharded::new(
token.clone(),
num_shards,
kv_block_size,
metrics.into(),
))
Box::new(KvIndexerSharded::new(
token.clone(),
num_shards,
kv_block_size,
metrics.into(),
))
}
}
#[template]
#[rstest]
fn indexer_template(
#[values(1, 3, 8)] num_shards: usize,
#[values(11, 32, 64)] kv_block_size: usize,
) {
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_kv_indexer_new(num_shards: usize, kv_block_size: u32) {
setup();
let token: CancellationToken = CancellationToken::new();
let _ = make_indexer(&token, num_shards, kv_block_size);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_find_matches(num_shards: usize, kv_block_size: u32) {
setup();
let token = CancellationToken::new();
let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
let sequence = vec![compute_block_hash(b"test data")];
let scores = kv_indexer.find_matches(sequence).await;
assert!(scores.unwrap().scores.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_find_matches_for_request(num_shards: usize, kv_block_size: u32) {
setup();
let token = CancellationToken::new();
let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
let tokens = vec![1, 2, 3, 4];
let scores = kv_indexer.find_matches_for_request(&tokens).await;
assert!(scores.unwrap().scores.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_apply_event(num_shards: usize, kv_block_size: u32) {
setup();
let worker_id = 0;
let token = CancellationToken::new();
let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
let event = create_store_event(worker_id, 1, vec![1, 2, 3], None);
kv_indexer.apply_event(event).await;
// No assertion here, just ensuring it runs without panic
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_shutdown(num_shards: usize, kv_block_size: u32) {
setup();
let token = CancellationToken::new();
let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
kv_indexer.shutdown();
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_frequency(num_shards: usize, kv_block_size: u32) {
const ONE_MILLIS: Duration = Duration::from_millis(1);
setup();
let mut kv_indexer: Box<dyn KvIndexerInterface>;
let token = CancellationToken::new();
let expiration = Duration::from_millis(50);
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
if num_shards == 1 {
kv_indexer = Box::new(KvIndexer::new_with_frequency(
token,
Some(expiration),
kv_block_size,
metrics,
None,
));
} else {
kv_indexer = Box::new(KvIndexerSharded::new_with_frequency(
token,
num_shards,
Some(expiration),
kv_block_size,
metrics,
None,
));
}
// The blocks
let block_hashes = vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
LocalBlockHash(4),
];
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(
overlap.frequencies.len(),
0,
"Should be no cached blocks yet"
);
// Blocks go in cache
let worker_id = 0;
let event = create_store_event(worker_id, 0, vec![1, 2, 3, 4], None);
kv_indexer.apply_event(event).await;
// First access
// The store event is applied async so poll briefly
let mut overlap = OverlapScores::default();
let timeout = Duration::from_millis(10);
let start = Instant::now();
while overlap.scores.is_empty() && Instant::now().duration_since(start) < timeout {
time::sleep(ONE_MILLIS).await;
overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
}
assert_eq!(
overlap.scores.len(),
1,
"One worker has these blocks cached"
);
assert_eq!(
overlap.frequencies.len(),
0,
"Blocks have not previously been accessed"
);
// Second access
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(overlap.scores.len(), 1, "Still one worker matches");
assert_eq!(
overlap.frequencies,
vec![1, 1, 1, 1],
"We should see the first access now"
);
// Let those two accesses expire
time::sleep(expiration + Duration::from_millis(10)).await;
// New first access
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(
overlap.frequencies.len(),
0,
"Blocks were accessed too long ago"
);
// New second access
let _ = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
// Access only the first three blocks
let overlap = kv_indexer
.find_matches(block_hashes[0..3].to_vec())
.await
.unwrap();
// We see the previous two new accesses
assert_eq!(overlap.frequencies, vec![2, 2, 2]);
// The third access did not touch the last block
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(overlap.frequencies, vec![3, 3, 3, 2]);
}
#[test]
fn test_router_event_new() {
setup();
let worker_id = 0;
let kv_cache_event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(0),
tokens_hash: LocalBlockHash(13226331709069118873),
}],
}),
dp_rank: 0,
};
let router_event = RouterEvent::new(worker_id, kv_cache_event);
assert_eq!(router_event.worker_id, worker_id);
assert_eq!(router_event.event.event_id, 1);
if let KvCacheEventData::Stored(store_op) = &router_event.event.data {
assert_eq!(store_op.blocks.len(), 1);
assert_eq!(
store_op.blocks[0].tokens_hash,
compute_block_hash(b"test data")
);
assert_eq!(store_op.blocks[0].block_hash, ExternalSequenceBlockHash(0));
} else {
panic!("Expected KvCacheEventData::Stored");
}
}
#[template]
#[rstest]
fn indexer_template(
#[values(1, 3, 8)] num_shards: usize,
#[values(11, 32, 64)] kv_block_size: usize,
) {
#[test]
fn test_radix_tree_default() {
setup();
let radix_tree: RadixTree = Default::default();
assert!(radix_tree.root.borrow().children.is_empty());
assert!(radix_tree.root.borrow().workers.is_empty());
assert!(radix_tree.lookup.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_kv_indexer_new(num_shards: usize, kv_block_size: u32) {
#[test]
fn test_overlap_scores_default() {
setup();
let token: CancellationToken = CancellationToken::new();
let _ = make_indexer(&token, num_shards, kv_block_size);
let overlap_scores: OverlapScores = Default::default();
assert!(overlap_scores.scores.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_find_matches(num_shards: usize, kv_block_size: u32) {
async fn test_dump_tree_as_events_round_trip() {
setup();
let token = CancellationToken::new();
let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
let sequence = vec![compute_block_hash(b"test data")];
let scores = kv_indexer.find_matches(sequence).await;
// Configuration
let kv_block_size = 32;
let num_shards = 2;
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
assert!(scores.unwrap().scores.is_empty());
}
// Build a non-trivial indexer with events
let token1 = CancellationToken::new();
let mut original_indexer =
KvIndexerSharded::new(token1.clone(), num_shards, kv_block_size, metrics.clone());
#[tokio::test]
#[apply(indexer_template)]
async fn test_find_matches_for_request(num_shards: usize, kv_block_size: u32) {
setup();
let token = CancellationToken::new();
let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
let worker_0 = 0;
let worker_1 = 1;
let worker_2 = 2;
let tokens = vec![1, 2, 3, 4];
let scores = kv_indexer.find_matches_for_request(&tokens).await;
// Apply events to the original indexer
original_indexer
.apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
.await;
assert!(scores.unwrap().scores.is_empty());
}
original_indexer
.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.await;
original_indexer
.apply_event(create_store_event(
worker_1,
2,
vec![4, 5],
Some(ExternalSequenceBlockHash(100)),
))
.await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_apply_event(num_shards: usize, kv_block_size: u32) {
setup();
let worker_id = 0;
original_indexer
.apply_event(create_store_event(worker_2, 3, vec![6, 7], None))
.await;
let token = CancellationToken::new();
let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
original_indexer
.apply_event(create_store_event(
worker_0,
4,
vec![4],
Some(ExternalSequenceBlockHash(100)),
))
.await;
let event = create_store_event(worker_id, 1, vec![1, 2, 3], None);
kv_indexer.apply_event(event).await;
// Allow some time for events to be processed
tokio::time::sleep(Duration::from_millis(50)).await;
// No assertion here, just ensuring it runs without panic
}
// Dump the original indexer
let dump1 = original_indexer.dump_events().await.unwrap();
println!("Dumped {} events", dump1.len());
#[tokio::test]
#[apply(indexer_template)]
async fn test_shutdown(num_shards: usize, kv_block_size: u32) {
setup();
let token = CancellationToken::new();
let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
// Create a new indexer and apply all dumped events
let token2 = CancellationToken::new();
let mut reconstructed_indexer =
KvIndexerSharded::new(token2.clone(), num_shards, kv_block_size, metrics);
kv_indexer.shutdown();
}
for event in &dump1 {
reconstructed_indexer.apply_event(event.clone()).await;
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_frequency(num_shards: usize, kv_block_size: u32) {
const ONE_MILLIS: Duration = Duration::from_millis(1);
// Allow some time for events to be processed
tokio::time::sleep(Duration::from_millis(50)).await;
setup();
let mut kv_indexer: Box<dyn KvIndexerInterface>;
let token = CancellationToken::new();
let expiration = Duration::from_millis(50);
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
// Dump the reconstructed indexer
let dump2 = reconstructed_indexer.dump_events().await.unwrap();
if num_shards == 1 {
kv_indexer = Box::new(KvIndexer::new_with_frequency(
token,
Some(expiration),
kv_block_size,
metrics,
None,
));
} else {
kv_indexer = Box::new(KvIndexerSharded::new_with_frequency(
token,
num_shards,
Some(expiration),
kv_block_size,
metrics,
None,
));
// Sort both dumps for comparison (order might differ due to HashMap iteration and sharding)
let mut sorted_dump1 = dump1.clone();
let mut sorted_dump2 = dump2.clone();
// Sort by (worker_id, tokens_hash, parent_hash)
let sort_key = |event: &RouterEvent| {
if let KvCacheEventData::Stored(ref data) = event.event.data {
(
event.worker_id,
data.blocks.first().map(|b| b.tokens_hash.0).unwrap_or(0),
data.parent_hash.map(|h| h.0).unwrap_or(0),
)
} else {
(event.worker_id, 0, 0)
}
};
sorted_dump1.sort_by_key(sort_key);
sorted_dump2.sort_by_key(sort_key);
// Verify the dumps have the same length
assert_eq!(
sorted_dump1.len(),
sorted_dump2.len(),
"Dumps have different lengths: {} vs {}",
sorted_dump1.len(),
sorted_dump2.len()
);
// Verify each event matches
for (i, (event1, event2)) in sorted_dump1.iter().zip(sorted_dump2.iter()).enumerate() {
assert_eq!(
event1.worker_id, event2.worker_id,
"Event {} worker_id mismatch",
i
);
if let (KvCacheEventData::Stored(data1), KvCacheEventData::Stored(data2)) =
(&event1.event.data, &event2.event.data)
{
assert_eq!(
data1.parent_hash, data2.parent_hash,
"Event {} parent_hash mismatch",
i
);
assert_eq!(
data1.blocks.len(),
data2.blocks.len(),
"Event {} blocks length mismatch",
i
);
for (j, (block1, block2)) in
data1.blocks.iter().zip(data2.blocks.iter()).enumerate()
{
assert_eq!(
block1.tokens_hash, block2.tokens_hash,
"Event {} block {} tokens_hash mismatch",
i, j
);
assert_eq!(
block1.block_hash, block2.block_hash,
"Event {} block {} block_hash mismatch",
i, j
);
}
} else {
panic!("Expected Stored events in both dumps");
}
}
// The blocks
let block_hashes = vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
LocalBlockHash(4),
];
// Also verify that both indexers produce the same match results
for test_seq in [
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
vec![LocalBlockHash(1), LocalBlockHash(4), LocalBlockHash(5)],
vec![LocalBlockHash(6), LocalBlockHash(7)],
vec![LocalBlockHash(1)],
] {
let scores1 = original_indexer
.find_matches(test_seq.clone())
.await
.unwrap();
let scores2 = reconstructed_indexer
.find_matches(test_seq.clone())
.await
.unwrap();
// Sort the scores to compare
let mut scores1_sorted: Vec<_> = scores1.scores.iter().collect();
let mut scores2_sorted: Vec<_> = scores2.scores.iter().collect();
scores1_sorted.sort_by_key(|(k, _)| *k);
scores2_sorted.sort_by_key(|(k, _)| *k);
assert_eq!(
scores1_sorted, scores2_sorted,
"Match scores differ for sequence {:?}",
test_seq
);
}
// Clean up
original_indexer.shutdown();
reconstructed_indexer.shutdown();
}
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
#[test]
fn test_increment_event_applied() {
let metrics = KvIndexerMetrics::new_unregistered();
metrics.increment_event_applied(METRIC_EVENT_STORED, Ok(()));
assert_eq!(
overlap.frequencies.len(),
0,
"Should be no cached blocks yet"
metrics
.kv_cache_events_applied
.get_metric_with_label_values(&[METRIC_EVENT_STORED, METRIC_STATUS_OK])
.unwrap()
.get(),
1
);
// Blocks go in cache
let worker_id = 0;
let event = create_store_event(worker_id, 0, vec![1, 2, 3, 4], None);
kv_indexer.apply_event(event).await;
// First access
// The store event is applied async so poll briefly
let mut overlap = OverlapScores::default();
let timeout = Duration::from_millis(10);
let start = Instant::now();
while overlap.scores.is_empty() && Instant::now().duration_since(start) < timeout {
time::sleep(ONE_MILLIS).await;
overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
}
assert_eq!(
overlap.scores.len(),
1,
"One worker has these blocks cached"
metrics.increment_event_applied(
METRIC_EVENT_STORED,
Err(KvCacheEventError::ParentBlockNotFound),
);
assert_eq!(
overlap.frequencies.len(),
0,
"Blocks have not previously been accessed"
metrics
.kv_cache_events_applied
.get_metric_with_label_values(&[
METRIC_EVENT_STORED,
METRIC_STATUS_PARENT_NOT_FOUND
])
.unwrap()
.get(),
1
);
// Second access
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(overlap.scores.len(), 1, "Still one worker matches");
metrics
.increment_event_applied(METRIC_EVENT_REMOVED, Err(KvCacheEventError::BlockNotFound));
assert_eq!(
overlap.frequencies,
vec![1, 1, 1, 1],
"We should see the first access now"
metrics
.kv_cache_events_applied
.get_metric_with_label_values(&[
METRIC_EVENT_REMOVED,
METRIC_STATUS_BLOCK_NOT_FOUND
])
.unwrap()
.get(),
1
);
}
// Let those two accesses expire
time::sleep(expiration + Duration::from_millis(10)).await;
#[test]
fn test_remove_worker_verifies_hash_removal() {
setup();
let mut trie = RadixTree::new();
// New first access
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
let worker_0 = 0;
let worker_1 = 1;
let worker_2 = 2;
// Add blocks for multiple workers
trie.apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_2, 0, vec![1, 4, 5], None))
.unwrap();
// Verify worker_0 has 3 blocks in lookup
assert_eq!(
overlap.frequencies.len(),
0,
"Blocks were accessed too long ago"
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.len(),
3
);
// New second access
let _ = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
// Access only the first three blocks
let overlap = kv_indexer
.find_matches(block_hashes[0..3].to_vec())
.await
// Verify that blocks have the correct workers
let block_1 = trie
.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.get(&ExternalSequenceBlockHash(100))
.unwrap();
// We see the previous two new accesses
assert_eq!(overlap.frequencies, vec![2, 2, 2]);
assert_eq!(block_1.borrow().workers.len(), 3); // worker_0, worker_1, and worker_2 (all have hash 1)
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_2))
);
// The third access did not touch the last block
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(overlap.frequencies, vec![3, 3, 3, 2]);
}
// Remove worker_0
trie.remove_worker(worker_0);
#[test]
fn test_router_event_new() {
setup();
let worker_id = 0;
let kv_cache_event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(0),
tokens_hash: LocalBlockHash(13226331709069118873),
}],
}),
dp_rank: 0,
};
let router_event = RouterEvent::new(worker_id, kv_cache_event);
// Verify worker_0 is completely removed from lookup table
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert_eq!(trie.lookup.len(), 2);
assert_eq!(router_event.worker_id, worker_id);
assert_eq!(router_event.event.event_id, 1);
if let KvCacheEventData::Stored(store_op) = &router_event.event.data {
assert_eq!(store_op.blocks.len(), 1);
assert_eq!(
store_op.blocks[0].tokens_hash,
compute_block_hash(b"test data")
);
assert_eq!(store_op.blocks[0].block_hash, ExternalSequenceBlockHash(0));
} else {
panic!("Expected KvCacheEventData::Stored");
}
}
// Verify that worker_0's hash is removed from the workers set
let block_1 = trie
.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.get(&ExternalSequenceBlockHash(100))
.unwrap();
assert_eq!(block_1.borrow().workers.len(), 2); // worker_1 and worker_2 remain
assert!(
!block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_2))
);
#[test]
fn test_radix_tree_default() {
setup();
let radix_tree: RadixTree = Default::default();
assert!(radix_tree.root.borrow().children.is_empty());
assert!(radix_tree.root.borrow().workers.is_empty());
assert!(radix_tree.lookup.is_empty());
// Verify that blocks with no remaining workers have their children cleared
// This tests the optimization where empty blocks clear their children
let block_2 = trie
.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.get(&ExternalSequenceBlockHash(200))
.unwrap();
assert_eq!(block_2.borrow().workers.len(), 1); // only worker_1
assert!(
block_2
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
);
// Verify match results no longer include worker_0
let result = trie
.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
)
.scores;
assert_eq!(result.len(), 2);
assert!(!result.contains_key(&WorkerWithDpRank::from_worker_id(worker_0)));
assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_1)));
assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_2)));
}
#[test]
fn test_overlap_scores_default() {
setup();
let overlap_scores: OverlapScores = Default::default();
assert!(overlap_scores.scores.is_empty());
// LocalKvIndexer tests
fn make_indexer_with_events(ids: &[u64]) -> LocalKvIndexer {
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
32,
);
{
let mut buffer = indexer.event_buffer.lock().unwrap();
for &id in ids {
buffer.push_back(RouterEvent::new(
0,
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Cleared,
dp_rank: 0,
},
));
}
}
indexer
}
#[tokio::test]
async fn test_dump_tree_as_events_round_trip() {
setup();
// Configuration
let kv_block_size = 32;
let num_shards = 2;
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
async fn returns_slice_within_range() {
let indexer = make_indexer_with_events(&[1, 2, 3, 4, 5]);
// Build a non-trivial indexer with events
let token1 = CancellationToken::new();
let mut original_indexer =
KvIndexerSharded::new(token1.clone(), num_shards, kv_block_size, metrics.clone());
// Helper to extract events from response
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
match resp {
WorkerKvQueryResponse::Events(e) => e,
WorkerKvQueryResponse::TreeDump(e) => e,
_ => panic!("Unexpected response type"),
}
};
let worker_0 = 0;
let worker_1 = 1;
let worker_2 = 2;
let get_ids = |events: Vec<RouterEvent>| -> Vec<u64> {
events.iter().map(|e| e.event.event_id).collect()
};
// Apply events to the original indexer
original_indexer
.apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
.await;
// Test get_events_in_id_range (buffer queries)
// Range is [start, end] inclusive
let result = indexer.get_events_in_id_range(Some(2), Some(4)).await;
let ids = get_ids(extract_events(result));
assert_eq!(ids, vec![2, 3, 4]); // inclusive range [2, 4]
original_indexer
.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.await;
original_indexer
.apply_event(create_store_event(
worker_1,
2,
vec![4, 5],
Some(ExternalSequenceBlockHash(100)),
))
.await;
let result = indexer.get_events_in_id_range(Some(2), Some(6)).await;
let ids = get_ids(extract_events(result));
assert_eq!(ids, vec![2, 3, 4, 5]); // clamp end to buffer max
original_indexer
.apply_event(create_store_event(worker_2, 3, vec![6, 7], None))
.await;
// start_id=0 is before buffer (first is 1), so should trigger tree dump
let result = indexer.get_events_in_id_range(Some(0), Some(4)).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
original_indexer
.apply_event(create_store_event(
worker_0,
4,
vec![4],
Some(ExternalSequenceBlockHash(100)),
))
.await;
let result = indexer.get_events_in_id_range(Some(3), Some(3)).await;
let ids = get_ids(extract_events(result));
assert_eq!(ids, vec![3]); // single element when start == end
// Allow some time for events to be processed
tokio::time::sleep(Duration::from_millis(50)).await;
// Invalid range: end < start
let result = indexer.get_events_in_id_range(Some(5), Some(2)).await;
assert!(matches!(result, WorkerKvQueryResponse::InvalidRange { .. }));
}
// Dump the original indexer
let dump1 = original_indexer.dump_events().await.unwrap();
println!("Dumped {} events", dump1.len());
#[tokio::test]
async fn test_get_events_in_id_range_all_cases() {
// Create indexer with small buffer (5 events max)
// This way older events will only be in the tree, not the buffer
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4, // block_size
Arc::new(KvIndexerMetrics::new_unregistered()),
5, // max_buffer_size - only keeps 5 most recent events
);
// Create a new indexer and apply all dumped events
let token2 = CancellationToken::new();
let mut reconstructed_indexer =
KvIndexerSharded::new(token2.clone(), num_shards, kv_block_size, metrics);
// Helper to create a test event
let make_event = |id: u64| {
RouterEvent::new(
0, // worker_id
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(id * 100),
tokens_hash: LocalBlockHash(id * 200),
}],
}),
dp_rank: 0,
},
)
};
for event in &dump1 {
reconstructed_indexer.apply_event(event.clone()).await;
// Add 10 events (IDs 5-14)
// Buffer will only keep the last 5: events 10-14
// Tree will have all blocks
for id in 5..15 {
indexer
.apply_event_with_buffer(make_event(id))
.await
.unwrap();
}
// Allow some time for events to be processed
tokio::time::sleep(Duration::from_millis(50)).await;
// Dump the reconstructed indexer
let dump2 = reconstructed_indexer.dump_events().await.unwrap();
// Sort both dumps for comparison (order might differ due to HashMap iteration and sharding)
let mut sorted_dump1 = dump1.clone();
let mut sorted_dump2 = dump2.clone();
// Wait for events to be processed by the tree
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Sort by (worker_id, tokens_hash, parent_hash)
let sort_key = |event: &RouterEvent| {
if let KvCacheEventData::Stored(ref data) = event.event.data {
(
event.worker_id,
data.blocks.first().map(|b| b.tokens_hash.0).unwrap_or(0),
data.parent_hash.map(|h| h.0).unwrap_or(0),
)
} else {
(event.worker_id, 0, 0)
// Helper to extract events from response
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
match resp {
WorkerKvQueryResponse::Events(e) => e,
WorkerKvQueryResponse::TreeDump(e) => e,
_ => panic!("Unexpected response type: {:?}", resp),
}
};
sorted_dump1.sort_by_key(sort_key);
sorted_dump2.sort_by_key(sort_key);
// Helper to extract event IDs from result
let get_ids = |events: Vec<RouterEvent>| -> Vec<u64> {
events.iter().map(|e| e.event.event_id).collect()
};
// Verify the dumps have the same length
// Verify buffer state: should have events 10-14 (last 5)
let buffer_events = indexer.get_all_events_in_buffer();
assert_eq!(
sorted_dump1.len(),
sorted_dump2.len(),
"Dumps have different lengths: {} vs {}",
sorted_dump1.len(),
sorted_dump2.len()
get_ids(buffer_events),
vec![10, 11, 12, 13, 14],
"Buffer should have events 10-14"
);
// Verify each event matches
for (i, (event1, event2)) in sorted_dump1.iter().zip(sorted_dump2.iter()).enumerate() {
assert_eq!(
event1.worker_id, event2.worker_id,
"Event {} worker_id mismatch",
i
);
if let (KvCacheEventData::Stored(data1), KvCacheEventData::Stored(data2)) =
(&event1.event.data, &event2.event.data)
{
assert_eq!(
data1.parent_hash, data2.parent_hash,
"Event {} parent_hash mismatch",
i
);
assert_eq!(
data1.blocks.len(),
data2.blocks.len(),
"Event {} blocks length mismatch",
i
);
for (j, (block1, block2)) in
data1.blocks.iter().zip(data2.blocks.iter()).enumerate()
{
assert_eq!(
block1.tokens_hash, block2.tokens_hash,
"Event {} block {} tokens_hash mismatch",
i, j
);
assert_eq!(
block1.block_hash, block2.block_hash,
"Event {} block {} block_hash mismatch",
i, j
);
}
} else {
panic!("Expected Stored events in both dumps");
}
}
// Also verify that both indexers produce the same match results
for test_seq in [
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
vec![LocalBlockHash(1), LocalBlockHash(4), LocalBlockHash(5)],
vec![LocalBlockHash(6), LocalBlockHash(7)],
vec![LocalBlockHash(1)],
] {
let scores1 = original_indexer
.find_matches(test_seq.clone())
.await
.unwrap();
let scores2 = reconstructed_indexer
.find_matches(test_seq.clone())
.await
.unwrap();
// Sort the scores to compare
let mut scores1_sorted: Vec<_> = scores1.scores.iter().collect();
let mut scores2_sorted: Vec<_> = scores2.scores.iter().collect();
scores1_sorted.sort_by_key(|(k, _)| *k);
scores2_sorted.sort_by_key(|(k, _)| *k);
assert_eq!(
scores1_sorted, scores2_sorted,
"Match scores differ for sequence {:?}",
test_seq
);
}
// Clean up
original_indexer.shutdown();
reconstructed_indexer.shutdown();
}
#[test]
fn test_increment_event_applied() {
let metrics = KvIndexerMetrics::new_unregistered();
// ========== BUFFER PATH TESTS (start_id >= first_buffered) ==========
// Range is [start, end] inclusive
metrics.increment_event_applied(METRIC_EVENT_STORED, Ok(()));
// Test: start_id within buffer, no end
let result = indexer.get_events_in_id_range(Some(11), None).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
metrics
.kv_cache_events_applied
.get_metric_with_label_values(&[METRIC_EVENT_STORED, METRIC_STATUS_OK])
.unwrap()
.get(),
1
get_ids(extract_events(result)),
vec![11, 12, 13, 14],
"start_id=11 (in buffer) should return [11, 14]"
);
metrics.increment_event_applied(
METRIC_EVENT_STORED,
Err(KvCacheEventError::ParentBlockNotFound),
);
// Test: start_id at buffer boundary
let result = indexer.get_events_in_id_range(Some(10), None).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
metrics
.kv_cache_events_applied
.get_metric_with_label_values(&[
METRIC_EVENT_STORED,
METRIC_STATUS_PARENT_NOT_FOUND
])
.unwrap()
.get(),
1
get_ids(extract_events(result)),
vec![10, 11, 12, 13, 14],
"start_id=10 (buffer start) should return [10, 14]"
);
metrics
.increment_event_applied(METRIC_EVENT_REMOVED, Err(KvCacheEventError::BlockNotFound));
// Test: both start and end within buffer (inclusive)
let result = indexer.get_events_in_id_range(Some(11), Some(13)).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
metrics
.kv_cache_events_applied
.get_metric_with_label_values(&[
METRIC_EVENT_REMOVED,
METRIC_STATUS_BLOCK_NOT_FOUND
])
.unwrap()
.get(),
1
get_ids(extract_events(result)),
vec![11, 12, 13],
"range [11, 13] inclusive should return 3 events"
);
}
#[test]
fn test_remove_worker_verifies_hash_removal() {
setup();
let mut trie = RadixTree::new();
let worker_0 = 0;
let worker_1 = 1;
let worker_2 = 2;
let result = indexer.get_events_in_id_range(Some(10), Some(14)).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
get_ids(extract_events(result)),
vec![10, 11, 12, 13, 14],
"range [10, 14] should return all buffer events"
);
// Add blocks for multiple workers
trie.apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_2, 0, vec![1, 4, 5], None))
.unwrap();
// ========== TREE DUMP PATH TESTS (range extends before buffer) ==========
// Note: Tree dumps return synthetic 0-indexed event IDs, so we just check
// that we get events back (the IDs won't match original IDs)
// Verify worker_0 has 3 blocks in lookup
// Test: (None, None) dumps entire tree
let result = indexer.get_events_in_id_range(None, None).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.len(),
3
extract_events(result).len(),
10,
"(None, None) should dump entire tree (10 events)"
);
// Verify that blocks have the correct workers
let block_1 = trie
.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.get(&ExternalSequenceBlockHash(100))
.unwrap();
assert_eq!(block_1.borrow().workers.len(), 3); // worker_0, worker_1, and worker_2 (all have hash 1)
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
// Test: (None, Some(_)) dumps entire tree
let result = indexer.get_events_in_id_range(None, Some(8)).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
assert_eq!(
extract_events(result).len(),
10,
"(None, Some(_)) dumps entire tree - end_id is ignored for tree dumps"
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
// Test: start_id before buffer triggers tree dump
let result = indexer.get_events_in_id_range(Some(7), None).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
assert_eq!(
extract_events(result).len(),
10,
"start_id=7 (before buffer) should dump entire tree"
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_2))
let result = indexer.get_events_in_id_range(Some(5), Some(12)).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
assert_eq!(
extract_events(result).len(),
10,
"range [5, 12] extending before buffer should dump entire tree"
);
// Remove worker_0
trie.remove_worker(worker_0);
// ========== EDGE CASES ==========
// Verify worker_0 is completely removed from lookup table
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
// Single element when start == end (inclusive range)
let result = indexer.get_events_in_id_range(Some(12), Some(12)).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
get_ids(extract_events(result)),
vec![12],
"start == end should return single event"
);
assert_eq!(trie.lookup.len(), 2);
// Verify that worker_0's hash is removed from the workers set
let block_1 = trie
.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.get(&ExternalSequenceBlockHash(100))
.unwrap();
assert_eq!(block_1.borrow().workers.len(), 2); // worker_1 and worker_2 remain
assert!(
!block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
);
// InvalidRange when start > end
let result = indexer.get_events_in_id_range(Some(15), Some(10)).await;
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_2))
matches!(result, WorkerKvQueryResponse::InvalidRange { .. }),
"start > end should return InvalidRange"
);
// Verify that blocks with no remaining workers have their children cleared
// This tests the optimization where empty blocks clear their children
let block_2 = trie
.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.get(&ExternalSequenceBlockHash(200))
.unwrap();
assert_eq!(block_2.borrow().workers.len(), 1); // only worker_1
// TooNew when start_id is beyond buffer
let result = indexer.get_events_in_id_range(Some(100), Some(200)).await;
assert!(
block_2
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
matches!(result, WorkerKvQueryResponse::TooNew { .. }),
"start_id beyond buffer should return TooNew"
);
// Verify match results no longer include worker_0
let result = trie
.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
)
.scores;
assert_eq!(result.len(), 2);
assert!(!result.contains_key(&WorkerWithDpRank::from_worker_id(worker_0)));
assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_1)));
assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_2)));
}
}
#[cfg(test)]
mod tests_local_indexer {
use super::*;
use crate::kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
use tokio::time;
use tokio_util::sync::CancellationToken;
fn setup() {
dynamo_runtime::logging::init();
}
fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
.iter()
.map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100),
})
.collect()
}
fn create_store_event(
worker_id: WorkerId,
event_id: u64,
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent,
blocks: make_blocks(hashes),
}),
dp_rank: 0,
},
}
// Request with end beyond buffer but valid start -> buffer returns what it has
let result = indexer.get_events_in_id_range(Some(12), Some(100)).await;
assert!(matches!(result, WorkerKvQueryResponse::Events(_)));
assert_eq!(
get_ids(extract_events(result)),
vec![12, 13, 14],
"range with end beyond buffer should return available buffer events"
);
}
#[tokio::test]
......@@ -3752,49 +3616,4 @@ mod tests_local_indexer {
_ => panic!("Expected Stored event"),
}
}
#[tokio::test]
async fn test_gap_detection_per_worker() {
setup();
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let indexer = KvIndexer::new(token.clone(), 4, metrics);
let worker_a: WorkerId = 100;
let worker_b: WorkerId = 200;
let event_tx = indexer.event_sender();
// Worker A: events 1, 2, 3 (no gap)
for id in 1..=3 {
let event = create_store_event(worker_a, id, vec![id], None);
event_tx.send(event).await.unwrap();
}
// Worker B: events 1, then 5 (gap of 2, 3, 4)
let event_b1 = create_store_event(worker_b, 1, vec![10], None);
event_tx.send(event_b1).await.unwrap();
let event_b5 = create_store_event(worker_b, 5, vec![50], None);
event_tx.send(event_b5).await.unwrap();
// Give time for events to be processed
time::sleep(Duration::from_millis(20)).await;
// Verify each worker has correct last_received_event_id
let last_ids = indexer.get_last_received_event_ids().await.unwrap();
assert_eq!(
last_ids.get(&worker_a),
Some(&3),
"Worker A should have last_id = 3 (no gap)"
);
assert_eq!(
last_ids.get(&worker_b),
Some(&5),
"Worker B should have last_id = 5 (despite gap)"
);
// Cleanup
token.cancel();
}
}
......@@ -36,6 +36,12 @@ use crate::kv_router::{
};
use dynamo_runtime::config::environment_names::nats as env_nats;
// Error handling configuration for ZMQ operations
const INITIAL_BACKOFF_MS: u64 = 10;
const MAX_BACKOFF_MS: u64 = 5000;
const MAX_CONSECUTIVE_ERRORS: u32 = 10;
const MAX_BACKOFF_EXPONENT: u32 = 8; // Cap at 2^8 = 256x multiplier to prevent overflow
// -------------------------------------------------------------------------
// KV Event Publishers -----------------------------------------------------
// -------------------------------------------------------------------------
......@@ -125,15 +131,14 @@ impl KvEventPublisher {
// Infer worker_id from component's connection
let worker_id = component.drt().connection_id();
let component_name = component.name();
tracing::info!(
worker_id,
component = component.name(),
"Initializing KvEventPublisher for worker {worker_id} in component {component}"
"Initializing KvEventPublisher for worker {worker_id} in component {component_name}"
);
if enable_local_indexer {
tracing::info!(
"LocalKvIndexer enabled for worker {worker_id} in component {component}"
"LocalKvIndexer enabled for worker {worker_id} in component {component_name}"
);
}
......@@ -321,27 +326,25 @@ async fn start_worker_kv_query_service(
let mut subscriber = match component.subscribe(&subject).await {
Ok(sub) => sub,
Err(e) => {
tracing::error!("Failed to subscribe to {}: {}", subject, e);
return; // No ? because function doesn't return Result
tracing::error!(
"Query service failed to subscribe for worker {worker_id} on subject {subject}: {e}"
);
return;
}
};
tracing::debug!(
"Query service on worker {} listening on NATS subject: {}",
worker_id,
subject
);
tracing::info!("Query service listening on NATS for worker {worker_id} on subject {subject}");
// Receive query request from router, retrieve event(s) from LocalKvIndexer, return response
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::info!("Router-Worker communication channel received cancellation signal");
tracing::info!("Query service received cancellation signal for worker {worker_id}");
break;
}
msg = subscriber.next() => {
let Some(msg) = msg else {
tracing::debug!("Router-Worker stream ended.");
tracing::warn!("Query service NATS stream ended for worker {worker_id}");
break;
};
......@@ -349,12 +352,12 @@ async fn start_worker_kv_query_service(
let request: WorkerKvQueryRequest = match serde_json::from_slice(&msg.payload) {
Ok(request) => request,
Err(e) => {
tracing::error!("Failed to deserialize WorkerKvQueryRequest: {}", e);
tracing::error!("Failed to deserialize WorkerKvQueryRequest for worker {worker_id}: {e}");
continue;
}
};
tracing::debug!("Received WorkerKvQueryRequest: {:?}", request);
tracing::debug!("Received query request for worker {worker_id}: {request:?}");
// Query events based on optional start/end ids
let response = local_indexer
......@@ -366,7 +369,7 @@ async fn start_worker_kv_query_service(
let payload = match serde_json::to_vec(&response) {
Ok(p) => p,
Err(e) => {
tracing::error!("Failed to serialize response: {}", e);
tracing::error!("Failed to serialize response for worker {worker_id}: {e}");
continue;
}
};
......@@ -377,22 +380,14 @@ async fn start_worker_kv_query_service(
.kv_router_nats_publish(reply_subject.to_string(), payload.into())
.await
{
tracing::error!("Failed to send reply: {}", e);
tracing::error!("Failed to send reply for worker {worker_id}: {e}");
}
}
}
}
}
}
// Error handling configuration for ZMQ operations
const INITIAL_BACKOFF_MS: u64 = 10;
const MAX_BACKOFF_MS: u64 = 5000;
const MAX_CONSECUTIVE_ERRORS: u32 = 10;
const MAX_BACKOFF_EXPONENT: u32 = 8; // Cap at 2^8 = 256x multiplier to prevent overflow
/// Calculate exponential backoff duration based on consecutive error count
fn calculate_backoff_ms(consecutive_errors: u32) -> u64 {
std::cmp::min(
......@@ -481,7 +476,7 @@ pub async fn start_zmq_listener(
let mut frames: Vec<Vec<u8>> = msg.into_vec().into_iter().map(|frame| frame.to_vec()).collect();
if frames.len() != 3 {
tracing::warn!(expected=3, actual=%frames.len(), "Received unexpected ZMQ frame count");
tracing::warn!("Received unexpected ZMQ frame count: expected 3, actual {}", frames.len());
continue;
}
......@@ -490,7 +485,7 @@ pub async fn start_zmq_listener(
let seq_bytes = frames.pop().unwrap();
if seq_bytes.len() != 8 {
tracing::warn!(expected=8, actual=%seq_bytes.len(), "Invalid sequence number byte length");
tracing::warn!("Invalid sequence number byte length: expected 8, actual {}", seq_bytes.len());
continue;
}
......@@ -500,7 +495,7 @@ pub async fn start_zmq_listener(
let batch_result = rmps::from_slice::<KvEventBatch>(&payload);
let Ok(batch) = batch_result else {
let e = batch_result.unwrap_err();
tracing::warn!(error=%e, "Failed to decode KVEventBatch msgpack");
tracing::warn!("Failed to decode KVEventBatch msgpack: {e}");
continue;
};
......@@ -1821,16 +1816,10 @@ mod tests_startup_helpers {
"Router should only see 1 shared block (not the new block from event_2)"
);
// === STEP 4 & 5: Recovery - Query last received event IDs and fetch missed events ===
// Step 4a: Router queries its last received event ID per worker
let last_ids = router_indexer.get_last_received_event_ids().await.unwrap();
let last_known_id = last_ids.get(&worker_1_id).copied().unwrap_or(0);
assert_eq!(
last_known_id, 1,
"Router should have last_received_event_id = 1 for worker (only event_1 was forwarded)"
);
// Step 4b: Query worker's local indexer for events after last_known_id
// === STEP 4 & 5: Recovery - Query worker's local indexer for missed events ===
// In practice, the subscriber detects gaps and triggers recovery automatically.
// Here we simulate that by querying for events after event_id=1.
let last_known_id = 1u64; // Router only received event_1
let response = local_indexer_1
.get_events_in_id_range(Some(last_known_id + 1), None)
.await;
......@@ -1868,14 +1857,6 @@ mod tests_startup_helpers {
"Router should now see both blocks after recovery"
);
// assert: Router's last_received_event_id is updated after recovery
let last_ids_after = router_indexer.get_last_received_event_ids().await.unwrap();
assert_eq!(
last_ids_after.get(&worker_1_id),
Some(&2),
"Router should have last_received_event_id = 2 after recovery"
);
token.cancel();
}
}
......@@ -2043,8 +2024,6 @@ mod test_integration_publisher {
#[tokio::test]
#[ignore] // Mark as ignored as requested, because CI's integrations still don't have NATS
async fn test_kvstats_prometheus_gauge_updates() {
use crate::kv_router::publisher::kvstats;
// Test that publish() updates Prometheus gauges correctly using real Component
let publisher = WorkerMetricsPublisher::new().unwrap();
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Scoring functions for the KV router.
use super::protocols::LoadMetrics;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
/// is cleaned (not optional)
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Endpoint {
pub name: String,
pub subject: String,
pub data: LoadMetrics,
}
impl Endpoint {
pub fn worker_id(&self) -> u64 {
u64::from_str_radix(
self.subject
.split("-")
.last()
.expect("invalid subject")
.to_string()
.as_str(),
16,
)
.expect("invalid worker id")
}
}
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
pub struct ProcessedEndpoints {
pub endpoints: HashMap<u64, Endpoint>,
pub load_avg: f64,
pub load_std: f64,
}
impl ProcessedEndpoints {
pub fn new(endpoints: Vec<Endpoint>) -> Self {
// compute some basic statistics
let load_values: Vec<f64> = endpoints
.iter()
.map(|endpoint| endpoint.data.kv_active_blocks() as f64)
.collect();
let load_avg = load_values.iter().copied().sum::<f64>() / load_values.len() as f64;
let variance = load_values
.iter()
.map(|&x| (x - load_avg).powi(2))
.sum::<f64>()
/ load_values.len() as f64;
let load_std = variance.sqrt();
let endpoints = endpoints.into_iter().map(|e| (e.worker_id(), e)).collect();
ProcessedEndpoints {
endpoints,
load_avg,
load_std,
}
}
pub fn worker_ids(&self) -> Vec<u64> {
self.endpoints.keys().copied().collect()
}
pub fn active_blocks(&self) -> HashMap<u64, usize> {
self.endpoints
.iter()
.map(|(&worker_id, endpoint)| (worker_id, endpoint.data.kv_active_blocks() as usize))
.collect()
}
}
......@@ -32,6 +32,10 @@ const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10;
const CHECK_INTERVAL_BASE: Duration = Duration::from_secs(1);
const CHECK_INTERVAL_JITTER_MS: i64 = 100;
// Worker query retry configuration
const WORKER_QUERY_MAX_RETRIES: u32 = 8;
const WORKER_QUERY_INITIAL_BACKOFF_MS: u64 = 200;
// ============================================================================
// Local KvIndexer-based Recovery
// ============================================================================
......@@ -65,8 +69,7 @@ pub async fn recover_from_all_workers(
// Skip workers without local indexer
if !worker_query_client.has_local_indexer(worker_id) {
tracing::debug!(
worker_id,
"Skipping recovery - worker does not have local indexer enabled"
"Skipping recovery - worker {worker_id} does not have local indexer enabled"
);
continue;
}
......@@ -101,10 +104,7 @@ pub async fn recover_from_all_workers(
// Log summary
if total_recovered > 0 || failed_workers > 0 {
tracing::info!(
total_recovered,
successful_workers,
failed_workers,
"Startup recovery completed"
"Startup recovery completed: {total_recovered} events recovered from {successful_workers} workers, {failed_workers} workers failed"
);
}
......@@ -133,35 +133,61 @@ pub async fn recover_from_worker(
) -> Result<usize> {
if worker_query_client.has_local_indexer(worker_id) {
tracing::debug!(
worker_id,
start_event_id = ?start_event_id,
end_event_id = ?end_event_id,
"Attempting recovery from worker"
"Attempting recovery from worker {worker_id}, start_event_id: {start_event_id:?}, end_event_id: {end_event_id:?}"
);
} else {
tracing::warn!(
worker_id,
"Worker does not have local indexer enabled, skipping recovery"
);
tracing::warn!("Worker {worker_id} does not have local indexer enabled, skipping recovery");
return Ok(0);
}
// Query worker for events in range
let response = worker_query_client
.query_worker(worker_id, start_event_id, end_event_id)
.await?;
// Query worker for events in range, with retry logic for transient failures
// (e.g., worker's query service not yet re-subscribed after NATS restart)
let mut response = None;
let mut last_error = None;
for attempt in 0..WORKER_QUERY_MAX_RETRIES {
match worker_query_client
.query_worker(worker_id, start_event_id, end_event_id)
.await
{
Ok(resp) => {
if attempt > 0 {
tracing::info!("Worker {worker_id} query succeeded after retry {attempt}");
}
response = Some(resp);
break;
}
Err(e) => {
last_error = Some(e);
if attempt < WORKER_QUERY_MAX_RETRIES - 1 {
let backoff_ms = WORKER_QUERY_INITIAL_BACKOFF_MS * 2_u64.pow(attempt);
tracing::warn!(
"Worker {worker_id} query failed on attempt {attempt}, retrying after {backoff_ms}ms"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
}
}
}
}
let response = match response {
Some(r) => r,
None => return Err(last_error.unwrap_or_else(|| anyhow::anyhow!("No response"))),
};
// Handle response variants
let events = match response {
WorkerKvQueryResponse::Events(events) => {
tracing::debug!(worker_id, count = events.len(), "Got buffered events");
tracing::debug!(
"Got {count} buffered events from worker {worker_id}",
count = events.len()
);
events
}
WorkerKvQueryResponse::TreeDump(events) => {
tracing::info!(
worker_id,
count = events.len(),
"Got tree dump (range too old or unspecified)"
"Got tree dump from worker {worker_id} (range too old or unspecified), count: {count}",
count = events.len()
);
events
}
......@@ -171,11 +197,7 @@ pub async fn recover_from_worker(
newest_available,
} => {
tracing::warn!(
worker_id,
?requested_start,
?requested_end,
newest_available,
"Requested range is newer than available data"
"Worker {worker_id} requested range is newer than available data: requested_start: {requested_start:?}, requested_end: {requested_end:?}, newest_available: {newest_available}"
);
return Ok(0);
}
......@@ -188,24 +210,21 @@ pub async fn recover_from_worker(
if events_count == 0 {
tracing::debug!(
worker_id,
start_event_id = ?start_event_id,
"No events to recover from worker"
"No events to recover from worker {worker_id}, start_event_id: {start_event_id:?}"
);
return Ok(0);
}
tracing::info!(
worker_id,
start_event_id = ?start_event_id,
events_count,
"Recovered {events_count} events from worker"
"Recovered {events_count} events from worker {worker_id}, start_event_id: {start_event_id:?}"
);
// Apply recovered events to the indexer
for event in events {
if let Err(e) = event_tx.send(event).await {
tracing::error!(worker_id, error = %e, "Failed to send recovered event to indexer");
tracing::error!(
"Failed to send recovered event to indexer for worker {worker_id}: {e}"
);
anyhow::bail!("Failed to send recovered event: {e}");
}
}
......@@ -528,8 +547,7 @@ pub async fn start_kv_router_background(
};
tracing::warn!(
worker_id = worker_id,
"DISCOVERY: Generate endpoint instance removed, removing worker"
"DISCOVERY: Generate endpoint instance removed, removing worker {worker_id}"
);
if let Err(e) = remove_worker_tx.send(worker_id).await {
......@@ -611,8 +629,7 @@ pub async fn start_kv_router_background(
let consumer_to_delete = router_instance_id.to_string();
tracing::info!(
router_instance_id = router_instance_id,
"DISCOVERY: Router instance removed, attempting to delete orphaned consumer: {consumer_to_delete}"
"DISCOVERY: Router instance {router_instance_id} removed, attempting to delete orphaned consumer: {consumer_to_delete}"
);
// Delete the consumer (allow race condition if multiple routers try to delete)
......@@ -653,11 +670,11 @@ pub async fn start_kv_router_background_nats_core(
) -> Result<()> {
// Subscribe to KV events using NATS Core
let mut subscriber = component.subscribe(KV_EVENT_SUBJECT).await?;
let kv_event_subject = format!("{}.{}", component.subject(), KV_EVENT_SUBJECT);
tracing::info!(
"KV Router using NATS Core subscription on subject: {}.{} (local_indexer mode)",
component.subject(),
KV_EVENT_SUBJECT
subject = %kv_event_subject,
"KV Router using NATS Core subscription (local_indexer mode)"
);
// Get the generate endpoint and watch for instance events (add/remove)
......@@ -696,8 +713,7 @@ pub async fn start_kv_router_background_nats_core(
let worker_id = _instance.instance_id();
tracing::info!(
worker_id = worker_id,
"DISCOVERY: Worker added, dumping local indexer into router"
"DISCOVERY: Worker {worker_id} added, dumping local indexer into router"
);
// Query worker's local indexer and dump all events
......@@ -712,24 +728,19 @@ pub async fn start_kv_router_background_nats_core(
{
Ok(count) => {
tracing::info!(
worker_id = worker_id,
events_recovered = count,
"Successfully dumped worker's local indexer"
"Successfully dumped worker {worker_id}'s local indexer, recovered {count} events"
);
}
Err(e) => {
tracing::warn!(
worker_id = worker_id,
error = %e,
"Failed to dump worker's local indexer (may not have local indexer enabled)"
"Failed to dump worker {worker_id}'s local indexer (may not have local indexer enabled): {e}"
);
}
}
}
DiscoveryEvent::Removed(worker_id) => {
tracing::warn!(
worker_id = worker_id,
"DISCOVERY: Worker removed, removing from router indexer"
"DISCOVERY: Worker {worker_id} removed, removing from router indexer"
);
if let Err(e) = remove_worker_tx.send(worker_id).await {
......@@ -760,12 +771,9 @@ pub async fn start_kv_router_background_nats_core(
// Gap detected - recover missing events before processing current
let gap_start = last_id + 1;
let gap_end = event_id - 1;
let gap_size = gap_end - gap_start + 1;
tracing::warn!(
worker_id,
gap_start,
gap_end,
gap_size = gap_end - gap_start + 1,
"Event ID gap detected, recovering events [{gap_start}, {gap_end}]"
"Event ID gap detected for worker {worker_id}, recovering events [{gap_start}, {gap_end}], gap_size: {gap_size}"
);
// Note: While recovering, new events may queue in the NATS subscriber's
......@@ -779,11 +787,7 @@ pub async fn start_kv_router_background_nats_core(
&kv_events_tx,
).await {
tracing::error!(
worker_id,
gap_start,
gap_end,
error = %e,
"Failed to recover gap events; proceeding with current event anyway"
"Failed to recover gap events for worker {worker_id} (gap_start: {gap_start}, gap_end: {gap_end}); proceeding with current event anyway: {e}"
);
// Note: If recovery fails, we still apply the current event.
// The tree will have a gap, but it's better than dropping the event.
......
......@@ -305,6 +305,8 @@ class NatsServer(ManagedProcess):
self.port = port
self.use_random_port = use_random_port # Track if we allocated the port
self._request = request # Store for restart
self._timeout = timeout
data_dir = tempfile.mkdtemp(prefix="nats_")
command = [
"nats-server",
......@@ -336,6 +338,39 @@ class NatsServer(ManagedProcess):
return super().__exit__(exc_type, exc_val, exc_tb)
def stop(self):
"""Stop the NATS server for restart. Does not release port or clean up fully."""
_logger.info(f"Stopping NATS server on port {self.port}")
self._terminate_process_group()
if self.proc:
try:
self.proc.wait(timeout=10)
except Exception as e:
_logger.warning(f"Error waiting for NATS process to stop: {e}")
self.proc = None
def start(self):
"""Restart a stopped NATS server with fresh state."""
_logger.info(f"Starting NATS server on port {self.port} with fresh state")
# Clean up old data directory and create fresh one
if self.data_dir:
shutil.rmtree(self.data_dir, ignore_errors=True)
self.data_dir = tempfile.mkdtemp(prefix="nats_")
# Rebuild command with new data_dir
self.command = [
"nats-server",
"-js",
"--trace",
"--store_dir",
self.data_dir,
"-p",
str(self.port),
]
self._start_process()
self._check_ports(self._timeout)
class SharedManagedProcess:
"""Base class for ManagedProcess with file-based reference counting for multi-process sharing."""
......
......@@ -8,7 +8,7 @@ import os
import random
import string
import time
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
import aiohttp
import nats
......@@ -16,6 +16,9 @@ import nats
from dynamo._core import DistributedRuntime, KvPushRouter, KvRouterConfig
from tests.utils.managed_process import ManagedProcess
if TYPE_CHECKING:
from tests.conftest import NatsServer
logger = logging.getLogger(__name__)
NUM_REQUESTS = 100
......@@ -220,7 +223,7 @@ async def wait_for_frontend_ready(
logger.debug(f"Error checking models endpoint: {e}")
# Wait before next poll
await asyncio.sleep(2)
await asyncio.sleep(1)
# Phase 2: Wait for chat completions pipeline to be ready
logger.info("Waiting for chat completions pipeline to be built...")
......@@ -253,7 +256,7 @@ async def wait_for_frontend_ready(
logger.debug(f"Error testing chat completions: {e}")
# Wait before next poll
await asyncio.sleep(2)
await asyncio.sleep(1)
async def wait_for_workers_ready(
......@@ -1321,6 +1324,9 @@ def _test_router_indexers_sync(
model_name: str,
num_workers: int,
store_backend: str = "etcd",
request_plane: str = "nats",
test_nats_interruption: bool = False,
nats_server: Optional["NatsServer"] = None,
):
"""Test that two KV routers have synchronized indexer states after processing requests.
......@@ -1333,16 +1339,30 @@ def _test_router_indexers_sync(
This validates that the snapshot mechanism works and routers can sync state from NATS.
When test_nats_interruption=True (requires nats_server and request_plane="tcp"):
- After first router sends 25 requests, NATS is stopped
- 10 more requests sent while NATS is down (stored locally by local indexer)
- NATS restarted (fresh state), recovery mechanism re-syncs
- Second router starts and sends 25 requests
- NATS stopped again, 10 more requests sent
- NATS restarted, 5 more requests sent
- Verify both routers converge to same state
Args:
engine_workers: Backend worker instance ({MockerProcess, VLLMProcess, TRTLLMProcess}) (already initialized with __enter__())
block_size: Block size for KV cache
model_name: Model name to use for requests
num_workers: Expected number of workers
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
request_plane: Request plane to use ("nats" or "tcp"). Defaults to "nats".
test_nats_interruption: If True, test NATS interruption recovery. Defaults to False.
nats_server: NatsServer instance for stop/start (required if test_nats_interruption=True).
Raises:
AssertionError: If router states don't synchronize correctly or snapshot is missing
"""
if test_nats_interruption and nats_server is None:
raise ValueError("nats_server is required when test_nats_interruption=True")
# Use async to manage the test flow
async def test_sync():
......@@ -1386,7 +1406,7 @@ def _test_router_indexers_sync(
# Create first runtime and endpoint for router 1
logger.info("Creating first KV router with its own runtime")
runtime1 = get_runtime(store_backend)
runtime1 = get_runtime(store_backend, request_plane)
namespace1 = runtime1.namespace(engine_workers.namespace)
component1 = namespace1.component(engine_workers.component_name)
endpoint1 = component1.endpoint("generate")
......@@ -1413,13 +1433,35 @@ def _test_router_indexers_sync(
successful1 == 25
), f"Expected 25 successful requests to router 1, got {successful1}"
# NATS interruption test: stop NATS, send requests, restart
if test_nats_interruption:
await asyncio.sleep(1)
assert nats_server is not None # Validated at function entry
logger.info("=== NATS INTERRUPTION TEST: Phase 1 ===")
logger.info("Stopping NATS server")
nats_server.stop()
logger.info("Sending 10 requests while NATS is down (via TCP)")
successful_offline1 = await send_requests_to_router(
kv_push_router1, 10, "Router 1 (NATS down)", endpoint1
)
assert (
successful_offline1 == 10
), f"Expected 10 successful requests while NATS down, got {successful_offline1}"
logger.info("Restarting NATS server (fresh state)")
nats_server.start()
await asyncio.sleep(5)
# Wait for a second before creating the second router
logger.info("Waiting for 1 second before creating second router")
await asyncio.sleep(1)
# Create second runtime and endpoint for router 2
logger.info("Creating second KV router with its own runtime")
runtime2 = get_runtime(store_backend)
runtime2 = get_runtime(store_backend, request_plane)
namespace2 = runtime2.namespace(engine_workers.namespace)
component2 = namespace2.component(engine_workers.component_name)
endpoint2 = component2.endpoint("generate")
......@@ -1439,51 +1481,87 @@ def _test_router_indexers_sync(
successful2 == 25
), f"Expected 25 successful requests to router 2, got {successful2}"
# NATS interruption test: stop NATS again, send requests, restart, send more
if test_nats_interruption:
await asyncio.sleep(1)
assert nats_server is not None # Validated at function entry
logger.info("=== NATS INTERRUPTION TEST: Phase 2 ===")
logger.info("Stopping NATS server")
nats_server.stop()
logger.info("Sending 10 requests while NATS is down (via TCP)")
successful_offline2 = await send_requests_to_router(
kv_push_router2, 10, "Router 2 (NATS down)", endpoint2
)
assert (
successful_offline2 == 10
), f"Expected 10 successful requests while NATS down, got {successful_offline2}"
logger.info("Restarting NATS server (fresh state)")
nats_server.start()
await asyncio.sleep(5)
logger.info("Sending 5 more requests after NATS recovery")
successful_recovery = await send_requests_to_router(
kv_push_router1, 5, "Router 1 (post-recovery)", endpoint1
)
assert (
successful_recovery == 5
), f"Expected 5 successful requests post-recovery, got {successful_recovery}"
# Wait for all requests to complete (they should already be complete from gather)
# Wait another 1 second for internal synchronization
logger.info("Waiting for final synchronization")
await asyncio.sleep(1)
# Verify NATS object store bucket was created with snapshot
# Mirror the Rust bucket naming logic from subscriber.rs:
# component.subject() -> "namespace.{ns}.component.{comp}"
# then slugify (convert dots to dashes, lowercase, etc) and append "-radix-bucket"
component_subject = f"namespace.{engine_workers.namespace}.component.{engine_workers.component_name}"
slugified = component_subject.lower().replace(".", "-").replace("_", "-")
expected_bucket = f"{slugified}-radix-bucket"
expected_file = "radix-state"
logger.info(f"Verifying NATS object store bucket exists: {expected_bucket}")
snapshot_verified = False
try:
# Connect to NATS and check object store
nc = await nats.connect("nats://localhost:4222")
try:
js = nc.jetstream()
obj_store = await js.object_store(expected_bucket)
# Skip this verification for NATS interruption test since NATS restarts fresh
# (local indexer recovery doesn't rely on NATS persistence)
if not test_nats_interruption:
# Mirror the Rust bucket naming logic from subscriber.rs:
# component.subject() -> "namespace.{ns}.component.{comp}"
# then slugify (convert dots to dashes, lowercase, etc) and append "-radix-bucket"
component_subject = f"namespace.{engine_workers.namespace}.component.{engine_workers.component_name}"
slugified = component_subject.lower().replace(".", "-").replace("_", "-")
expected_bucket = f"{slugified}-radix-bucket"
expected_file = "radix-state"
# Try to get the expected file
logger.info(f"Verifying NATS object store bucket exists: {expected_bucket}")
snapshot_verified = False
try:
# Connect to NATS and check object store
nc = await nats.connect("nats://localhost:4222")
try:
result = await obj_store.get(expected_file)
logger.info(
f"✓ Snapshot file '{expected_file}' found in bucket '{expected_bucket}' "
f"(size: {len(result.data) if result.data else 0} bytes)"
)
snapshot_verified = True
except Exception as e:
logger.error(
f"Snapshot file '{expected_file}' not found in bucket '{expected_bucket}': {e}"
)
finally:
await nc.close()
except Exception as e:
logger.error(f"Error checking NATS object store: {e}")
js = nc.jetstream()
obj_store = await js.object_store(expected_bucket)
# Assert that snapshot was created (threshold=20, sent 25 requests)
if not snapshot_verified:
assert False, (
f"Expected snapshot to be created in bucket '{expected_bucket}' with file '{expected_file}'. "
f"Router sent 25 requests with snapshot_threshold=20, so snapshot should have been triggered."
# Try to get the expected file
try:
result = await obj_store.get(expected_file)
logger.info(
f"✓ Snapshot file '{expected_file}' found in bucket '{expected_bucket}' "
f"(size: {len(result.data) if result.data else 0} bytes)"
)
snapshot_verified = True
except Exception as e:
logger.error(
f"Snapshot file '{expected_file}' not found in bucket '{expected_bucket}': {e}"
)
finally:
await nc.close()
except Exception as e:
logger.error(f"Error checking NATS object store: {e}")
# Assert that snapshot was created (threshold=20, sent 25 requests)
if not snapshot_verified:
assert False, (
f"Expected snapshot to be created in bucket '{expected_bucket}' with file '{expected_file}'. "
f"Router sent 25 requests with snapshot_threshold=20, so snapshot should have been triggered."
)
else:
logger.info(
"Skipping NATS object store verification (NATS was restarted fresh for interruption test)"
)
# Dump states from both routers
......@@ -1562,25 +1640,31 @@ def _test_router_indexers_sync(
logger.info("Successfully verified that both router states are equal")
# Verify NATS consumers are created (while routers are still alive)
logger.info("Verifying NATS consumers exist for both routers")
component_subject = f"namespace.{engine_workers.namespace}.component.{engine_workers.component_name}"
slugified = component_subject.lower().replace(".", "-").replace("_", "-")
stream_name = f"{slugified}-kv-events"
# Skip this for NATS interruption test since it uses local indexer (NATS Core, not JetStream)
if not test_nats_interruption:
logger.info("Verifying NATS consumers exist for both routers")
component_subject = f"namespace.{engine_workers.namespace}.component.{engine_workers.component_name}"
slugified = component_subject.lower().replace(".", "-").replace("_", "-")
stream_name = f"{slugified}-kv-events"
nc = await nats.connect("nats://localhost:4222")
try:
js = nc.jetstream()
consumer_infos = await js.consumers_info(stream_name)
consumer_names = [info.name for info in consumer_infos]
logger.info(f"Found {len(consumer_names)} consumers: {consumer_names}")
assert len(consumer_names) == 2, (
f"Expected 2 durable consumers (one per router), "
f"found {len(consumer_names)}: {consumer_names}"
nc = await nats.connect("nats://localhost:4222")
try:
js = nc.jetstream()
consumer_infos = await js.consumers_info(stream_name)
consumer_names = [info.name for info in consumer_infos]
logger.info(f"Found {len(consumer_names)} consumers: {consumer_names}")
assert len(consumer_names) == 2, (
f"Expected 2 durable consumers (one per router), "
f"found {len(consumer_names)}: {consumer_names}"
)
logger.info("✓ Verified 2 durable consumers exist (one per router)")
finally:
await nc.close()
else:
logger.info(
"Skipping NATS consumers verification (local indexer uses NATS Core, not JetStream)"
)
logger.info("✓ Verified 2 durable consumers exist (one per router)")
finally:
await nc.close()
# Run the async test
asyncio.run(test_sync())
......
......@@ -2,10 +2,12 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import os
from contextlib import nullcontext
from typing import Any, Dict, Optional
import pytest
from tests.conftest import EtcdServer, NatsServer
from tests.router.common import ( # utilities
_test_busy_threshold_endpoint,
_test_python_router_bindings,
......@@ -476,53 +478,82 @@ def test_kv_push_router_bindings(
mockers.__exit__(None, None, None)
@pytest.mark.parametrize("store_backend", ["etcd", "file"])
# NO @pytest.mark.parallel - nats_core variant stops/restarts NATS
@pytest.mark.parametrize(
"store_backend,use_nats_core,request_plane",
[
("etcd", False, "nats"), # JetStream mode
# ("etcd", True, "tcp"), # ignored, needs unconditional nats_client
("file", False, "nats"), # File backend
],
ids=["jetstream", "file"], # "nats_core" commented out to match commented test case
)
def test_indexers_sync(
request,
runtime_services_session,
predownload_tokenizers,
file_storage_backend,
store_backend,
use_nats_core,
request_plane,
):
"""
Test that two KV routers have synchronized indexer states after processing requests.
This test verifies that both routers converge to the same internal state.
Tests with both etcd and file storage backends.
"""
# runtime_services starts etcd and nats
logger.info(f"Starting indexers sync test with {store_backend} storage backend")
# Create mocker args dictionary
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
try:
# Start mocker instances
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=NUM_MOCKERS,
store_backend=store_backend,
)
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
# Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
_test_router_indexers_sync(
engine_workers=mockers,
block_size=BLOCK_SIZE,
model_name=MODEL_NAME,
num_workers=NUM_MOCKERS,
store_backend=store_backend,
)
logger.info("Indexers sync test completed successfully")
Tests with three configurations:
- jetstream: etcd backend, JetStream for KV events, NATS request plane
- nats_core: etcd backend, local indexer with NATS Core, TCP request plane
(includes NATS interruption/recovery testing)
- file: file backend, JetStream for KV events, NATS request plane
"""
logger.info(
f"Starting indexers sync test: store_backend={store_backend}, "
f"use_nats_core={use_nats_core}, request_plane={request_plane}"
)
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
# Start NATS manually (needed for all variants - KV event sync)
with NatsServer(request) as nats_server:
# Start etcd if needed
etcd_ctx = EtcdServer(request) if store_backend == "etcd" else nullcontext()
with etcd_ctx:
# Create mocker args dictionary
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"enable_local_indexer": use_nats_core,
}
try:
# Start mocker instances
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=NUM_MOCKERS,
store_backend=store_backend,
request_plane=request_plane,
)
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
# Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
_test_router_indexers_sync(
engine_workers=mockers,
block_size=BLOCK_SIZE,
model_name=MODEL_NAME,
num_workers=NUM_MOCKERS,
store_backend=store_backend,
request_plane=request_plane,
test_nats_interruption=use_nats_core,
nats_server=nats_server if use_nats_core else None,
)
logger.info("Indexers sync test completed successfully")
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.parallel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment