Unverified Commit 7ff5e0be authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Concurrent KV event consumer (#7293)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent 0f8e1a9e
...@@ -18,6 +18,74 @@ use crate::indexer::pruning::{BlockEntry, PruneConfig, PruneManager}; ...@@ -18,6 +18,74 @@ use crate::indexer::pruning::{BlockEntry, PruneConfig, PruneManager};
use crate::protocols::*; use crate::protocols::*;
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
fn stored_block_entries(event: &RouterEvent) -> Option<Vec<BlockEntry>> {
let KvCacheEventData::Stored(ref store_data) = event.event.data else {
return None;
};
let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank);
Some(
store_data
.blocks
.iter()
.enumerate()
.map(|(idx, block)| BlockEntry {
key: block.block_hash,
worker,
seq_position: idx,
})
.collect(),
)
}
fn apply_event_with_prune_tracking(
trie: &mut RadixTree,
event: RouterEvent,
metrics: &KvIndexerMetrics,
prune_manager: &mut Option<PruneManager<BlockEntry>>,
prune_tx: &mpsc::Sender<()>,
) {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let event_id = event.event.event_id;
let worker_id = event.worker_id;
let event_for_prune = prune_manager.is_some().then(|| event.clone());
let result = trie.apply_event(event);
let result_is_ok = result.is_ok();
let tree_size = trie.current_size();
tracing::trace!(
"Applied KV event to global radix tree: event_type={event_type}, event_id={event_id}, worker_id={worker_id}, success={result_is_ok}, global_radix_tree_size={tree_size}"
);
metrics.increment_event_applied(event_type, result);
let Some(pm) = prune_manager.as_mut() else {
return;
};
if !result_is_ok {
return;
}
let Some(ref event) = event_for_prune else {
return;
};
let Some(block_entries) = stored_block_entries(event) else {
return;
};
pm.insert(block_entries);
let Some(ref pc) = pm.prune_config else {
return;
};
let current_size = trie.current_size();
if current_size > pc.max_tree_size {
tracing::info!(
"Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning",
current_size,
pc.max_tree_size
);
let _ = prune_tx.try_send(());
}
}
/// The KV Indexer, managing the KV store and handling events and match requests. /// The KV Indexer, managing the KV store and handling events and match requests.
#[derive(Clone)] #[derive(Clone)]
pub struct KvIndexer { pub struct KvIndexer {
...@@ -64,7 +132,7 @@ impl KvIndexer { ...@@ -64,7 +132,7 @@ impl KvIndexer {
metrics: Arc<KvIndexerMetrics>, metrics: Arc<KvIndexerMetrics>,
prune_config: Option<PruneConfig>, prune_config: Option<PruneConfig>,
) -> Self { ) -> Self {
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048); let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(16384);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128); let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16); let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (remove_worker_dp_rank_tx, remove_worker_dp_rank_rx) = let (remove_worker_dp_rank_tx, remove_worker_dp_rank_rx) =
...@@ -151,49 +219,26 @@ impl KvIndexer { ...@@ -151,49 +219,26 @@ impl KvIndexer {
} }
Some(event) = event_rx.recv() => { Some(event) = event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data); apply_event_with_prune_tracking(
let event_id = event.event.event_id; &mut trie,
let worker_id = event.worker_id; event,
// Only clone if we need the event for prune_manager afterward &metrics,
let event_for_prune = prune_manager.is_some().then(|| event.clone()); &mut prune_manager,
let result = trie.apply_event(event); &prune_tx,
let result_is_ok = result.is_ok();
let tree_size = trie.current_size();
tracing::trace!(
"Applied KV event to global radix tree: event_type={event_type}, event_id={event_id}, worker_id={worker_id}, success={result_is_ok}, global_radix_tree_size={tree_size}"
); );
metrics.increment_event_applied(event_type, result);
// Track blocks in PruneManager if TTL is enabled and event was stored successfully
let Some(ref mut pm) = prune_manager else { continue };
if !result_is_ok { continue };
let Some(ref event) = event_for_prune else { continue };
let KvCacheEventData::Stored(ref store_data) = event.event.data else { continue };
let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank);
let block_entries: Vec<BlockEntry> = store_data.blocks.iter().enumerate().map(|(idx, block)| {
BlockEntry {
key: block.block_hash,
worker,
seq_position: idx,
} }
}).collect();
pm.insert(block_entries);
// Check if we need to prune due to tree size Some(dump_req) = dump_rx.recv() => {
let Some(ref pc) = pm.prune_config else { continue }; // Flush pending events so tree is consistent with buffer
let current_size = trie.current_size(); while let Ok(event) = event_rx.try_recv() {
if current_size > pc.max_tree_size { apply_event_with_prune_tracking(
tracing::info!( &mut trie,
"Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning", event,
current_size, &metrics,
pc.max_tree_size &mut prune_manager,
&prune_tx,
); );
let _ = prune_tx.try_send(());
}
} }
Some(dump_req) = dump_rx.recv() => {
let events = trie.dump_tree_as_events(); let events = trie.dump_tree_as_events();
let _ = dump_req.resp.send(events); let _ = dump_req.resp.send(events);
} }
......
...@@ -53,6 +53,15 @@ impl LocalKvIndexer { ...@@ -53,6 +53,15 @@ impl LocalKvIndexer {
buffer.iter().cloned().collect() buffer.iter().cloned().collect()
} }
/// Build a tree dump response with the given `last_event_id`.
async fn tree_dump_response(&self, last_event_id: u64) -> WorkerKvQueryResponse {
let events = self.dump_events().await.unwrap_or_default();
WorkerKvQueryResponse::TreeDump {
events,
last_event_id,
}
}
/// Query events by ID range, returning events in `[start_id, end_id]` (both inclusive). /// Query events by ID range, returning events in `[start_id, end_id]` (both inclusive).
/// ///
/// ### Arguments /// ### Arguments
...@@ -63,7 +72,7 @@ impl LocalKvIndexer { ...@@ -63,7 +72,7 @@ impl LocalKvIndexer {
/// ### Returns /// ### Returns
/// ///
/// - `Events`: Buffered events with original IDs (when range is within buffer) /// - `Events`: Buffered events with original IDs (when range is within buffer)
/// - `TreeDump`: Full tree dump with synthetic IDs (when range is too old or unspecified) /// - `TreeDump`: Full tree dump with synthetic IDs and the worker's latest real event ID (when range is too old or unspecified)
/// - `TooNew`: Error when requested range is newer than available data /// - `TooNew`: Error when requested range is newer than available data
/// - `InvalidRange`: Error when end_id < start_id /// - `InvalidRange`: Error when end_id < start_id
pub async fn get_events_in_id_range( pub async fn get_events_in_id_range(
...@@ -98,8 +107,7 @@ impl LocalKvIndexer { ...@@ -98,8 +107,7 @@ impl LocalKvIndexer {
// If no start_id specified, dump entire tree // If no start_id specified, dump entire tree
if start_id.is_none() { if start_id.is_none() {
tracing::debug!("No start_id specified, dumping entire tree"); tracing::debug!("No start_id specified, dumping entire tree");
let events = self.dump_events().await.unwrap_or_default(); return self.tree_dump_response(last_id.unwrap_or(0)).await;
return WorkerKvQueryResponse::TreeDump(events);
} }
let start_id = start_id.unwrap(); let start_id = start_id.unwrap();
...@@ -108,8 +116,7 @@ impl LocalKvIndexer { ...@@ -108,8 +116,7 @@ impl LocalKvIndexer {
// Check for empty buffer // Check for empty buffer
let Some(first_buffered) = first_id else { let Some(first_buffered) = first_id else {
tracing::debug!("Buffer empty, dumping entire tree"); tracing::debug!("Buffer empty, dumping entire tree");
let events = self.dump_events().await.unwrap_or_default(); return self.tree_dump_response(0).await;
return WorkerKvQueryResponse::TreeDump(events);
}; };
let last_buffered = last_id.unwrap(); let last_buffered = last_id.unwrap();
...@@ -134,8 +141,7 @@ impl LocalKvIndexer { ...@@ -134,8 +141,7 @@ impl LocalKvIndexer {
first_buffered, first_buffered,
"Requested start_id is older than buffer, dumping entire tree" "Requested start_id is older than buffer, dumping entire tree"
); );
let events = self.dump_events().await.unwrap_or_default(); return self.tree_dump_response(last_buffered).await;
return WorkerKvQueryResponse::TreeDump(events);
} }
// Serve from buffer // Serve from buffer
...@@ -196,17 +202,20 @@ impl LocalKvIndexer { ...@@ -196,17 +202,20 @@ impl LocalKvIndexer {
/// Apply event with buffering. /// Apply event with buffering.
/// ///
/// This records the event in the buffer and forwards it to the underlying indexer. /// This forwards the event to the underlying indexer and records it on success.
pub async fn apply_event_with_buffer(&self, event: RouterEvent) -> Result<(), KvRouterError> { pub async fn apply_event_with_buffer(&self, event: RouterEvent) -> Result<(), KvRouterError> {
// Record in buffer
self.record_event(event.clone());
// Forward to underlying indexer // Forward to underlying indexer
self.indexer let result = self
.indexer
.event_sender() .event_sender()
.send(event) .send(event.clone())
.await .await
.map_err(|_| KvRouterError::IndexerOffline) .map_err(|_| KvRouterError::IndexerOffline);
if result.is_ok() {
self.record_event(event);
}
result
} }
/// Clear the event buffer. /// Clear the event buffer.
......
...@@ -1941,7 +1941,7 @@ async fn test_local_indexer_slice_within_range() { ...@@ -1941,7 +1941,7 @@ async fn test_local_indexer_slice_within_range() {
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> { let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
match resp { match resp {
WorkerKvQueryResponse::Events(e) => e, WorkerKvQueryResponse::Events(e) => e,
WorkerKvQueryResponse::TreeDump(e) => e, WorkerKvQueryResponse::TreeDump { events: e, .. } => e,
_ => panic!("Unexpected response type"), _ => panic!("Unexpected response type"),
} }
}; };
...@@ -1962,7 +1962,7 @@ async fn test_local_indexer_slice_within_range() { ...@@ -1962,7 +1962,7 @@ async fn test_local_indexer_slice_within_range() {
// start_id=0 is before buffer (first is 1), so should trigger tree dump // start_id=0 is before buffer (first is 1), so should trigger tree dump
let result = indexer.get_events_in_id_range(Some(0), Some(4)).await; let result = indexer.get_events_in_id_range(Some(0), Some(4)).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_))); assert!(matches!(result, WorkerKvQueryResponse::TreeDump { .. }));
let result = indexer.get_events_in_id_range(Some(3), Some(3)).await; let result = indexer.get_events_in_id_range(Some(3), Some(3)).await;
let ids = get_ids(extract_events(result)); let ids = get_ids(extract_events(result));
...@@ -2016,7 +2016,7 @@ async fn test_local_indexer_get_events_in_id_range_all_cases() { ...@@ -2016,7 +2016,7 @@ async fn test_local_indexer_get_events_in_id_range_all_cases() {
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> { let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
match resp { match resp {
WorkerKvQueryResponse::Events(e) => e, WorkerKvQueryResponse::Events(e) => e,
WorkerKvQueryResponse::TreeDump(e) => e, WorkerKvQueryResponse::TreeDump { events: e, .. } => e,
_ => panic!("Unexpected response type: {:?}", resp), _ => panic!("Unexpected response type: {:?}", resp),
} }
}; };
...@@ -2038,11 +2038,11 @@ async fn test_local_indexer_get_events_in_id_range_all_cases() { ...@@ -2038,11 +2038,11 @@ async fn test_local_indexer_get_events_in_id_range_all_cases() {
// Tree dump path tests // Tree dump path tests
let result = indexer.get_events_in_id_range(None, None).await; let result = indexer.get_events_in_id_range(None, None).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_))); assert!(matches!(&result, WorkerKvQueryResponse::TreeDump { .. }));
assert_eq!(extract_events(result).len(), 10); assert_eq!(extract_events(result).len(), 10);
let result = indexer.get_events_in_id_range(Some(7), None).await; let result = indexer.get_events_in_id_range(Some(7), None).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_))); assert!(matches!(result, WorkerKvQueryResponse::TreeDump { .. }));
// Edge cases // Edge cases
let result = indexer.get_events_in_id_range(Some(15), Some(10)).await; let result = indexer.get_events_in_id_range(Some(15), Some(10)).await;
...@@ -2052,6 +2052,98 @@ async fn test_local_indexer_get_events_in_id_range_all_cases() { ...@@ -2052,6 +2052,98 @@ async fn test_local_indexer_get_events_in_id_range_all_cases() {
assert!(matches!(result, WorkerKvQueryResponse::TooNew { .. })); assert!(matches!(result, WorkerKvQueryResponse::TooNew { .. }));
} }
#[tokio::test]
async fn test_tree_dump_includes_last_event_id() {
// Create indexer with small buffer (5 events max)
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
let make_event = |id: u64| {
RouterEvent::new(
0,
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(id * 100),
tokens_hash: LocalBlockHash(id * 200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
},
)
};
// Add 10 events (IDs 5-14), buffer keeps last 5: events 10-14
for id in 5..15 {
indexer
.apply_event_with_buffer(make_event(id))
.await
.unwrap();
}
indexer.flush().await;
// Request with start_id=None -> tree dump should include last_event_id=14
let result = indexer.get_events_in_id_range(None, None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 14,
"last_event_id should be the buffer's newest event ID"
);
assert!(!events.is_empty(), "tree dump should contain events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
// Request with start_id older than buffer -> tree dump should include last_event_id=14
let result = indexer.get_events_in_id_range(Some(7), None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 14,
"last_event_id should be the buffer's newest event ID"
);
assert!(!events.is_empty(), "tree dump should contain events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
// Empty buffer case: create a fresh indexer with no events
let empty_indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
let result = empty_indexer.get_events_in_id_range(None, None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 0,
"empty buffer should return last_event_id=0"
);
assert!(events.is_empty(), "empty indexer should have no events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
}
#[tokio::test] #[tokio::test]
async fn test_local_indexer_buffer_and_serialization() { async fn test_local_indexer_buffer_and_serialization() {
let worker_id = 42u64; let worker_id = 42u64;
...@@ -2099,6 +2191,51 @@ async fn test_local_indexer_buffer_and_serialization() { ...@@ -2099,6 +2191,51 @@ async fn test_local_indexer_buffer_and_serialization() {
assert_eq!(events[0].worker_id, worker_id); assert_eq!(events[0].worker_id, worker_id);
} }
#[tokio::test]
async fn test_local_indexer_does_not_buffer_failed_send() {
let local_indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
let test_event = RouterEvent::new(
7,
KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
},
);
let event_tx = local_indexer.event_sender();
local_indexer.shutdown();
event_tx.closed().await;
let result = local_indexer.apply_event_with_buffer(test_event).await;
assert!(matches!(result, Err(KvRouterError::IndexerOffline)));
assert_eq!(local_indexer.buffer_len(), 0);
match local_indexer.get_events_in_id_range(None, None).await {
WorkerKvQueryResponse::TreeDump {
events,
last_event_id,
} => {
assert!(events.is_empty());
assert_eq!(last_event_id, 0);
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
}
#[tokio::test] #[tokio::test]
#[apply(indexer_template)] #[apply(indexer_template)]
async fn test_apply_events_idempotent(variant: &str) { async fn test_apply_events_idempotent(variant: &str) {
......
...@@ -56,8 +56,13 @@ pub struct WorkerKvQueryRequest { ...@@ -56,8 +56,13 @@ pub struct WorkerKvQueryRequest {
pub enum WorkerKvQueryResponse { pub enum WorkerKvQueryResponse {
/// Events served from the circular buffer (with original event IDs) /// Events served from the circular buffer (with original event IDs)
Events(Vec<RouterEvent>), Events(Vec<RouterEvent>),
/// Full tree dump (with synthetic 0-indexed event IDs) /// Full tree dump (with synthetic 0-indexed event IDs).
TreeDump(Vec<RouterEvent>), /// Includes `last_event_id`: the newest real event ID in the worker's buffer
/// at the time of the dump, so the caller can set its tracking cursor correctly.
TreeDump {
events: Vec<RouterEvent>,
last_event_id: u64,
},
/// Requested range is newer than available data /// Requested range is newer than available data
TooNew { TooNew {
requested_start: Option<u64>, requested_start: Option<u64>,
......
...@@ -2024,7 +2024,7 @@ mod tests_startup_helpers { ...@@ -2024,7 +2024,7 @@ mod tests_startup_helpers {
.await; .await;
let missed_events = match response { let missed_events = match response {
crate::kv_router::indexer::WorkerKvQueryResponse::Events(e) => e, crate::kv_router::indexer::WorkerKvQueryResponse::Events(e) => e,
crate::kv_router::indexer::WorkerKvQueryResponse::TreeDump(e) => e, crate::kv_router::indexer::WorkerKvQueryResponse::TreeDump { events: e, .. } => e,
crate::kv_router::indexer::WorkerKvQueryResponse::Error(message) => { crate::kv_router::indexer::WorkerKvQueryResponse::Error(message) => {
panic!("Unexpected error response: {message}") panic!("Unexpected error response: {message}")
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use crate::kv_router::{ use crate::kv_router::{
Indexer, KV_EVENT_SUBJECT, KvRouterConfig, Indexer, KV_EVENT_SUBJECT, KvRouterConfig, protocols::RouterEvent,
protocols::{DpRank, RouterEvent, WorkerId},
worker_query::WorkerQueryClient, worker_query::WorkerQueryClient,
}; };
use anyhow::Result; use anyhow::Result;
...@@ -23,9 +20,6 @@ use dynamo_runtime::{ ...@@ -23,9 +20,6 @@ use dynamo_runtime::{
/// - On worker Added: dumps worker's local indexer into router /// - On worker Added: dumps worker's local indexer into router
/// - On worker Removed: removes worker from router indexer /// - On worker Removed: removes worker from router indexer
/// ///
/// This function first recovers state from all currently registered workers before
/// spawning the background task, ensuring the router is ready before returning.
///
/// This is appropriate when workers have local indexers enabled. /// This is appropriate when workers have local indexers enabled.
async fn start_kv_router_background_event_plane( async fn start_kv_router_background_event_plane(
component: Component, component: Component,
...@@ -47,7 +41,9 @@ async fn start_kv_router_background_event_plane( ...@@ -47,7 +41,9 @@ async fn start_kv_router_background_event_plane(
// before recovery fetches the initial dump from workers. // before recovery fetches the initial dump from workers.
tokio::time::sleep(std::time::Duration::from_millis(100)).await; tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let worker_query_client = WorkerQueryClient::spawn(component.clone(), indexer.clone()).await?; // WorkerQueryClient handles its own discovery loop for lifecycle + initial recovery.
// No blocking wait — recovery happens asynchronously as endpoints are discovered.
let worker_query_client = WorkerQueryClient::spawn(component.clone(), indexer).await?;
let kv_event_subject = format!( let kv_event_subject = format!(
"namespace.{}.component.{}.{}", "namespace.{}.component.{}.{}",
component.namespace().name(), component.namespace().name(),
...@@ -71,10 +67,6 @@ async fn start_kv_router_background_event_plane( ...@@ -71,10 +67,6 @@ async fn start_kv_router_background_event_plane(
} }
tokio::spawn(async move { tokio::spawn(async move {
// Track last received event ID per (worker, dp_rank) for gap detection
// Each dp_rank has its own monotonic event ID sequence
let mut last_event_ids: HashMap<(WorkerId, DpRank), u64> = HashMap::new();
loop { loop {
tokio::select! { tokio::select! {
biased; biased;
...@@ -94,47 +86,19 @@ async fn start_kv_router_background_event_plane( ...@@ -94,47 +86,19 @@ async fn start_kv_router_background_event_plane(
} }
}; };
let worker_id = event.worker_id;
let dp_rank = event.event.dp_rank;
let event_id = event.event.event_id;
let event_key = (worker_id, dp_rank);
tracing::trace!( tracing::trace!(
"Received event from publisher {} (seq {})", "Received event from publisher {} (seq {})",
envelope.publisher_id, envelope.publisher_id,
envelope.sequence envelope.sequence
); );
// Gap detection: check if event ID is monotonically increasing per (worker, dp_rank) tracing::trace!(
// Note: event_id <= last_id is duplicate/out-of-order, apply anyway (idempotent) "Forwarding live event to recovery coordinator for worker {} dp_rank {} event_id {}",
if let Some(&last_id) = last_event_ids.get(&event_key) event.worker_id,
&& event_id > last_id + 1 event.event.dp_rank,
{ event.event.event_id
let gap_start = last_id + 1;
let gap_end = event_id - 1;
let gap_size = gap_end - gap_start + 1;
tracing::warn!(
"Event ID gap detected for worker {worker_id} dp_rank {dp_rank}, recovering events [{gap_start}, {gap_end}], gap_size: {gap_size}"
);
if let Err(e) = worker_query_client
.recover_from_worker(worker_id, dp_rank, Some(gap_start), Some(gap_end))
.await
{
tracing::error!(
"Failed to recover gap events for worker {worker_id} dp_rank {dp_rank} (gap_start: {gap_start}, gap_end: {gap_end}); proceeding with current event anyway: {e}"
); );
} worker_query_client.handle_live_event(event).await;
}
// Update last seen event ID (use max to handle out-of-order)
last_event_ids
.entry(event_key)
.and_modify(|id| *id = (*id).max(event_id))
.or_insert(event_id);
// Forward the RouterEvent to the indexer
indexer.apply_event(event).await;
} }
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::{HashMap, HashSet}; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use async_trait::async_trait;
use dashmap::DashMap; use dashmap::DashMap;
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryQuery}; use dynamo_runtime::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryQuery};
use dynamo_runtime::pipeline::{ use dynamo_runtime::pipeline::{
AsyncEngine, AsyncEngineContextProvider, ManyOut, PushRouter, ResponseStream, RouterMode, AsyncEngine, AsyncEngineContextProvider, ManyOut, PushRouter, ResponseStream, RouterMode,
SingleIn, async_trait, network::Ingress, SingleIn, network::Ingress,
}; };
use dynamo_runtime::protocols::maybe_error::MaybeError; use dynamo_runtime::protocols::maybe_error::MaybeError;
use dynamo_runtime::stream; use dynamo_runtime::stream;
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use futures::StreamExt; use futures::StreamExt;
use tokio::sync::{Mutex, Semaphore};
use crate::kv_router::Indexer; use crate::kv_router::Indexer;
use crate::kv_router::indexer::{LocalKvIndexer, WorkerKvQueryRequest, WorkerKvQueryResponse}; use crate::kv_router::indexer::{LocalKvIndexer, WorkerKvQueryRequest, WorkerKvQueryResponse};
use crate::kv_router::protocols::{DpRank, WorkerId}; use crate::kv_router::protocols::{DpRank, KvCacheEventData, RouterEvent, WorkerId};
use crate::kv_router::worker_kv_indexer_query_endpoint; use crate::kv_router::worker_kv_indexer_query_endpoint;
// Recovery retry configuration // Recovery retry configuration
const RECOVERY_MAX_RETRIES: u32 = 8; const RECOVERY_MAX_RETRIES: u32 = 8;
const RECOVERY_INITIAL_BACKOFF_MS: u64 = 200; const RECOVERY_INITIAL_BACKOFF_MS: u64 = 200;
const RECOVERY_CONCURRENCY_LIMIT: usize = 16;
/// Prefix for worker KV indexer query endpoint names. /// Prefix for worker KV indexer query endpoint names.
const QUERY_ENDPOINT_PREFIX: &str = "worker_kv_indexer_query_dp"; const QUERY_ENDPOINT_PREFIX: &str = "worker_kv_indexer_query_dp";
type RecoveryKey = (WorkerId, DpRank);
#[derive(Clone, Copy, Debug, Default)]
enum RankCursor {
#[default]
NeedsRestore,
Live(u64),
InvalidatedByBarrier(Option<u64>),
}
#[derive(Debug, Default)]
struct RankState {
cursor: RankCursor,
max_seen_live_id: Option<u64>,
recovery_inflight: bool,
}
impl RankState {
fn last_applied_id(&self) -> Option<u64> {
match self.cursor {
RankCursor::NeedsRestore => None,
RankCursor::Live(event_id) => Some(event_id),
RankCursor::InvalidatedByBarrier(last_applied_id) => last_applied_id,
}
}
fn observe_live_id(&mut self, event_id: u64) {
self.max_seen_live_id = Some(self.max_seen_live_id.unwrap_or(0).max(event_id));
}
fn clear_max_seen_if_caught_up(&mut self, last_applied_id: u64) {
if self
.max_seen_live_id
.is_some_and(|max_seen| max_seen <= last_applied_id)
{
self.max_seen_live_id = None;
}
}
}
#[derive(Debug, Default)]
struct WorkerState {
epoch: u64,
ranks: HashMap<DpRank, RankState>,
}
#[async_trait]
trait WorkerQueryTransport: Send + Sync {
async fn query_worker(
&self,
worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse>;
}
struct RuntimeWorkerQueryTransport {
component: Component,
routers: DashMap<DpRank, Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>>,
}
impl RuntimeWorkerQueryTransport {
fn new(component: Component) -> Self {
Self {
component,
routers: DashMap::new(),
}
}
async fn get_router_for_dp_rank(
&self,
dp_rank: DpRank,
) -> Result<Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>> {
if let Some(router) = self.routers.get(&dp_rank) {
return Ok(router.clone());
}
let endpoint_name = worker_kv_indexer_query_endpoint(dp_rank);
let endpoint = self.component.endpoint(&endpoint_name);
let client = endpoint.client().await?;
let router = Arc::new(
PushRouter::from_client_no_fault_detection(client, RouterMode::RoundRobin).await?,
);
Ok(self.routers.entry(dp_rank).or_insert(router).clone())
}
}
#[async_trait]
impl WorkerQueryTransport for RuntimeWorkerQueryTransport {
async fn query_worker(
&self,
worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> {
let router = self.get_router_for_dp_rank(dp_rank).await?;
let request = WorkerKvQueryRequest {
worker_id,
start_event_id,
end_event_id,
};
let mut stream = router
.direct(SingleIn::new(request), worker_id)
.await
.with_context(|| {
format!("Failed to send worker KV query to worker {worker_id} dp_rank {dp_rank}")
})?;
let response = stream
.next()
.await
.context("Worker KV query returned an empty response stream")?;
if let Some(err) = response.err() {
return Err(err).context("Worker KV query response error");
}
Ok(response)
}
}
/// Router-side client for querying worker local KV indexers. /// Router-side client for querying worker local KV indexers.
/// ///
/// Discovers query endpoints via `ComponentEndpoints` discovery, filtering for /// Discovers query endpoints via `ComponentEndpoints` discovery, filtering for
/// the `worker_kv_indexer_query_dp{N}` name pattern. Recovers each /// the `worker_kv_indexer_query_dp{N}` name pattern. Coordinates restore and
/// `(worker_id, dp_rank)` individually as it appears in discovery. /// gap recovery at the worker level while still querying each `(worker_id,
/// dp_rank)` endpoint independently.
/// ///
/// Also handles worker lifecycle (add/remove) by tracking known endpoints and /// Also handles worker lifecycle (add/remove) by tracking known endpoints and
/// sending removal events to the router indexer when all dp_ranks for a worker /// sending removal events to the router indexer when all dp_ranks for a worker
/// disappear. /// disappear.
pub struct WorkerQueryClient { pub struct WorkerQueryClient {
component: Component, component: Component,
/// Routers keyed by dp_rank — each dp_rank has its own endpoint. Created lazily. transport: Arc<dyn WorkerQueryTransport>,
routers: Arc<DashMap<DpRank, Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>>>,
/// Indexer for applying recovered events and worker removals. /// Indexer for applying recovered events and worker removals.
indexer: Indexer, indexer: Indexer,
worker_states: DashMap<WorkerId, Arc<Mutex<WorkerState>>>,
recovery_semaphore: Arc<Semaphore>,
} }
impl WorkerQueryClient { impl WorkerQueryClient {
fn new(
component: Component,
indexer: Indexer,
transport: Arc<dyn WorkerQueryTransport>,
) -> Arc<Self> {
Arc::new(Self {
component,
transport,
indexer,
worker_states: DashMap::new(),
recovery_semaphore: Arc::new(Semaphore::new(RECOVERY_CONCURRENCY_LIMIT)),
})
}
/// Create a new WorkerQueryClient and spawn its background discovery loop. /// Create a new WorkerQueryClient and spawn its background discovery loop.
/// ///
/// The background loop watches `ComponentEndpoints` discovery for query endpoints, /// The background loop watches `ComponentEndpoints` discovery for query endpoints,
/// recovers each `(worker_id, dp_rank)` as it appears, and sends worker removal /// recovers each `(worker_id, dp_rank)` as it appears, and sends worker removal
/// events when all dp_ranks for a worker disappear. /// events when all dp_ranks for a worker disappear.
pub async fn spawn(component: Component, indexer: Indexer) -> Result<Arc<Self>> { pub async fn spawn(component: Component, indexer: Indexer) -> Result<Arc<Self>> {
let client = Arc::new(Self { let transport = Arc::new(RuntimeWorkerQueryTransport::new(component.clone()));
component: component.clone(), let client = Self::new(component.clone(), indexer, transport);
routers: Arc::new(DashMap::new()),
indexer,
});
let client_bg = client.clone(); let client_bg = client.clone();
let cancel_token = component.drt().primary_token(); let cancel_token = component.drt().primary_token();
...@@ -71,9 +212,9 @@ impl WorkerQueryClient { ...@@ -71,9 +212,9 @@ impl WorkerQueryClient {
Ok(client) Ok(client)
} }
/// Background loop: watches ComponentEndpoints, recovers per (worker_id, dp_rank). /// Background loop: watches ComponentEndpoints and schedules worker-coordinated recovery.
async fn run_discovery_loop( async fn run_discovery_loop(
&self, self: Arc<Self>,
cancel_token: tokio_util::sync::CancellationToken, cancel_token: tokio_util::sync::CancellationToken,
) -> Result<()> { ) -> Result<()> {
let discovery = self.component.drt().discovery(); let discovery = self.component.drt().discovery();
...@@ -87,9 +228,6 @@ impl WorkerQueryClient { ...@@ -87,9 +228,6 @@ impl WorkerQueryClient {
) )
.await?; .await?;
// Track known (worker_id, dp_rank) pairs to detect removals
let mut known: HashMap<WorkerId, HashSet<DpRank>> = HashMap::new();
while let Some(result) = stream.next().await { while let Some(result) = stream.next().await {
if cancel_token.is_cancelled() { if cancel_token.is_cancelled() {
break; break;
...@@ -108,45 +246,13 @@ impl WorkerQueryClient { ...@@ -108,45 +246,13 @@ impl WorkerQueryClient {
let Some((worker_id, dp_rank)) = Self::parse_query_endpoint(&instance) else { let Some((worker_id, dp_rank)) = Self::parse_query_endpoint(&instance) else {
continue; continue;
}; };
self.handle_discovered_worker(worker_id, dp_rank).await;
if known.entry(worker_id).or_default().insert(dp_rank) {
tracing::info!(
"WorkerQueryClient: discovered worker {worker_id} dp_rank {dp_rank}, recovering"
);
match self
.recover_from_worker(worker_id, dp_rank, None, None)
.await
{
Ok(count) => {
if count > 0 {
tracing::info!(
"Recovered {count} events from worker {worker_id} dp_rank {dp_rank}"
);
}
}
Err(e) => {
tracing::warn!(
"Failed to recover from worker {worker_id} dp_rank {dp_rank}: {e}"
);
}
}
}
} }
DiscoveryEvent::Removed(id) => { DiscoveryEvent::Removed(id) => {
let Some((worker_id, dp_rank)) = Self::parse_instance_id(&id) else { let Some((worker_id, dp_rank)) = Self::parse_instance_id(&id) else {
continue; continue;
}; };
self.handle_removed_worker_dp(worker_id, dp_rank).await;
if let Some(dp_ranks) = known.get_mut(&worker_id) {
dp_ranks.remove(&dp_rank);
if dp_ranks.is_empty() {
known.remove(&worker_id);
tracing::warn!(
"WorkerQueryClient: all dp_ranks gone for worker {worker_id}, removing"
);
self.indexer.remove_worker(worker_id).await;
}
}
} }
} }
} }
...@@ -177,76 +283,362 @@ impl WorkerQueryClient { ...@@ -177,76 +283,362 @@ impl WorkerQueryClient {
Some((eid.instance_id, dp_rank)) Some((eid.instance_id, dp_rank))
} }
/// Get or create a router for the specified dp_rank's endpoint. fn get_or_create_worker_state(&self, worker_id: WorkerId) -> Arc<Mutex<WorkerState>> {
async fn get_router_for_dp_rank( self.worker_states
&self, .entry(worker_id)
.or_insert_with(|| Arc::new(Mutex::new(WorkerState::default())))
.clone()
}
pub(crate) async fn handle_discovered_worker(
self: &Arc<Self>,
worker_id: WorkerId,
dp_rank: DpRank, dp_rank: DpRank,
) -> Result<Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>> { ) {
if let Some(router) = self.routers.get(&dp_rank) { let worker_state = self.get_or_create_worker_state(worker_id);
return Ok(router.clone()); let spawn = {
let mut worker_state = worker_state.lock().await;
let rank_state = worker_state.ranks.entry(dp_rank).or_default();
if matches!(rank_state.cursor, RankCursor::NeedsRestore)
&& !rank_state.recovery_inflight
{
tracing::info!(
"WorkerQueryClient: discovered worker {worker_id} dp_rank {dp_rank}, scheduling restore"
);
rank_state.recovery_inflight = true;
Some(worker_state.epoch)
} else {
None
} }
};
let endpoint_name = worker_kv_indexer_query_endpoint(dp_rank); if let Some(epoch) = spawn {
let endpoint = self.component.endpoint(&endpoint_name); self.spawn_recovery_task((worker_id, dp_rank), epoch, None, None);
let client = endpoint.client().await?; }
let router = Arc::new( }
PushRouter::from_client_no_fault_detection(client, RouterMode::RoundRobin).await?,
pub(crate) async fn handle_removed_worker_dp(&self, worker_id: WorkerId, dp_rank: DpRank) {
let Some(worker_state) = self
.worker_states
.get(&worker_id)
.map(|entry| entry.clone())
else {
return;
};
let should_remove_worker = {
let mut worker_state = worker_state.lock().await;
if worker_state.ranks.remove(&dp_rank).is_none() {
return;
}
worker_state.ranks.is_empty()
};
if should_remove_worker {
tracing::warn!("WorkerQueryClient: all dp_ranks gone for worker {worker_id}, removing");
self.worker_states.remove(&worker_id);
self.indexer.remove_worker(worker_id).await;
}
}
async fn apply_worker_clear_locked(&self, worker_state: &mut WorkerState, event: RouterEvent) {
let worker_id = event.worker_id;
let clear_dp_rank = event.event.dp_rank;
let clear_event_id = event.event.event_id;
worker_state.epoch += 1;
for rank_state in worker_state.ranks.values_mut() {
rank_state.cursor = RankCursor::InvalidatedByBarrier(rank_state.last_applied_id());
rank_state.max_seen_live_id = None;
rank_state.recovery_inflight = false;
}
let rank_state = worker_state.ranks.entry(clear_dp_rank).or_default();
rank_state.cursor = RankCursor::Live(clear_event_id);
tracing::info!(
"Applying clear barrier for worker {worker_id}; invalidating recovery across {} dp_ranks",
worker_state.ranks.len()
); );
self.indexer.apply_event(event).await;
}
pub(crate) async fn handle_live_event(self: &Arc<Self>, event: RouterEvent) {
let worker_id = event.worker_id;
let dp_rank = event.event.dp_rank;
let event_id = event.event.event_id;
let key = (worker_id, dp_rank);
Ok(self enum Action {
.routers ApplyDirect,
.entry(dp_rank) SpawnFullRestore { epoch: u64 },
.or_insert(router) SpawnIncremental { epoch: u64, start_event_id: u64 },
.value()
.clone())
} }
/// Query a specific worker's local KV indexer for a specific dp_rank. let action = {
pub async fn query_worker( let worker_state = self.get_or_create_worker_state(worker_id);
&self, let mut worker_state = worker_state.lock().await;
worker_id: WorkerId, let rank_state = worker_state.ranks.entry(dp_rank).or_default();
dp_rank: DpRank,
// `Cleared` is worker-wide in the indexer, so it bypasses the per-rank gap logic and
// instead installs a worker barrier that invalidates any inflight recovery.
if matches!(&event.event.data, KvCacheEventData::Cleared) {
if rank_state
.last_applied_id()
.is_none_or(|last_applied_id| event_id > last_applied_id)
{
self.apply_worker_clear_locked(&mut worker_state, event)
.await;
}
// Already applied the event, so no further action needed.
return;
} else {
match rank_state.cursor {
// We have never established a cursor for this rank, so live traffic only tells
// us how far ahead the stream has moved while a full restore catches up.
RankCursor::NeedsRestore => {
rank_state.observe_live_id(event_id);
if !rank_state.recovery_inflight {
rank_state.recovery_inflight = true;
Action::SpawnFullRestore {
epoch: worker_state.epoch,
}
} else {
// A recovery is already in flight. Nothing to do.
return;
}
}
// Normal steady-state path: apply contiguous events directly, but coalesce any
// gap into a single recovery pass using `max_seen_live_id` as the high-water mark.
RankCursor::Live(last_applied_id) => {
if event_id <= last_applied_id {
// We've already applied this event. Nothing to do.
return;
} else if rank_state.recovery_inflight {
// A recovery is already in flight. Drop the event for now, and potentially spawn a new recovery afterwards.
rank_state.observe_live_id(event_id);
return;
} else if event_id > last_applied_id.saturating_add(1) {
// We've detected a gap. Spawn a new recovery pass.
rank_state.observe_live_id(event_id);
rank_state.recovery_inflight = true;
Action::SpawnIncremental {
epoch: worker_state.epoch,
start_event_id: last_applied_id.saturating_add(1),
}
} else {
// Apply the event.
rank_state.cursor = RankCursor::Live(event_id);
rank_state.clear_max_seen_if_caught_up(event_id);
Action::ApplyDirect
}
}
// A worker-wide barrier (currently `Cleared`) invalidated this rank's old
// cursor. The next newer live event becomes the new starting point; we do not
// recover across the barrier.
RankCursor::InvalidatedByBarrier(last_applied_id) => {
if last_applied_id
.is_some_and(|last_applied_id| event_id <= last_applied_id)
{
return;
} else {
rank_state.cursor = RankCursor::Live(event_id);
rank_state.max_seen_live_id = None;
rank_state.recovery_inflight = false;
Action::ApplyDirect
}
}
}
}
};
match action {
Action::ApplyDirect => {
self.indexer.apply_event(event).await;
}
Action::SpawnFullRestore { epoch } => {
self.spawn_recovery_task(key, epoch, None, None);
}
Action::SpawnIncremental {
epoch,
start_event_id,
} => {
self.spawn_recovery_task(key, epoch, Some(start_event_id), None);
}
}
}
fn spawn_recovery_task(
self: &Arc<Self>,
key: RecoveryKey,
epoch: u64,
start_event_id: Option<u64>, start_event_id: Option<u64>,
end_event_id: Option<u64>, end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> { ) {
let router = self.get_router_for_dp_rank(dp_rank).await?; let client = self.clone();
let request = WorkerKvQueryRequest { tokio::spawn(async move {
worker_id, let Ok(_permit) = client.recovery_semaphore.clone().acquire_owned().await else {
start_event_id, return;
end_event_id,
}; };
let mut stream = router
.direct(SingleIn::new(request), worker_id)
.await
.with_context(|| {
format!("Failed to send worker KV query to worker {worker_id} dp_rank {dp_rank}")
})?;
let response = stream let result = client
.next() .fetch_recovery_response(key.0, key.1, start_event_id, end_event_id)
.await .await;
.context("Worker KV query returned an empty response stream")?; client.finish_recovery_task(key, epoch, result).await;
});
}
if let Some(err) = response.err() { async fn finish_recovery_task(
return Err(err).context("Worker KV query response error"); self: Arc<Self>,
key: RecoveryKey,
epoch: u64,
result: Result<WorkerKvQueryResponse>,
) {
let Some(worker_state) = self.worker_states.get(&key.0).map(|entry| entry.clone()) else {
return;
};
let mut worker_state = worker_state.lock().await;
if worker_state.epoch != epoch {
tracing::debug!(
"Discarding stale recovery result for worker {} dp_rank {} due to epoch change",
key.0,
key.1
);
return;
} }
Ok(response) let Some(rank_state) = worker_state.ranks.get(&key.1) else {
return;
};
let mut new_cursor = rank_state.cursor;
let mut successful_response = false;
let mut saw_clear = false;
match result {
Ok(WorkerKvQueryResponse::Events(events)) => {
tracing::debug!(
"Got {count} buffered events from worker {} dp_rank {}",
key.0,
key.1,
count = events.len()
);
for event in events {
let event_id = event.event.event_id;
if matches!(&event.event.data, KvCacheEventData::Cleared) {
self.apply_worker_clear_locked(&mut worker_state, event)
.await;
new_cursor = RankCursor::Live(event_id);
saw_clear = true;
continue;
}
self.indexer.apply_event(event).await;
new_cursor = RankCursor::Live(event_id);
}
successful_response = true;
}
Ok(WorkerKvQueryResponse::TreeDump {
events,
last_event_id,
}) => {
tracing::info!(
"Got tree dump from worker {} dp_rank {} (range too old or unspecified), count: {}, last_event_id: {}",
key.0,
key.1,
events.len(),
last_event_id
);
for event in &events {
self.indexer.apply_event(event.clone()).await;
}
new_cursor = RankCursor::Live(last_event_id);
successful_response = true;
}
Ok(WorkerKvQueryResponse::TooNew {
requested_start,
requested_end,
newest_available,
}) => {
tracing::warn!(
"Requested range [{requested_start:?}, {requested_end:?}] is newer than available (newest: {newest_available}) for worker {} dp_rank {}",
key.0,
key.1
);
}
Ok(WorkerKvQueryResponse::InvalidRange { start_id, end_id }) => {
tracing::error!(
"Invalid range for worker {} dp_rank {}: end_id ({end_id}) < start_id ({start_id})",
key.0,
key.1
);
}
Ok(WorkerKvQueryResponse::Error(message)) => {
tracing::error!(
"Worker {} dp_rank {} query error: {}",
key.0,
key.1,
message
);
}
Err(error) => {
tracing::warn!(
"Failed recovery from worker {} dp_rank {}: {}",
key.0,
key.1,
error
);
}
}
let mut follow_up_start = None;
{
let rank_state = worker_state
.ranks
.get_mut(&key.1)
.expect("rank state should exist while finishing recovery");
rank_state.recovery_inflight = false;
rank_state.cursor = new_cursor;
let last_applied_id = rank_state.last_applied_id().unwrap_or(0);
rank_state.clear_max_seen_if_caught_up(last_applied_id);
if successful_response
&& !saw_clear
&& rank_state
.max_seen_live_id
.is_some_and(|max_seen| max_seen > last_applied_id)
{
rank_state.recovery_inflight = true;
follow_up_start = Some(last_applied_id.saturating_add(1));
}
}
let follow_up_epoch = worker_state.epoch;
drop(worker_state);
if let Some(start_event_id) = follow_up_start {
self.spawn_recovery_task(key, follow_up_epoch, Some(start_event_id), None);
}
} }
/// Query a worker's local KV indexer with exponential backoff retry. /// Query a worker's local KV indexer with exponential backoff retry.
async fn query_worker_with_retry( async fn fetch_recovery_response(
&self, &self,
worker_id: WorkerId, worker_id: WorkerId,
dp_rank: DpRank, dp_rank: DpRank,
start_event_id: Option<u64>, start_event_id: Option<u64>,
end_event_id: Option<u64>, end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> { ) -> Result<WorkerKvQueryResponse> {
tracing::debug!(
"Attempting recovery from worker {worker_id} dp_rank {dp_rank}, \
start_event_id: {start_event_id:?}, end_event_id: {end_event_id:?}"
);
let mut last_error = None; let mut last_error = None;
for attempt in 0..RECOVERY_MAX_RETRIES { for attempt in 0..RECOVERY_MAX_RETRIES {
match self match self
.transport
.query_worker(worker_id, dp_rank, start_event_id, end_event_id) .query_worker(worker_id, dp_rank, start_event_id, end_event_id)
.await .await
{ {
...@@ -275,79 +667,6 @@ impl WorkerQueryClient { ...@@ -275,79 +667,6 @@ impl WorkerQueryClient {
Err(last_error Err(last_error
.unwrap_or_else(|| anyhow::anyhow!("No response after {RECOVERY_MAX_RETRIES} retries"))) .unwrap_or_else(|| anyhow::anyhow!("No response after {RECOVERY_MAX_RETRIES} retries")))
} }
/// Recover missed KV events from a specific worker's dp_rank with retry logic.
///
/// Called both by the internal discovery loop (initial recovery) and by the
/// event plane task in subscriber.rs (gap recovery).
pub async fn recover_from_worker(
&self,
worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<usize> {
tracing::debug!(
"Attempting recovery from worker {worker_id} dp_rank {dp_rank}, \
start_event_id: {start_event_id:?}, end_event_id: {end_event_id:?}"
);
let response = self
.query_worker_with_retry(worker_id, dp_rank, start_event_id, end_event_id)
.await?;
let events = match response {
WorkerKvQueryResponse::Events(events) => {
tracing::debug!(
"Got {count} buffered events from worker {worker_id} dp_rank {dp_rank}",
count = events.len()
);
events
}
WorkerKvQueryResponse::TreeDump(events) => {
tracing::info!(
"Got tree dump from worker {worker_id} dp_rank {dp_rank} \
(range too old or unspecified), count: {count}",
count = events.len()
);
events
}
WorkerKvQueryResponse::TooNew {
requested_start,
requested_end,
newest_available,
} => {
tracing::warn!(
"Requested range [{requested_start:?}, {requested_end:?}] is newer than \
available (newest: {newest_available}) for worker {worker_id} dp_rank {dp_rank}"
);
return Ok(0);
}
WorkerKvQueryResponse::InvalidRange { start_id, end_id } => {
anyhow::bail!(
"Invalid range for worker {worker_id} dp_rank {dp_rank}: \
end_id ({end_id}) < start_id ({start_id})"
);
}
WorkerKvQueryResponse::Error(msg) => {
anyhow::bail!("Worker {worker_id} dp_rank {dp_rank} query error: {msg}");
}
};
let count = events.len();
if count == 0 {
tracing::debug!("No events to recover from worker {worker_id} dp_rank {dp_rank}");
return Ok(0);
}
tracing::info!("Recovered {count} events from worker {worker_id} dp_rank {dp_rank}");
for event in events {
self.indexer.apply_event(event).await;
}
Ok(count)
}
} }
// ============================================================================ // ============================================================================
...@@ -443,11 +762,190 @@ impl AsyncEngine<SingleIn<WorkerKvQueryRequest>, ManyOut<WorkerKvQueryResponse>, ...@@ -443,11 +762,190 @@ impl AsyncEngine<SingleIn<WorkerKvQueryRequest>, ManyOut<WorkerKvQueryResponse>,
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::kv_router::Indexer;
use crate::kv_router::RouterEvent; use crate::kv_router::RouterEvent;
use crate::kv_router::indexer::KvIndexerMetrics; use crate::kv_router::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics};
use crate::kv_router::protocols::{KvCacheEvent, KvCacheEventData}; use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash,
};
use dynamo_runtime::{DistributedRuntime, Runtime, distributed::DistributedConfig};
use std::collections::VecDeque;
use std::sync::Mutex as StdMutex;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
#[derive(Clone)]
struct MockQueryAction {
started: Option<Arc<Notify>>,
release: Option<Arc<Notify>>,
response: Result<WorkerKvQueryResponse, String>,
}
#[derive(Default)]
struct MockWorkerQueryTransport {
actions: DashMap<RecoveryKey, Arc<StdMutex<VecDeque<MockQueryAction>>>>,
#[allow(clippy::type_complexity)]
calls: Arc<StdMutex<Vec<(RecoveryKey, Option<u64>, Option<u64>)>>>,
}
impl MockWorkerQueryTransport {
fn push_action(&self, key: RecoveryKey, action: MockQueryAction) {
let queue = self
.actions
.entry(key)
.or_insert_with(|| Arc::new(StdMutex::new(VecDeque::new())))
.clone();
queue.lock().unwrap().push_back(action);
}
fn call_count(&self) -> usize {
self.calls.lock().unwrap().len()
}
}
#[async_trait]
impl WorkerQueryTransport for MockWorkerQueryTransport {
async fn query_worker(
&self,
worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> {
let key = (worker_id, dp_rank);
self.calls
.lock()
.unwrap()
.push((key, start_event_id, end_event_id));
let queue = self
.actions
.get(&key)
.unwrap_or_else(|| {
panic!("Missing action queue for worker {worker_id} dp_rank {dp_rank}")
})
.clone();
let action = queue.lock().unwrap().pop_front().unwrap_or_else(|| {
panic!("Missing action for worker {worker_id} dp_rank {dp_rank}")
});
if let Some(started) = action.started {
started.notify_waiters();
}
if let Some(release) = action.release {
release.notified().await;
}
match action.response {
Ok(response) => Ok(response),
Err(message) => Err(anyhow::anyhow!(message)),
}
}
}
async fn make_test_component(name: &str) -> Component {
let runtime = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(runtime, DistributedConfig::process_local())
.await
.unwrap();
let namespace = drt.namespace(format!("test-ns-{name}")).unwrap();
namespace
.component(format!("test-component-{name}"))
.unwrap()
}
fn make_test_indexer() -> (KvIndexer, Indexer) {
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let kv_indexer = KvIndexer::new(token, 4, metrics);
(kv_indexer.clone(), Indexer::KvIndexer(kv_indexer))
}
async fn make_test_client(
name: &str,
) -> (
Arc<WorkerQueryClient>,
Arc<MockWorkerQueryTransport>,
KvIndexer,
) {
let component = make_test_component(name).await;
let (kv_indexer, indexer) = make_test_indexer();
let transport = Arc::new(MockWorkerQueryTransport::default());
let client = WorkerQueryClient::new(component, indexer, transport.clone());
(client, transport, kv_indexer)
}
fn make_store_event(worker_id: WorkerId, dp_rank: DpRank, event_id: u64) -> RouterEvent {
RouterEvent::new(
worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(event_id),
tokens_hash: LocalBlockHash(event_id),
mm_extra_info: None,
}],
}),
dp_rank,
},
)
}
fn make_clear_event(worker_id: WorkerId, dp_rank: DpRank, event_id: u64) -> RouterEvent {
RouterEvent::new(
worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Cleared,
dp_rank,
},
)
}
fn stored_block_hashes(events: &[RouterEvent]) -> Vec<u64> {
let mut hashes = events
.iter()
.filter_map(|event| match &event.event.data {
KvCacheEventData::Stored(data) => {
data.blocks.first().map(|block| block.block_hash.0)
}
_ => None,
})
.collect::<Vec<_>>();
hashes.sort_unstable();
hashes
}
async fn wait_for<F>(mut check: F)
where
F: FnMut() -> bool,
{
for _ in 0..100 {
if check() {
return;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
panic!("condition not met before timeout");
}
fn rank_state_matches<F>(client: &Arc<WorkerQueryClient>, key: RecoveryKey, check: F) -> bool
where
F: FnOnce(&RankState) -> bool,
{
client
.worker_states
.get(&key.0)
.map(|worker_state| match worker_state.try_lock() {
Ok(worker_state) => worker_state.ranks.get(&key.1).is_some_and(check),
Err(_) => false,
})
.unwrap_or(false)
}
#[tokio::test] #[tokio::test]
async fn test_worker_kv_query_engine_returns_buffered_events() { async fn test_worker_kv_query_engine_returns_buffered_events() {
let worker_id = 7u64; let worker_id = 7u64;
...@@ -497,4 +995,345 @@ mod tests { ...@@ -497,4 +995,345 @@ mod tests {
other => panic!("Unexpected response: {other:?}"), other => panic!("Unexpected response: {other:?}"),
} }
} }
#[tokio::test]
async fn test_discovery_restore_does_not_block_other_workers() {
let (client, transport, kv_indexer) = make_test_client("discovery-concurrency").await;
let first_started = Arc::new(Notify::new());
let first_release = Arc::new(Notify::new());
transport.push_action(
(1, 0),
MockQueryAction {
started: Some(first_started.clone()),
release: Some(first_release.clone()),
response: Ok(WorkerKvQueryResponse::TreeDump {
events: vec![],
last_event_id: 0,
}),
},
);
transport.push_action(
(2, 0),
MockQueryAction {
started: None,
release: None,
response: Ok(WorkerKvQueryResponse::TreeDump {
events: vec![],
last_event_id: 0,
}),
},
);
client.handle_discovered_worker(1, 0).await;
first_started.notified().await;
client.handle_discovered_worker(2, 0).await;
wait_for(|| transport.call_count() == 2).await;
first_release.notify_waiters();
kv_indexer.flush().await;
}
#[tokio::test]
async fn test_gap_recovery_follows_high_water_mark() {
let (client, transport, kv_indexer) = make_test_client("high-water").await;
let key = (1, 0);
{
let worker_state = client.get_or_create_worker_state(key.0);
let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(key.1).or_default().cursor = RankCursor::Live(10);
}
let first_started = Arc::new(Notify::new());
let first_release = Arc::new(Notify::new());
transport.push_action(
key,
MockQueryAction {
started: Some(first_started.clone()),
release: Some(first_release.clone()),
response: Ok(WorkerKvQueryResponse::Events(
(11..=15).map(|id| make_store_event(1, 0, id)).collect(),
)),
},
);
transport.push_action(
key,
MockQueryAction {
started: None,
release: None,
response: Ok(WorkerKvQueryResponse::Events(
(16..=18).map(|id| make_store_event(1, 0, id)).collect(),
)),
},
);
client.handle_live_event(make_store_event(1, 0, 15)).await;
first_started.notified().await;
client.handle_live_event(make_store_event(1, 0, 16)).await;
client.handle_live_event(make_store_event(1, 0, 17)).await;
client.handle_live_event(make_store_event(1, 0, 18)).await;
first_release.notify_waiters();
wait_for(|| {
rank_state_matches(&client, key, |state| {
state.last_applied_id() == Some(18) && !state.recovery_inflight
})
})
.await;
kv_indexer.flush().await;
let events = kv_indexer.dump_events().await.unwrap();
assert_eq!(
stored_block_hashes(&events),
vec![11, 12, 13, 14, 15, 16, 17, 18]
);
}
#[tokio::test]
async fn test_initial_restore_updates_cursor_for_live_and_gap_paths() {
let (client, transport, kv_indexer) = make_test_client("initial-restore-cursor").await;
let key = (1, 0);
transport.push_action(
key,
MockQueryAction {
started: None,
release: None,
response: Ok(WorkerKvQueryResponse::TreeDump {
events: vec![],
last_event_id: 10,
}),
},
);
transport.push_action(
key,
MockQueryAction {
started: None,
release: None,
response: Ok(WorkerKvQueryResponse::Events(vec![
make_store_event(1, 0, 12),
make_store_event(1, 0, 13),
])),
},
);
client.handle_discovered_worker(1, 0).await;
wait_for(|| {
rank_state_matches(&client, key, |state| {
state.last_applied_id() == Some(10) && !state.recovery_inflight
})
})
.await;
assert_eq!(transport.call_count(), 1);
client.handle_live_event(make_store_event(1, 0, 11)).await;
wait_for(|| {
rank_state_matches(&client, key, |state| {
state.last_applied_id() == Some(11) && !state.recovery_inflight
})
})
.await;
assert_eq!(transport.call_count(), 1);
client.handle_live_event(make_store_event(1, 0, 13)).await;
wait_for(|| {
rank_state_matches(&client, key, |state| {
state.last_applied_id() == Some(13) && !state.recovery_inflight
})
})
.await;
assert_eq!(transport.call_count(), 2);
kv_indexer.flush().await;
let events = kv_indexer.dump_events().await.unwrap();
assert_eq!(stored_block_hashes(&events), vec![11, 12, 13]);
}
#[tokio::test]
async fn test_live_event_for_other_worker_is_not_blocked_by_inflight_recovery() {
let (client, transport, kv_indexer) = make_test_client("live-concurrency").await;
let delayed_key = (1, 0);
{
let worker_state = client.get_or_create_worker_state(delayed_key.0);
let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(delayed_key.1).or_default().cursor = RankCursor::Live(10);
}
let other_key = (2, 0);
{
let worker_state = client.get_or_create_worker_state(other_key.0);
let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(other_key.1).or_default().cursor = RankCursor::Live(20);
}
let started = Arc::new(Notify::new());
let release = Arc::new(Notify::new());
transport.push_action(
delayed_key,
MockQueryAction {
started: Some(started.clone()),
release: Some(release.clone()),
response: Ok(WorkerKvQueryResponse::Events(vec![
make_store_event(1, 0, 11),
make_store_event(1, 0, 12),
make_store_event(1, 0, 13),
])),
},
);
client.handle_live_event(make_store_event(1, 0, 13)).await;
started.notified().await;
client.handle_live_event(make_store_event(2, 0, 21)).await;
kv_indexer.flush().await;
let events = kv_indexer.dump_events().await.unwrap();
assert!(events.iter().any(|event| {
event.worker_id == 2
&& event.event.dp_rank == 0
&& matches!(
&event.event.data,
KvCacheEventData::Stored(data)
if data.blocks.first().map(|block| block.block_hash.0) == Some(21)
)
}));
release.notify_waiters();
}
#[tokio::test]
async fn test_worker_removal_discards_late_recovery_result() {
let (client, transport, kv_indexer) = make_test_client("remove-race").await;
let key = (1, 0);
{
let worker_state = client.get_or_create_worker_state(key.0);
let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(key.1).or_default().cursor = RankCursor::Live(10);
}
let started = Arc::new(Notify::new());
let release = Arc::new(Notify::new());
transport.push_action(
key,
MockQueryAction {
started: Some(started.clone()),
release: Some(release.clone()),
response: Ok(WorkerKvQueryResponse::Events(vec![
make_store_event(1, 0, 11),
make_store_event(1, 0, 12),
])),
},
);
client.handle_live_event(make_store_event(1, 0, 12)).await;
started.notified().await;
client.handle_removed_worker_dp(1, 0).await;
release.notify_waiters();
wait_for(|| !rank_state_matches(&client, key, |_| true)).await;
kv_indexer.flush().await;
let events = kv_indexer.dump_events().await.unwrap();
assert!(events.is_empty());
}
#[tokio::test]
async fn test_live_cleared_invalidates_inflight_recovery_without_restore() {
let (client, transport, kv_indexer) = make_test_client("live-cleared-no-restore").await;
let key0 = (1, 0);
let key1 = (1, 1);
{
let worker_state = client.get_or_create_worker_state(1);
let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(0).or_default().cursor = RankCursor::Live(10);
worker_state.ranks.entry(1).or_default().cursor = RankCursor::Live(20);
}
let started = Arc::new(Notify::new());
let release = Arc::new(Notify::new());
transport.push_action(
key0,
MockQueryAction {
started: Some(started.clone()),
release: Some(release.clone()),
response: Ok(WorkerKvQueryResponse::Events(vec![
make_store_event(1, 0, 11),
make_store_event(1, 0, 12),
make_store_event(1, 0, 13),
])),
},
);
client.handle_live_event(make_store_event(1, 0, 13)).await;
started.notified().await;
client.handle_live_event(make_clear_event(1, 0, 14)).await;
wait_for(|| transport.call_count() == 1).await;
release.notify_waiters();
wait_for(|| {
rank_state_matches(&client, key0, |state| {
state.last_applied_id() == Some(14) && !state.recovery_inflight
}) && rank_state_matches(&client, key1, |state| {
state.last_applied_id() == Some(20) && !state.recovery_inflight
})
})
.await;
client.handle_live_event(make_store_event(1, 0, 15)).await;
client.handle_live_event(make_store_event(1, 1, 30)).await;
kv_indexer.flush().await;
let events = kv_indexer.dump_events().await.unwrap();
assert_eq!(transport.call_count(), 1);
assert_eq!(stored_block_hashes(&events), vec![15, 30]);
}
#[tokio::test]
async fn test_recovered_cleared_resumes_live_without_restore() {
let (client, transport, kv_indexer) =
make_test_client("recovered-cleared-no-restore").await;
let key0 = (1, 0);
let key1 = (1, 1);
{
let worker_state = client.get_or_create_worker_state(1);
let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(0).or_default().cursor = RankCursor::Live(10);
worker_state.ranks.entry(1).or_default().cursor = RankCursor::Live(20);
}
transport.push_action(
key0,
MockQueryAction {
started: None,
release: None,
response: Ok(WorkerKvQueryResponse::Events(vec![
make_store_event(1, 0, 11),
make_clear_event(1, 0, 12),
make_store_event(1, 0, 13),
])),
},
);
client.handle_live_event(make_store_event(1, 0, 13)).await;
wait_for(|| {
rank_state_matches(&client, key0, |state| {
state.last_applied_id() == Some(13) && !state.recovery_inflight
}) && rank_state_matches(&client, key1, |state| {
state.last_applied_id() == Some(20) && !state.recovery_inflight
})
})
.await;
assert_eq!(transport.call_count(), 1);
client.handle_live_event(make_store_event(1, 0, 14)).await;
client.handle_live_event(make_store_event(1, 1, 30)).await;
kv_indexer.flush().await;
let events = kv_indexer.dump_events().await.unwrap();
assert_eq!(stored_block_hashes(&events), vec![13, 14, 30]);
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment