Unverified Commit 7ff5e0be authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Concurrent KV event consumer (#7293)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent 0f8e1a9e
......@@ -18,6 +18,74 @@ use crate::indexer::pruning::{BlockEntry, PruneConfig, PruneManager};
use crate::protocols::*;
use dynamo_tokens::SequenceHash;
fn stored_block_entries(event: &RouterEvent) -> Option<Vec<BlockEntry>> {
let KvCacheEventData::Stored(ref store_data) = event.event.data else {
return None;
};
let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank);
Some(
store_data
.blocks
.iter()
.enumerate()
.map(|(idx, block)| BlockEntry {
key: block.block_hash,
worker,
seq_position: idx,
})
.collect(),
)
}
fn apply_event_with_prune_tracking(
trie: &mut RadixTree,
event: RouterEvent,
metrics: &KvIndexerMetrics,
prune_manager: &mut Option<PruneManager<BlockEntry>>,
prune_tx: &mpsc::Sender<()>,
) {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let event_id = event.event.event_id;
let worker_id = event.worker_id;
let event_for_prune = prune_manager.is_some().then(|| event.clone());
let result = trie.apply_event(event);
let result_is_ok = result.is_ok();
let tree_size = trie.current_size();
tracing::trace!(
"Applied KV event to global radix tree: event_type={event_type}, event_id={event_id}, worker_id={worker_id}, success={result_is_ok}, global_radix_tree_size={tree_size}"
);
metrics.increment_event_applied(event_type, result);
let Some(pm) = prune_manager.as_mut() else {
return;
};
if !result_is_ok {
return;
}
let Some(ref event) = event_for_prune else {
return;
};
let Some(block_entries) = stored_block_entries(event) else {
return;
};
pm.insert(block_entries);
let Some(ref pc) = pm.prune_config else {
return;
};
let current_size = trie.current_size();
if current_size > pc.max_tree_size {
tracing::info!(
"Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning",
current_size,
pc.max_tree_size
);
let _ = prune_tx.try_send(());
}
}
/// The KV Indexer, managing the KV store and handling events and match requests.
#[derive(Clone)]
pub struct KvIndexer {
......@@ -64,7 +132,7 @@ impl KvIndexer {
metrics: Arc<KvIndexerMetrics>,
prune_config: Option<PruneConfig>,
) -> Self {
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(16384);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (remove_worker_dp_rank_tx, remove_worker_dp_rank_rx) =
......@@ -151,49 +219,26 @@ impl KvIndexer {
}
Some(event) = event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let event_id = event.event.event_id;
let worker_id = event.worker_id;
// Only clone if we need the event for prune_manager afterward
let event_for_prune = prune_manager.is_some().then(|| event.clone());
let result = trie.apply_event(event);
let result_is_ok = result.is_ok();
let tree_size = trie.current_size();
tracing::trace!(
"Applied KV event to global radix tree: event_type={event_type}, event_id={event_id}, worker_id={worker_id}, success={result_is_ok}, global_radix_tree_size={tree_size}"
apply_event_with_prune_tracking(
&mut trie,
event,
&metrics,
&mut prune_manager,
&prune_tx,
);
metrics.increment_event_applied(event_type, result);
// Track blocks in PruneManager if TTL is enabled and event was stored successfully
let Some(ref mut pm) = prune_manager else { continue };
if !result_is_ok { continue };
let Some(ref event) = event_for_prune else { continue };
let KvCacheEventData::Stored(ref store_data) = event.event.data else { continue };
let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank);
let block_entries: Vec<BlockEntry> = store_data.blocks.iter().enumerate().map(|(idx, block)| {
BlockEntry {
key: block.block_hash,
worker,
seq_position: idx,
}
}).collect();
pm.insert(block_entries);
// Check if we need to prune due to tree size
let Some(ref pc) = pm.prune_config else { continue };
let current_size = trie.current_size();
if current_size > pc.max_tree_size {
tracing::info!(
"Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning",
current_size,
pc.max_tree_size
);
let _ = prune_tx.try_send(());
}
}
Some(dump_req) = dump_rx.recv() => {
// Flush pending events so tree is consistent with buffer
while let Ok(event) = event_rx.try_recv() {
apply_event_with_prune_tracking(
&mut trie,
event,
&metrics,
&mut prune_manager,
&prune_tx,
);
}
let events = trie.dump_tree_as_events();
let _ = dump_req.resp.send(events);
}
......
......@@ -53,6 +53,15 @@ impl LocalKvIndexer {
buffer.iter().cloned().collect()
}
/// Build a tree dump response with the given `last_event_id`.
async fn tree_dump_response(&self, last_event_id: u64) -> WorkerKvQueryResponse {
let events = self.dump_events().await.unwrap_or_default();
WorkerKvQueryResponse::TreeDump {
events,
last_event_id,
}
}
/// Query events by ID range, returning events in `[start_id, end_id]` (both inclusive).
///
/// ### Arguments
......@@ -63,7 +72,7 @@ impl LocalKvIndexer {
/// ### Returns
///
/// - `Events`: Buffered events with original IDs (when range is within buffer)
/// - `TreeDump`: Full tree dump with synthetic IDs (when range is too old or unspecified)
/// - `TreeDump`: Full tree dump with synthetic IDs and the worker's latest real event ID (when range is too old or unspecified)
/// - `TooNew`: Error when requested range is newer than available data
/// - `InvalidRange`: Error when end_id < start_id
pub async fn get_events_in_id_range(
......@@ -98,8 +107,7 @@ impl LocalKvIndexer {
// If no start_id specified, dump entire tree
if start_id.is_none() {
tracing::debug!("No start_id specified, dumping entire tree");
let events = self.dump_events().await.unwrap_or_default();
return WorkerKvQueryResponse::TreeDump(events);
return self.tree_dump_response(last_id.unwrap_or(0)).await;
}
let start_id = start_id.unwrap();
......@@ -108,8 +116,7 @@ impl LocalKvIndexer {
// Check for empty buffer
let Some(first_buffered) = first_id else {
tracing::debug!("Buffer empty, dumping entire tree");
let events = self.dump_events().await.unwrap_or_default();
return WorkerKvQueryResponse::TreeDump(events);
return self.tree_dump_response(0).await;
};
let last_buffered = last_id.unwrap();
......@@ -134,8 +141,7 @@ impl LocalKvIndexer {
first_buffered,
"Requested start_id is older than buffer, dumping entire tree"
);
let events = self.dump_events().await.unwrap_or_default();
return WorkerKvQueryResponse::TreeDump(events);
return self.tree_dump_response(last_buffered).await;
}
// Serve from buffer
......@@ -196,17 +202,20 @@ impl LocalKvIndexer {
/// Apply event with buffering.
///
/// This records the event in the buffer and forwards it to the underlying indexer.
/// This forwards the event to the underlying indexer and records it on success.
pub async fn apply_event_with_buffer(&self, event: RouterEvent) -> Result<(), KvRouterError> {
// Record in buffer
self.record_event(event.clone());
// Forward to underlying indexer
self.indexer
let result = self
.indexer
.event_sender()
.send(event)
.send(event.clone())
.await
.map_err(|_| KvRouterError::IndexerOffline)
.map_err(|_| KvRouterError::IndexerOffline);
if result.is_ok() {
self.record_event(event);
}
result
}
/// Clear the event buffer.
......
......@@ -1941,7 +1941,7 @@ async fn test_local_indexer_slice_within_range() {
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
match resp {
WorkerKvQueryResponse::Events(e) => e,
WorkerKvQueryResponse::TreeDump(e) => e,
WorkerKvQueryResponse::TreeDump { events: e, .. } => e,
_ => panic!("Unexpected response type"),
}
};
......@@ -1962,7 +1962,7 @@ async fn test_local_indexer_slice_within_range() {
// start_id=0 is before buffer (first is 1), so should trigger tree dump
let result = indexer.get_events_in_id_range(Some(0), Some(4)).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
assert!(matches!(result, WorkerKvQueryResponse::TreeDump { .. }));
let result = indexer.get_events_in_id_range(Some(3), Some(3)).await;
let ids = get_ids(extract_events(result));
......@@ -2016,7 +2016,7 @@ async fn test_local_indexer_get_events_in_id_range_all_cases() {
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
match resp {
WorkerKvQueryResponse::Events(e) => e,
WorkerKvQueryResponse::TreeDump(e) => e,
WorkerKvQueryResponse::TreeDump { events: e, .. } => e,
_ => panic!("Unexpected response type: {:?}", resp),
}
};
......@@ -2038,11 +2038,11 @@ async fn test_local_indexer_get_events_in_id_range_all_cases() {
// Tree dump path tests
let result = indexer.get_events_in_id_range(None, None).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
assert!(matches!(&result, WorkerKvQueryResponse::TreeDump { .. }));
assert_eq!(extract_events(result).len(), 10);
let result = indexer.get_events_in_id_range(Some(7), None).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_)));
assert!(matches!(result, WorkerKvQueryResponse::TreeDump { .. }));
// Edge cases
let result = indexer.get_events_in_id_range(Some(15), Some(10)).await;
......@@ -2052,6 +2052,98 @@ async fn test_local_indexer_get_events_in_id_range_all_cases() {
assert!(matches!(result, WorkerKvQueryResponse::TooNew { .. }));
}
#[tokio::test]
async fn test_tree_dump_includes_last_event_id() {
// Create indexer with small buffer (5 events max)
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
let make_event = |id: u64| {
RouterEvent::new(
0,
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(id * 100),
tokens_hash: LocalBlockHash(id * 200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
},
)
};
// Add 10 events (IDs 5-14), buffer keeps last 5: events 10-14
for id in 5..15 {
indexer
.apply_event_with_buffer(make_event(id))
.await
.unwrap();
}
indexer.flush().await;
// Request with start_id=None -> tree dump should include last_event_id=14
let result = indexer.get_events_in_id_range(None, None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 14,
"last_event_id should be the buffer's newest event ID"
);
assert!(!events.is_empty(), "tree dump should contain events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
// Request with start_id older than buffer -> tree dump should include last_event_id=14
let result = indexer.get_events_in_id_range(Some(7), None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 14,
"last_event_id should be the buffer's newest event ID"
);
assert!(!events.is_empty(), "tree dump should contain events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
// Empty buffer case: create a fresh indexer with no events
let empty_indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
let result = empty_indexer.get_events_in_id_range(None, None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 0,
"empty buffer should return last_event_id=0"
);
assert!(events.is_empty(), "empty indexer should have no events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
}
#[tokio::test]
async fn test_local_indexer_buffer_and_serialization() {
let worker_id = 42u64;
......@@ -2099,6 +2191,51 @@ async fn test_local_indexer_buffer_and_serialization() {
assert_eq!(events[0].worker_id, worker_id);
}
#[tokio::test]
async fn test_local_indexer_does_not_buffer_failed_send() {
let local_indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
let test_event = RouterEvent::new(
7,
KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
},
);
let event_tx = local_indexer.event_sender();
local_indexer.shutdown();
event_tx.closed().await;
let result = local_indexer.apply_event_with_buffer(test_event).await;
assert!(matches!(result, Err(KvRouterError::IndexerOffline)));
assert_eq!(local_indexer.buffer_len(), 0);
match local_indexer.get_events_in_id_range(None, None).await {
WorkerKvQueryResponse::TreeDump {
events,
last_event_id,
} => {
assert!(events.is_empty());
assert_eq!(last_event_id, 0);
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_apply_events_idempotent(variant: &str) {
......
......@@ -56,8 +56,13 @@ pub struct WorkerKvQueryRequest {
pub enum WorkerKvQueryResponse {
/// Events served from the circular buffer (with original event IDs)
Events(Vec<RouterEvent>),
/// Full tree dump (with synthetic 0-indexed event IDs)
TreeDump(Vec<RouterEvent>),
/// Full tree dump (with synthetic 0-indexed event IDs).
/// Includes `last_event_id`: the newest real event ID in the worker's buffer
/// at the time of the dump, so the caller can set its tracking cursor correctly.
TreeDump {
events: Vec<RouterEvent>,
last_event_id: u64,
},
/// Requested range is newer than available data
TooNew {
requested_start: Option<u64>,
......
......@@ -2024,7 +2024,7 @@ mod tests_startup_helpers {
.await;
let missed_events = match response {
crate::kv_router::indexer::WorkerKvQueryResponse::Events(e) => e,
crate::kv_router::indexer::WorkerKvQueryResponse::TreeDump(e) => e,
crate::kv_router::indexer::WorkerKvQueryResponse::TreeDump { events: e, .. } => e,
crate::kv_router::indexer::WorkerKvQueryResponse::Error(message) => {
panic!("Unexpected error response: {message}")
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use crate::kv_router::{
Indexer, KV_EVENT_SUBJECT, KvRouterConfig,
protocols::{DpRank, RouterEvent, WorkerId},
Indexer, KV_EVENT_SUBJECT, KvRouterConfig, protocols::RouterEvent,
worker_query::WorkerQueryClient,
};
use anyhow::Result;
......@@ -23,9 +20,6 @@ use dynamo_runtime::{
/// - On worker Added: dumps worker's local indexer into router
/// - On worker Removed: removes worker from router indexer
///
/// This function first recovers state from all currently registered workers before
/// spawning the background task, ensuring the router is ready before returning.
///
/// This is appropriate when workers have local indexers enabled.
async fn start_kv_router_background_event_plane(
component: Component,
......@@ -47,7 +41,9 @@ async fn start_kv_router_background_event_plane(
// before recovery fetches the initial dump from workers.
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let worker_query_client = WorkerQueryClient::spawn(component.clone(), indexer.clone()).await?;
// WorkerQueryClient handles its own discovery loop for lifecycle + initial recovery.
// No blocking wait — recovery happens asynchronously as endpoints are discovered.
let worker_query_client = WorkerQueryClient::spawn(component.clone(), indexer).await?;
let kv_event_subject = format!(
"namespace.{}.component.{}.{}",
component.namespace().name(),
......@@ -71,10 +67,6 @@ async fn start_kv_router_background_event_plane(
}
tokio::spawn(async move {
// Track last received event ID per (worker, dp_rank) for gap detection
// Each dp_rank has its own monotonic event ID sequence
let mut last_event_ids: HashMap<(WorkerId, DpRank), u64> = HashMap::new();
loop {
tokio::select! {
biased;
......@@ -94,47 +86,19 @@ async fn start_kv_router_background_event_plane(
}
};
let worker_id = event.worker_id;
let dp_rank = event.event.dp_rank;
let event_id = event.event.event_id;
let event_key = (worker_id, dp_rank);
tracing::trace!(
"Received event from publisher {} (seq {})",
envelope.publisher_id,
envelope.sequence
);
// Gap detection: check if event ID is monotonically increasing per (worker, dp_rank)
// Note: event_id <= last_id is duplicate/out-of-order, apply anyway (idempotent)
if let Some(&last_id) = last_event_ids.get(&event_key)
&& event_id > last_id + 1
{
let gap_start = last_id + 1;
let gap_end = event_id - 1;
let gap_size = gap_end - gap_start + 1;
tracing::warn!(
"Event ID gap detected for worker {worker_id} dp_rank {dp_rank}, recovering events [{gap_start}, {gap_end}], gap_size: {gap_size}"
);
if let Err(e) = worker_query_client
.recover_from_worker(worker_id, dp_rank, Some(gap_start), Some(gap_end))
.await
{
tracing::error!(
"Failed to recover gap events for worker {worker_id} dp_rank {dp_rank} (gap_start: {gap_start}, gap_end: {gap_end}); proceeding with current event anyway: {e}"
);
}
}
// Update last seen event ID (use max to handle out-of-order)
last_event_ids
.entry(event_key)
.and_modify(|id| *id = (*id).max(event_id))
.or_insert(event_id);
// Forward the RouterEvent to the indexer
indexer.apply_event(event).await;
tracing::trace!(
"Forwarding live event to recovery coordinator for worker {} dp_rank {} event_id {}",
event.worker_id,
event.event.dp_rank,
event.event.event_id
);
worker_query_client.handle_live_event(event).await;
}
}
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result};
use async_trait::async_trait;
use dashmap::DashMap;
use dynamo_runtime::component::Component;
use dynamo_runtime::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryQuery};
use dynamo_runtime::pipeline::{
AsyncEngine, AsyncEngineContextProvider, ManyOut, PushRouter, ResponseStream, RouterMode,
SingleIn, async_trait, network::Ingress,
SingleIn, network::Ingress,
};
use dynamo_runtime::protocols::maybe_error::MaybeError;
use dynamo_runtime::stream;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use futures::StreamExt;
use tokio::sync::{Mutex, Semaphore};
use crate::kv_router::Indexer;
use crate::kv_router::indexer::{LocalKvIndexer, WorkerKvQueryRequest, WorkerKvQueryResponse};
use crate::kv_router::protocols::{DpRank, WorkerId};
use crate::kv_router::protocols::{DpRank, KvCacheEventData, RouterEvent, WorkerId};
use crate::kv_router::worker_kv_indexer_query_endpoint;
// Recovery retry configuration
const RECOVERY_MAX_RETRIES: u32 = 8;
const RECOVERY_INITIAL_BACKOFF_MS: u64 = 200;
const RECOVERY_CONCURRENCY_LIMIT: usize = 16;
/// Prefix for worker KV indexer query endpoint names.
const QUERY_ENDPOINT_PREFIX: &str = "worker_kv_indexer_query_dp";
type RecoveryKey = (WorkerId, DpRank);
#[derive(Clone, Copy, Debug, Default)]
enum RankCursor {
#[default]
NeedsRestore,
Live(u64),
InvalidatedByBarrier(Option<u64>),
}
#[derive(Debug, Default)]
struct RankState {
cursor: RankCursor,
max_seen_live_id: Option<u64>,
recovery_inflight: bool,
}
impl RankState {
fn last_applied_id(&self) -> Option<u64> {
match self.cursor {
RankCursor::NeedsRestore => None,
RankCursor::Live(event_id) => Some(event_id),
RankCursor::InvalidatedByBarrier(last_applied_id) => last_applied_id,
}
}
fn observe_live_id(&mut self, event_id: u64) {
self.max_seen_live_id = Some(self.max_seen_live_id.unwrap_or(0).max(event_id));
}
fn clear_max_seen_if_caught_up(&mut self, last_applied_id: u64) {
if self
.max_seen_live_id
.is_some_and(|max_seen| max_seen <= last_applied_id)
{
self.max_seen_live_id = None;
}
}
}
#[derive(Debug, Default)]
struct WorkerState {
epoch: u64,
ranks: HashMap<DpRank, RankState>,
}
#[async_trait]
trait WorkerQueryTransport: Send + Sync {
async fn query_worker(
&self,
worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse>;
}
struct RuntimeWorkerQueryTransport {
component: Component,
routers: DashMap<DpRank, Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>>,
}
impl RuntimeWorkerQueryTransport {
fn new(component: Component) -> Self {
Self {
component,
routers: DashMap::new(),
}
}
async fn get_router_for_dp_rank(
&self,
dp_rank: DpRank,
) -> Result<Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>> {
if let Some(router) = self.routers.get(&dp_rank) {
return Ok(router.clone());
}
let endpoint_name = worker_kv_indexer_query_endpoint(dp_rank);
let endpoint = self.component.endpoint(&endpoint_name);
let client = endpoint.client().await?;
let router = Arc::new(
PushRouter::from_client_no_fault_detection(client, RouterMode::RoundRobin).await?,
);
Ok(self.routers.entry(dp_rank).or_insert(router).clone())
}
}
#[async_trait]
impl WorkerQueryTransport for RuntimeWorkerQueryTransport {
async fn query_worker(
&self,
worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> {
let router = self.get_router_for_dp_rank(dp_rank).await?;
let request = WorkerKvQueryRequest {
worker_id,
start_event_id,
end_event_id,
};
let mut stream = router
.direct(SingleIn::new(request), worker_id)
.await
.with_context(|| {
format!("Failed to send worker KV query to worker {worker_id} dp_rank {dp_rank}")
})?;
let response = stream
.next()
.await
.context("Worker KV query returned an empty response stream")?;
if let Some(err) = response.err() {
return Err(err).context("Worker KV query response error");
}
Ok(response)
}
}
/// Router-side client for querying worker local KV indexers.
///
/// Discovers query endpoints via `ComponentEndpoints` discovery, filtering for
/// the `worker_kv_indexer_query_dp{N}` name pattern. Recovers each
/// `(worker_id, dp_rank)` individually as it appears in discovery.
/// the `worker_kv_indexer_query_dp{N}` name pattern. Coordinates restore and
/// gap recovery at the worker level while still querying each `(worker_id,
/// dp_rank)` endpoint independently.
///
/// Also handles worker lifecycle (add/remove) by tracking known endpoints and
/// sending removal events to the router indexer when all dp_ranks for a worker
/// disappear.
pub struct WorkerQueryClient {
component: Component,
/// Routers keyed by dp_rank — each dp_rank has its own endpoint. Created lazily.
routers: Arc<DashMap<DpRank, Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>>>,
transport: Arc<dyn WorkerQueryTransport>,
/// Indexer for applying recovered events and worker removals.
indexer: Indexer,
worker_states: DashMap<WorkerId, Arc<Mutex<WorkerState>>>,
recovery_semaphore: Arc<Semaphore>,
}
impl WorkerQueryClient {
fn new(
component: Component,
indexer: Indexer,
transport: Arc<dyn WorkerQueryTransport>,
) -> Arc<Self> {
Arc::new(Self {
component,
transport,
indexer,
worker_states: DashMap::new(),
recovery_semaphore: Arc::new(Semaphore::new(RECOVERY_CONCURRENCY_LIMIT)),
})
}
/// Create a new WorkerQueryClient and spawn its background discovery loop.
///
/// The background loop watches `ComponentEndpoints` discovery for query endpoints,
/// recovers each `(worker_id, dp_rank)` as it appears, and sends worker removal
/// events when all dp_ranks for a worker disappear.
pub async fn spawn(component: Component, indexer: Indexer) -> Result<Arc<Self>> {
let client = Arc::new(Self {
component: component.clone(),
routers: Arc::new(DashMap::new()),
indexer,
});
let transport = Arc::new(RuntimeWorkerQueryTransport::new(component.clone()));
let client = Self::new(component.clone(), indexer, transport);
let client_bg = client.clone();
let cancel_token = component.drt().primary_token();
......@@ -71,9 +212,9 @@ impl WorkerQueryClient {
Ok(client)
}
/// Background loop: watches ComponentEndpoints, recovers per (worker_id, dp_rank).
/// Background loop: watches ComponentEndpoints and schedules worker-coordinated recovery.
async fn run_discovery_loop(
&self,
self: Arc<Self>,
cancel_token: tokio_util::sync::CancellationToken,
) -> Result<()> {
let discovery = self.component.drt().discovery();
......@@ -87,9 +228,6 @@ impl WorkerQueryClient {
)
.await?;
// Track known (worker_id, dp_rank) pairs to detect removals
let mut known: HashMap<WorkerId, HashSet<DpRank>> = HashMap::new();
while let Some(result) = stream.next().await {
if cancel_token.is_cancelled() {
break;
......@@ -108,45 +246,13 @@ impl WorkerQueryClient {
let Some((worker_id, dp_rank)) = Self::parse_query_endpoint(&instance) else {
continue;
};
if known.entry(worker_id).or_default().insert(dp_rank) {
tracing::info!(
"WorkerQueryClient: discovered worker {worker_id} dp_rank {dp_rank}, recovering"
);
match self
.recover_from_worker(worker_id, dp_rank, None, None)
.await
{
Ok(count) => {
if count > 0 {
tracing::info!(
"Recovered {count} events from worker {worker_id} dp_rank {dp_rank}"
);
}
}
Err(e) => {
tracing::warn!(
"Failed to recover from worker {worker_id} dp_rank {dp_rank}: {e}"
);
}
}
}
self.handle_discovered_worker(worker_id, dp_rank).await;
}
DiscoveryEvent::Removed(id) => {
let Some((worker_id, dp_rank)) = Self::parse_instance_id(&id) else {
continue;
};
if let Some(dp_ranks) = known.get_mut(&worker_id) {
dp_ranks.remove(&dp_rank);
if dp_ranks.is_empty() {
known.remove(&worker_id);
tracing::warn!(
"WorkerQueryClient: all dp_ranks gone for worker {worker_id}, removing"
);
self.indexer.remove_worker(worker_id).await;
}
}
self.handle_removed_worker_dp(worker_id, dp_rank).await;
}
}
}
......@@ -177,76 +283,362 @@ impl WorkerQueryClient {
Some((eid.instance_id, dp_rank))
}
/// Get or create a router for the specified dp_rank's endpoint.
async fn get_router_for_dp_rank(
&self,
fn get_or_create_worker_state(&self, worker_id: WorkerId) -> Arc<Mutex<WorkerState>> {
self.worker_states
.entry(worker_id)
.or_insert_with(|| Arc::new(Mutex::new(WorkerState::default())))
.clone()
}
pub(crate) async fn handle_discovered_worker(
self: &Arc<Self>,
worker_id: WorkerId,
dp_rank: DpRank,
) -> Result<Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>> {
if let Some(router) = self.routers.get(&dp_rank) {
return Ok(router.clone());
) {
let worker_state = self.get_or_create_worker_state(worker_id);
let spawn = {
let mut worker_state = worker_state.lock().await;
let rank_state = worker_state.ranks.entry(dp_rank).or_default();
if matches!(rank_state.cursor, RankCursor::NeedsRestore)
&& !rank_state.recovery_inflight
{
tracing::info!(
"WorkerQueryClient: discovered worker {worker_id} dp_rank {dp_rank}, scheduling restore"
);
rank_state.recovery_inflight = true;
Some(worker_state.epoch)
} else {
None
}
};
if let Some(epoch) = spawn {
self.spawn_recovery_task((worker_id, dp_rank), epoch, None, None);
}
}
let endpoint_name = worker_kv_indexer_query_endpoint(dp_rank);
let endpoint = self.component.endpoint(&endpoint_name);
let client = endpoint.client().await?;
let router = Arc::new(
PushRouter::from_client_no_fault_detection(client, RouterMode::RoundRobin).await?,
pub(crate) async fn handle_removed_worker_dp(&self, worker_id: WorkerId, dp_rank: DpRank) {
let Some(worker_state) = self
.worker_states
.get(&worker_id)
.map(|entry| entry.clone())
else {
return;
};
let should_remove_worker = {
let mut worker_state = worker_state.lock().await;
if worker_state.ranks.remove(&dp_rank).is_none() {
return;
}
worker_state.ranks.is_empty()
};
if should_remove_worker {
tracing::warn!("WorkerQueryClient: all dp_ranks gone for worker {worker_id}, removing");
self.worker_states.remove(&worker_id);
self.indexer.remove_worker(worker_id).await;
}
}
async fn apply_worker_clear_locked(&self, worker_state: &mut WorkerState, event: RouterEvent) {
let worker_id = event.worker_id;
let clear_dp_rank = event.event.dp_rank;
let clear_event_id = event.event.event_id;
worker_state.epoch += 1;
for rank_state in worker_state.ranks.values_mut() {
rank_state.cursor = RankCursor::InvalidatedByBarrier(rank_state.last_applied_id());
rank_state.max_seen_live_id = None;
rank_state.recovery_inflight = false;
}
let rank_state = worker_state.ranks.entry(clear_dp_rank).or_default();
rank_state.cursor = RankCursor::Live(clear_event_id);
tracing::info!(
"Applying clear barrier for worker {worker_id}; invalidating recovery across {} dp_ranks",
worker_state.ranks.len()
);
self.indexer.apply_event(event).await;
}
pub(crate) async fn handle_live_event(self: &Arc<Self>, event: RouterEvent) {
let worker_id = event.worker_id;
let dp_rank = event.event.dp_rank;
let event_id = event.event.event_id;
let key = (worker_id, dp_rank);
enum Action {
ApplyDirect,
SpawnFullRestore { epoch: u64 },
SpawnIncremental { epoch: u64, start_event_id: u64 },
}
let action = {
let worker_state = self.get_or_create_worker_state(worker_id);
let mut worker_state = worker_state.lock().await;
let rank_state = worker_state.ranks.entry(dp_rank).or_default();
// `Cleared` is worker-wide in the indexer, so it bypasses the per-rank gap logic and
// instead installs a worker barrier that invalidates any inflight recovery.
if matches!(&event.event.data, KvCacheEventData::Cleared) {
if rank_state
.last_applied_id()
.is_none_or(|last_applied_id| event_id > last_applied_id)
{
self.apply_worker_clear_locked(&mut worker_state, event)
.await;
}
// Already applied the event, so no further action needed.
return;
} else {
match rank_state.cursor {
// We have never established a cursor for this rank, so live traffic only tells
// us how far ahead the stream has moved while a full restore catches up.
RankCursor::NeedsRestore => {
rank_state.observe_live_id(event_id);
if !rank_state.recovery_inflight {
rank_state.recovery_inflight = true;
Action::SpawnFullRestore {
epoch: worker_state.epoch,
}
} else {
// A recovery is already in flight. Nothing to do.
return;
}
}
// Normal steady-state path: apply contiguous events directly, but coalesce any
// gap into a single recovery pass using `max_seen_live_id` as the high-water mark.
RankCursor::Live(last_applied_id) => {
if event_id <= last_applied_id {
// We've already applied this event. Nothing to do.
return;
} else if rank_state.recovery_inflight {
// A recovery is already in flight. Drop the event for now, and potentially spawn a new recovery afterwards.
rank_state.observe_live_id(event_id);
return;
} else if event_id > last_applied_id.saturating_add(1) {
// We've detected a gap. Spawn a new recovery pass.
rank_state.observe_live_id(event_id);
rank_state.recovery_inflight = true;
Action::SpawnIncremental {
epoch: worker_state.epoch,
start_event_id: last_applied_id.saturating_add(1),
}
} else {
// Apply the event.
rank_state.cursor = RankCursor::Live(event_id);
rank_state.clear_max_seen_if_caught_up(event_id);
Action::ApplyDirect
}
}
// A worker-wide barrier (currently `Cleared`) invalidated this rank's old
// cursor. The next newer live event becomes the new starting point; we do not
// recover across the barrier.
RankCursor::InvalidatedByBarrier(last_applied_id) => {
if last_applied_id
.is_some_and(|last_applied_id| event_id <= last_applied_id)
{
return;
} else {
rank_state.cursor = RankCursor::Live(event_id);
rank_state.max_seen_live_id = None;
rank_state.recovery_inflight = false;
Action::ApplyDirect
}
}
}
}
};
Ok(self
.routers
.entry(dp_rank)
.or_insert(router)
.value()
.clone())
match action {
Action::ApplyDirect => {
self.indexer.apply_event(event).await;
}
Action::SpawnFullRestore { epoch } => {
self.spawn_recovery_task(key, epoch, None, None);
}
Action::SpawnIncremental {
epoch,
start_event_id,
} => {
self.spawn_recovery_task(key, epoch, Some(start_event_id), None);
}
}
}
/// Query a specific worker's local KV indexer for a specific dp_rank.
pub async fn query_worker(
&self,
worker_id: WorkerId,
dp_rank: DpRank,
fn spawn_recovery_task(
self: &Arc<Self>,
key: RecoveryKey,
epoch: u64,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> {
let router = self.get_router_for_dp_rank(dp_rank).await?;
) {
let client = self.clone();
let request = WorkerKvQueryRequest {
worker_id,
start_event_id,
end_event_id,
tokio::spawn(async move {
let Ok(_permit) = client.recovery_semaphore.clone().acquire_owned().await else {
return;
};
let result = client
.fetch_recovery_response(key.0, key.1, start_event_id, end_event_id)
.await;
client.finish_recovery_task(key, epoch, result).await;
});
}
async fn finish_recovery_task(
self: Arc<Self>,
key: RecoveryKey,
epoch: u64,
result: Result<WorkerKvQueryResponse>,
) {
let Some(worker_state) = self.worker_states.get(&key.0).map(|entry| entry.clone()) else {
return;
};
let mut stream = router
.direct(SingleIn::new(request), worker_id)
.await
.with_context(|| {
format!("Failed to send worker KV query to worker {worker_id} dp_rank {dp_rank}")
})?;
let mut worker_state = worker_state.lock().await;
if worker_state.epoch != epoch {
tracing::debug!(
"Discarding stale recovery result for worker {} dp_rank {} due to epoch change",
key.0,
key.1
);
return;
}
let response = stream
.next()
.await
.context("Worker KV query returned an empty response stream")?;
let Some(rank_state) = worker_state.ranks.get(&key.1) else {
return;
};
if let Some(err) = response.err() {
return Err(err).context("Worker KV query response error");
let mut new_cursor = rank_state.cursor;
let mut successful_response = false;
let mut saw_clear = false;
match result {
Ok(WorkerKvQueryResponse::Events(events)) => {
tracing::debug!(
"Got {count} buffered events from worker {} dp_rank {}",
key.0,
key.1,
count = events.len()
);
for event in events {
let event_id = event.event.event_id;
if matches!(&event.event.data, KvCacheEventData::Cleared) {
self.apply_worker_clear_locked(&mut worker_state, event)
.await;
new_cursor = RankCursor::Live(event_id);
saw_clear = true;
continue;
}
self.indexer.apply_event(event).await;
new_cursor = RankCursor::Live(event_id);
}
successful_response = true;
}
Ok(WorkerKvQueryResponse::TreeDump {
events,
last_event_id,
}) => {
tracing::info!(
"Got tree dump from worker {} dp_rank {} (range too old or unspecified), count: {}, last_event_id: {}",
key.0,
key.1,
events.len(),
last_event_id
);
for event in &events {
self.indexer.apply_event(event.clone()).await;
}
new_cursor = RankCursor::Live(last_event_id);
successful_response = true;
}
Ok(WorkerKvQueryResponse::TooNew {
requested_start,
requested_end,
newest_available,
}) => {
tracing::warn!(
"Requested range [{requested_start:?}, {requested_end:?}] is newer than available (newest: {newest_available}) for worker {} dp_rank {}",
key.0,
key.1
);
}
Ok(WorkerKvQueryResponse::InvalidRange { start_id, end_id }) => {
tracing::error!(
"Invalid range for worker {} dp_rank {}: end_id ({end_id}) < start_id ({start_id})",
key.0,
key.1
);
}
Ok(WorkerKvQueryResponse::Error(message)) => {
tracing::error!(
"Worker {} dp_rank {} query error: {}",
key.0,
key.1,
message
);
}
Err(error) => {
tracing::warn!(
"Failed recovery from worker {} dp_rank {}: {}",
key.0,
key.1,
error
);
}
}
Ok(response)
let mut follow_up_start = None;
{
let rank_state = worker_state
.ranks
.get_mut(&key.1)
.expect("rank state should exist while finishing recovery");
rank_state.recovery_inflight = false;
rank_state.cursor = new_cursor;
let last_applied_id = rank_state.last_applied_id().unwrap_or(0);
rank_state.clear_max_seen_if_caught_up(last_applied_id);
if successful_response
&& !saw_clear
&& rank_state
.max_seen_live_id
.is_some_and(|max_seen| max_seen > last_applied_id)
{
rank_state.recovery_inflight = true;
follow_up_start = Some(last_applied_id.saturating_add(1));
}
}
let follow_up_epoch = worker_state.epoch;
drop(worker_state);
if let Some(start_event_id) = follow_up_start {
self.spawn_recovery_task(key, follow_up_epoch, Some(start_event_id), None);
}
}
/// Query a worker's local KV indexer with exponential backoff retry.
async fn query_worker_with_retry(
async fn fetch_recovery_response(
&self,
worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> {
tracing::debug!(
"Attempting recovery from worker {worker_id} dp_rank {dp_rank}, \
start_event_id: {start_event_id:?}, end_event_id: {end_event_id:?}"
);
let mut last_error = None;
for attempt in 0..RECOVERY_MAX_RETRIES {
match self
.transport
.query_worker(worker_id, dp_rank, start_event_id, end_event_id)
.await
{
......@@ -275,79 +667,6 @@ impl WorkerQueryClient {
Err(last_error
.unwrap_or_else(|| anyhow::anyhow!("No response after {RECOVERY_MAX_RETRIES} retries")))
}
/// Recover missed KV events from a specific worker's dp_rank with retry logic.
///
/// Called both by the internal discovery loop (initial recovery) and by the
/// event plane task in subscriber.rs (gap recovery).
pub async fn recover_from_worker(
&self,
worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<usize> {
tracing::debug!(
"Attempting recovery from worker {worker_id} dp_rank {dp_rank}, \
start_event_id: {start_event_id:?}, end_event_id: {end_event_id:?}"
);
let response = self
.query_worker_with_retry(worker_id, dp_rank, start_event_id, end_event_id)
.await?;
let events = match response {
WorkerKvQueryResponse::Events(events) => {
tracing::debug!(
"Got {count} buffered events from worker {worker_id} dp_rank {dp_rank}",
count = events.len()
);
events
}
WorkerKvQueryResponse::TreeDump(events) => {
tracing::info!(
"Got tree dump from worker {worker_id} dp_rank {dp_rank} \
(range too old or unspecified), count: {count}",
count = events.len()
);
events
}
WorkerKvQueryResponse::TooNew {
requested_start,
requested_end,
newest_available,
} => {
tracing::warn!(
"Requested range [{requested_start:?}, {requested_end:?}] is newer than \
available (newest: {newest_available}) for worker {worker_id} dp_rank {dp_rank}"
);
return Ok(0);
}
WorkerKvQueryResponse::InvalidRange { start_id, end_id } => {
anyhow::bail!(
"Invalid range for worker {worker_id} dp_rank {dp_rank}: \
end_id ({end_id}) < start_id ({start_id})"
);
}
WorkerKvQueryResponse::Error(msg) => {
anyhow::bail!("Worker {worker_id} dp_rank {dp_rank} query error: {msg}");
}
};
let count = events.len();
if count == 0 {
tracing::debug!("No events to recover from worker {worker_id} dp_rank {dp_rank}");
return Ok(0);
}
tracing::info!("Recovered {count} events from worker {worker_id} dp_rank {dp_rank}");
for event in events {
self.indexer.apply_event(event).await;
}
Ok(count)
}
}
// ============================================================================
......@@ -443,11 +762,190 @@ impl AsyncEngine<SingleIn<WorkerKvQueryRequest>, ManyOut<WorkerKvQueryResponse>,
#[cfg(test)]
mod tests {
use super::*;
use crate::kv_router::Indexer;
use crate::kv_router::RouterEvent;
use crate::kv_router::indexer::KvIndexerMetrics;
use crate::kv_router::protocols::{KvCacheEvent, KvCacheEventData};
use crate::kv_router::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics};
use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash,
};
use dynamo_runtime::{DistributedRuntime, Runtime, distributed::DistributedConfig};
use std::collections::VecDeque;
use std::sync::Mutex as StdMutex;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
#[derive(Clone)]
struct MockQueryAction {
started: Option<Arc<Notify>>,
release: Option<Arc<Notify>>,
response: Result<WorkerKvQueryResponse, String>,
}
#[derive(Default)]
struct MockWorkerQueryTransport {
actions: DashMap<RecoveryKey, Arc<StdMutex<VecDeque<MockQueryAction>>>>,
#[allow(clippy::type_complexity)]
calls: Arc<StdMutex<Vec<(RecoveryKey, Option<u64>, Option<u64>)>>>,
}
impl MockWorkerQueryTransport {
fn push_action(&self, key: RecoveryKey, action: MockQueryAction) {
let queue = self
.actions
.entry(key)
.or_insert_with(|| Arc::new(StdMutex::new(VecDeque::new())))
.clone();
queue.lock().unwrap().push_back(action);
}
fn call_count(&self) -> usize {
self.calls.lock().unwrap().len()
}
}
#[async_trait]
impl WorkerQueryTransport for MockWorkerQueryTransport {
async fn query_worker(
&self,
worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> {
let key = (worker_id, dp_rank);
self.calls
.lock()
.unwrap()
.push((key, start_event_id, end_event_id));
let queue = self
.actions
.get(&key)
.unwrap_or_else(|| {
panic!("Missing action queue for worker {worker_id} dp_rank {dp_rank}")
})
.clone();
let action = queue.lock().unwrap().pop_front().unwrap_or_else(|| {
panic!("Missing action for worker {worker_id} dp_rank {dp_rank}")
});
if let Some(started) = action.started {
started.notify_waiters();
}
if let Some(release) = action.release {
release.notified().await;
}
match action.response {
Ok(response) => Ok(response),
Err(message) => Err(anyhow::anyhow!(message)),
}
}
}
async fn make_test_component(name: &str) -> Component {
let runtime = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(runtime, DistributedConfig::process_local())
.await
.unwrap();
let namespace = drt.namespace(format!("test-ns-{name}")).unwrap();
namespace
.component(format!("test-component-{name}"))
.unwrap()
}
fn make_test_indexer() -> (KvIndexer, Indexer) {
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let kv_indexer = KvIndexer::new(token, 4, metrics);
(kv_indexer.clone(), Indexer::KvIndexer(kv_indexer))
}
async fn make_test_client(
name: &str,
) -> (
Arc<WorkerQueryClient>,
Arc<MockWorkerQueryTransport>,
KvIndexer,
) {
let component = make_test_component(name).await;
let (kv_indexer, indexer) = make_test_indexer();
let transport = Arc::new(MockWorkerQueryTransport::default());
let client = WorkerQueryClient::new(component, indexer, transport.clone());
(client, transport, kv_indexer)
}
fn make_store_event(worker_id: WorkerId, dp_rank: DpRank, event_id: u64) -> RouterEvent {
RouterEvent::new(
worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(event_id),
tokens_hash: LocalBlockHash(event_id),
mm_extra_info: None,
}],
}),
dp_rank,
},
)
}
fn make_clear_event(worker_id: WorkerId, dp_rank: DpRank, event_id: u64) -> RouterEvent {
RouterEvent::new(
worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Cleared,
dp_rank,
},
)
}
fn stored_block_hashes(events: &[RouterEvent]) -> Vec<u64> {
let mut hashes = events
.iter()
.filter_map(|event| match &event.event.data {
KvCacheEventData::Stored(data) => {
data.blocks.first().map(|block| block.block_hash.0)
}
_ => None,
})
.collect::<Vec<_>>();
hashes.sort_unstable();
hashes
}
async fn wait_for<F>(mut check: F)
where
F: FnMut() -> bool,
{
for _ in 0..100 {
if check() {
return;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
panic!("condition not met before timeout");
}
fn rank_state_matches<F>(client: &Arc<WorkerQueryClient>, key: RecoveryKey, check: F) -> bool
where
F: FnOnce(&RankState) -> bool,
{
client
.worker_states
.get(&key.0)
.map(|worker_state| match worker_state.try_lock() {
Ok(worker_state) => worker_state.ranks.get(&key.1).is_some_and(check),
Err(_) => false,
})
.unwrap_or(false)
}
#[tokio::test]
async fn test_worker_kv_query_engine_returns_buffered_events() {
let worker_id = 7u64;
......@@ -497,4 +995,345 @@ mod tests {
other => panic!("Unexpected response: {other:?}"),
}
}
#[tokio::test]
async fn test_discovery_restore_does_not_block_other_workers() {
let (client, transport, kv_indexer) = make_test_client("discovery-concurrency").await;
let first_started = Arc::new(Notify::new());
let first_release = Arc::new(Notify::new());
transport.push_action(
(1, 0),
MockQueryAction {
started: Some(first_started.clone()),
release: Some(first_release.clone()),
response: Ok(WorkerKvQueryResponse::TreeDump {
events: vec![],
last_event_id: 0,
}),
},
);
transport.push_action(
(2, 0),
MockQueryAction {
started: None,
release: None,
response: Ok(WorkerKvQueryResponse::TreeDump {
events: vec![],
last_event_id: 0,
}),
},
);
client.handle_discovered_worker(1, 0).await;
first_started.notified().await;
client.handle_discovered_worker(2, 0).await;
wait_for(|| transport.call_count() == 2).await;
first_release.notify_waiters();
kv_indexer.flush().await;
}
#[tokio::test]
async fn test_gap_recovery_follows_high_water_mark() {
let (client, transport, kv_indexer) = make_test_client("high-water").await;
let key = (1, 0);
{
let worker_state = client.get_or_create_worker_state(key.0);
let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(key.1).or_default().cursor = RankCursor::Live(10);
}
let first_started = Arc::new(Notify::new());
let first_release = Arc::new(Notify::new());
transport.push_action(
key,
MockQueryAction {
started: Some(first_started.clone()),
release: Some(first_release.clone()),
response: Ok(WorkerKvQueryResponse::Events(
(11..=15).map(|id| make_store_event(1, 0, id)).collect(),
)),
},
);
transport.push_action(
key,
MockQueryAction {
started: None,
release: None,
response: Ok(WorkerKvQueryResponse::Events(
(16..=18).map(|id| make_store_event(1, 0, id)).collect(),
)),
},
);
client.handle_live_event(make_store_event(1, 0, 15)).await;
first_started.notified().await;
client.handle_live_event(make_store_event(1, 0, 16)).await;
client.handle_live_event(make_store_event(1, 0, 17)).await;
client.handle_live_event(make_store_event(1, 0, 18)).await;
first_release.notify_waiters();
wait_for(|| {
rank_state_matches(&client, key, |state| {
state.last_applied_id() == Some(18) && !state.recovery_inflight
})
})
.await;
kv_indexer.flush().await;
let events = kv_indexer.dump_events().await.unwrap();
assert_eq!(
stored_block_hashes(&events),
vec![11, 12, 13, 14, 15, 16, 17, 18]
);
}
#[tokio::test]
async fn test_initial_restore_updates_cursor_for_live_and_gap_paths() {
let (client, transport, kv_indexer) = make_test_client("initial-restore-cursor").await;
let key = (1, 0);
transport.push_action(
key,
MockQueryAction {
started: None,
release: None,
response: Ok(WorkerKvQueryResponse::TreeDump {
events: vec![],
last_event_id: 10,
}),
},
);
transport.push_action(
key,
MockQueryAction {
started: None,
release: None,
response: Ok(WorkerKvQueryResponse::Events(vec![
make_store_event(1, 0, 12),
make_store_event(1, 0, 13),
])),
},
);
client.handle_discovered_worker(1, 0).await;
wait_for(|| {
rank_state_matches(&client, key, |state| {
state.last_applied_id() == Some(10) && !state.recovery_inflight
})
})
.await;
assert_eq!(transport.call_count(), 1);
client.handle_live_event(make_store_event(1, 0, 11)).await;
wait_for(|| {
rank_state_matches(&client, key, |state| {
state.last_applied_id() == Some(11) && !state.recovery_inflight
})
})
.await;
assert_eq!(transport.call_count(), 1);
client.handle_live_event(make_store_event(1, 0, 13)).await;
wait_for(|| {
rank_state_matches(&client, key, |state| {
state.last_applied_id() == Some(13) && !state.recovery_inflight
})
})
.await;
assert_eq!(transport.call_count(), 2);
kv_indexer.flush().await;
let events = kv_indexer.dump_events().await.unwrap();
assert_eq!(stored_block_hashes(&events), vec![11, 12, 13]);
}
#[tokio::test]
async fn test_live_event_for_other_worker_is_not_blocked_by_inflight_recovery() {
let (client, transport, kv_indexer) = make_test_client("live-concurrency").await;
let delayed_key = (1, 0);
{
let worker_state = client.get_or_create_worker_state(delayed_key.0);
let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(delayed_key.1).or_default().cursor = RankCursor::Live(10);
}
let other_key = (2, 0);
{
let worker_state = client.get_or_create_worker_state(other_key.0);
let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(other_key.1).or_default().cursor = RankCursor::Live(20);
}
let started = Arc::new(Notify::new());
let release = Arc::new(Notify::new());
transport.push_action(
delayed_key,
MockQueryAction {
started: Some(started.clone()),
release: Some(release.clone()),
response: Ok(WorkerKvQueryResponse::Events(vec![
make_store_event(1, 0, 11),
make_store_event(1, 0, 12),
make_store_event(1, 0, 13),
])),
},
);
client.handle_live_event(make_store_event(1, 0, 13)).await;
started.notified().await;
client.handle_live_event(make_store_event(2, 0, 21)).await;
kv_indexer.flush().await;
let events = kv_indexer.dump_events().await.unwrap();
assert!(events.iter().any(|event| {
event.worker_id == 2
&& event.event.dp_rank == 0
&& matches!(
&event.event.data,
KvCacheEventData::Stored(data)
if data.blocks.first().map(|block| block.block_hash.0) == Some(21)
)
}));
release.notify_waiters();
}
#[tokio::test]
async fn test_worker_removal_discards_late_recovery_result() {
let (client, transport, kv_indexer) = make_test_client("remove-race").await;
let key = (1, 0);
{
let worker_state = client.get_or_create_worker_state(key.0);
let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(key.1).or_default().cursor = RankCursor::Live(10);
}
let started = Arc::new(Notify::new());
let release = Arc::new(Notify::new());
transport.push_action(
key,
MockQueryAction {
started: Some(started.clone()),
release: Some(release.clone()),
response: Ok(WorkerKvQueryResponse::Events(vec![
make_store_event(1, 0, 11),
make_store_event(1, 0, 12),
])),
},
);
client.handle_live_event(make_store_event(1, 0, 12)).await;
started.notified().await;
client.handle_removed_worker_dp(1, 0).await;
release.notify_waiters();
wait_for(|| !rank_state_matches(&client, key, |_| true)).await;
kv_indexer.flush().await;
let events = kv_indexer.dump_events().await.unwrap();
assert!(events.is_empty());
}
#[tokio::test]
async fn test_live_cleared_invalidates_inflight_recovery_without_restore() {
let (client, transport, kv_indexer) = make_test_client("live-cleared-no-restore").await;
let key0 = (1, 0);
let key1 = (1, 1);
{
let worker_state = client.get_or_create_worker_state(1);
let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(0).or_default().cursor = RankCursor::Live(10);
worker_state.ranks.entry(1).or_default().cursor = RankCursor::Live(20);
}
let started = Arc::new(Notify::new());
let release = Arc::new(Notify::new());
transport.push_action(
key0,
MockQueryAction {
started: Some(started.clone()),
release: Some(release.clone()),
response: Ok(WorkerKvQueryResponse::Events(vec![
make_store_event(1, 0, 11),
make_store_event(1, 0, 12),
make_store_event(1, 0, 13),
])),
},
);
client.handle_live_event(make_store_event(1, 0, 13)).await;
started.notified().await;
client.handle_live_event(make_clear_event(1, 0, 14)).await;
wait_for(|| transport.call_count() == 1).await;
release.notify_waiters();
wait_for(|| {
rank_state_matches(&client, key0, |state| {
state.last_applied_id() == Some(14) && !state.recovery_inflight
}) && rank_state_matches(&client, key1, |state| {
state.last_applied_id() == Some(20) && !state.recovery_inflight
})
})
.await;
client.handle_live_event(make_store_event(1, 0, 15)).await;
client.handle_live_event(make_store_event(1, 1, 30)).await;
kv_indexer.flush().await;
let events = kv_indexer.dump_events().await.unwrap();
assert_eq!(transport.call_count(), 1);
assert_eq!(stored_block_hashes(&events), vec![15, 30]);
}
#[tokio::test]
async fn test_recovered_cleared_resumes_live_without_restore() {
let (client, transport, kv_indexer) =
make_test_client("recovered-cleared-no-restore").await;
let key0 = (1, 0);
let key1 = (1, 1);
{
let worker_state = client.get_or_create_worker_state(1);
let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(0).or_default().cursor = RankCursor::Live(10);
worker_state.ranks.entry(1).or_default().cursor = RankCursor::Live(20);
}
transport.push_action(
key0,
MockQueryAction {
started: None,
release: None,
response: Ok(WorkerKvQueryResponse::Events(vec![
make_store_event(1, 0, 11),
make_clear_event(1, 0, 12),
make_store_event(1, 0, 13),
])),
},
);
client.handle_live_event(make_store_event(1, 0, 13)).await;
wait_for(|| {
rank_state_matches(&client, key0, |state| {
state.last_applied_id() == Some(13) && !state.recovery_inflight
}) && rank_state_matches(&client, key1, |state| {
state.last_applied_id() == Some(20) && !state.recovery_inflight
})
})
.await;
assert_eq!(transport.call_count(), 1);
client.handle_live_event(make_store_event(1, 0, 14)).await;
client.handle_live_event(make_store_event(1, 1, 30)).await;
kv_indexer.flush().await;
let events = kv_indexer.dump_events().await.unwrap();
assert_eq!(stored_block_hashes(&events), vec![13, 14, 30]);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment