"docs/features/vscode:/vscode.git/clone" did not exist on "435c8024945a9cef91226c891eec5694f556739e"
Unverified Commit 7389a369 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(kv-router): share recovery cursor state (#7596)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent a818a4bd
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
pub mod indexer; pub mod indexer;
pub mod protocols; pub mod protocols;
pub mod recovery;
pub mod scheduling; pub mod scheduling;
pub mod sequences; pub mod sequences;
pub mod zmq_wire; pub mod zmq_wire;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
/// Shared cursor state for monotonically increasing event streams.
///
/// `InvalidatedByBarrier` represents a semantic stream boundary such as a
/// worker-wide `Cleared` event. After such a barrier, callers must not attempt
/// to recover pre-barrier gaps.
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum CursorState {
#[default]
Initial,
Live(u64),
InvalidatedByBarrier(Option<u64>),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CursorObservation {
Initial {
got: u64,
},
Contiguous {
got: u64,
},
Gap {
expected: u64,
got: u64,
},
Stale {
got: u64,
last_applied: Option<u64>,
},
FreshAfterBarrier {
got: u64,
last_before_barrier: Option<u64>,
},
}
impl CursorState {
#[must_use]
pub fn last_applied_id(self) -> Option<u64> {
match self {
CursorState::Initial => None,
CursorState::Live(id) => Some(id),
CursorState::InvalidatedByBarrier(last_applied) => last_applied,
}
}
#[must_use]
pub fn observe(self, got: u64) -> CursorObservation {
match self {
CursorState::Initial => CursorObservation::Initial { got },
CursorState::Live(last) if got <= last => CursorObservation::Stale {
got,
last_applied: Some(last),
},
CursorState::Live(last) if got == last + 1 => CursorObservation::Contiguous { got },
CursorState::Live(last) => CursorObservation::Gap {
expected: last + 1,
got,
},
CursorState::InvalidatedByBarrier(last_before_barrier)
if last_before_barrier.is_some_and(|last| got <= last) =>
{
CursorObservation::Stale {
got,
last_applied: last_before_barrier,
}
}
CursorState::InvalidatedByBarrier(last_before_barrier) => {
CursorObservation::FreshAfterBarrier {
got,
last_before_barrier,
}
}
}
}
#[must_use]
pub fn advance_to(self, id: u64) -> Self {
let _ = self;
CursorState::Live(id)
}
#[must_use]
pub fn invalidate_by_barrier(self) -> Self {
CursorState::InvalidatedByBarrier(self.last_applied_id())
}
#[must_use]
pub fn apply_barrier(self, clear_id: u64) -> Self {
let _ = self;
CursorState::Live(clear_id)
}
}
#[cfg(test)]
mod tests {
use super::{CursorObservation, CursorState};
#[test]
fn initial_observation_preserves_first_id() {
assert_eq!(
CursorState::Initial.observe(0),
CursorObservation::Initial { got: 0 }
);
assert_eq!(
CursorState::Initial.observe(5),
CursorObservation::Initial { got: 5 }
);
}
#[test]
fn live_observation_detects_contiguous_gap_and_stale_ids() {
assert_eq!(
CursorState::Live(10).observe(11),
CursorObservation::Contiguous { got: 11 }
);
assert_eq!(
CursorState::Live(10).observe(15),
CursorObservation::Gap {
expected: 11,
got: 15,
}
);
assert_eq!(
CursorState::Live(10).observe(10),
CursorObservation::Stale {
got: 10,
last_applied: Some(10),
}
);
assert_eq!(
CursorState::Live(10).observe(9),
CursorObservation::Stale {
got: 9,
last_applied: Some(10),
}
);
}
#[test]
fn barrier_invalidation_preserves_last_applied_id() {
assert_eq!(
CursorState::Live(17).invalidate_by_barrier(),
CursorState::InvalidatedByBarrier(Some(17))
);
assert_eq!(
CursorState::InvalidatedByBarrier(Some(17)).observe(16),
CursorObservation::Stale {
got: 16,
last_applied: Some(17),
}
);
assert_eq!(
CursorState::InvalidatedByBarrier(Some(17)).observe(20),
CursorObservation::FreshAfterBarrier {
got: 20,
last_before_barrier: Some(17),
}
);
}
#[test]
fn apply_barrier_and_advance_restore_live_cursor() {
assert_eq!(
CursorState::Initial.apply_barrier(20),
CursorState::Live(20)
);
assert_eq!(CursorState::Initial.advance_to(7), CursorState::Live(7));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod cursor;
pub use cursor::{CursorObservation, CursorState};
...@@ -12,6 +12,7 @@ use tokio_util::sync::CancellationToken; ...@@ -12,6 +12,7 @@ use tokio_util::sync::CancellationToken;
use zeromq::{DealerSocket, Socket, SocketRecv, SocketSend, SubSocket}; use zeromq::{DealerSocket, Socket, SocketRecv, SocketSend, SubSocket};
use crate::protocols::{WorkerId, WorkerWithDpRank}; use crate::protocols::{WorkerId, WorkerWithDpRank};
use crate::recovery::{CursorObservation, CursorState};
use crate::zmq_wire::{KvEventBatch, convert_event}; use crate::zmq_wire::{KvEventBatch, convert_event};
use super::indexer::Indexer; use super::indexer::Indexer;
...@@ -31,41 +32,92 @@ fn calculate_backoff_ms(consecutive_errors: u32) -> u64 { ...@@ -31,41 +32,92 @@ fn calculate_backoff_ms(consecutive_errors: u32) -> u64 {
const WATERMARK_UNSET: u64 = u64::MAX; const WATERMARK_UNSET: u64 = u64::MAX;
fn gap_start(prev: u64, seq: u64) -> Option<u64> { fn cursor_from_watermark(watermark: u64) -> CursorState {
if prev == WATERMARK_UNSET { if watermark == WATERMARK_UNSET {
return (seq > 0).then_some(0); CursorState::Initial
} else {
CursorState::Live(watermark)
} }
}
(seq > prev + 1).then_some(prev + 1) struct ListenerLoop {
worker_id: WorkerId,
dp_rank: u32,
block_size: u32,
indexer: Indexer,
cancel: CancellationToken,
socket: SubSocket,
replay_socket: Option<DealerSocket>,
watermark: Arc<AtomicU64>,
warning_count: Arc<AtomicU32>,
consecutive_errors: u32,
messages_processed: u64,
} }
#[expect(clippy::too_many_arguments)] impl ListenerLoop {
async fn replay_gap( #[expect(clippy::too_many_arguments)]
replay_socket: &mut DealerSocket, fn new(
start_seq: u64,
end_seq: u64,
worker_id: WorkerId, worker_id: WorkerId,
dp_rank: u32, dp_rank: u32,
block_size: u32, block_size: u32,
indexer: &Indexer, indexer: Indexer,
warning_count: &Arc<AtomicU32>, cancel: CancellationToken,
watermark: &Arc<AtomicU64>, socket: SubSocket,
) -> u64 { replay_socket: Option<DealerSocket>,
tracing::info!( watermark: Arc<AtomicU64>,
) -> Self {
Self {
worker_id, worker_id,
dp_rank, dp_rank,
block_size,
indexer,
cancel,
socket,
replay_socket,
watermark,
warning_count: Arc::new(AtomicU32::new(0)),
consecutive_errors: 0,
messages_processed: 0,
}
}
fn cursor(&self) -> CursorState {
cursor_from_watermark(self.watermark.load(Ordering::Acquire))
}
async fn replay_gap(&mut self, start_seq: u64, end_seq: u64) -> u64 {
tracing::info!(
self.worker_id,
self.dp_rank,
start_seq, start_seq,
end_seq, end_seq,
"Requesting replay from engine" "Requesting replay from engine"
); );
let Some(replay_socket) = self.replay_socket.as_mut() else {
tracing::warn!(
self.worker_id,
self.dp_rank,
gap_size = end_seq.saturating_sub(start_seq),
"No replay endpoint configured; batches lost"
);
return 0;
};
let worker_id = self.worker_id;
let dp_rank = self.dp_rank;
let block_size = self.block_size;
let indexer = &self.indexer;
let warning_count = &self.warning_count;
let watermark = &self.watermark;
let req_frames = vec![Bytes::new(), Bytes::from(start_seq.to_be_bytes().to_vec())]; let req_frames = vec![Bytes::new(), Bytes::from(start_seq.to_be_bytes().to_vec())];
let Ok(req_msg) = zeromq::ZmqMessage::try_from(req_frames) else { let Ok(req_msg) = zeromq::ZmqMessage::try_from(req_frames) else {
tracing::error!(worker_id, dp_rank, "Failed to build replay request"); tracing::error!(worker_id, dp_rank, "Failed to build replay request");
return 0; return 0;
}; };
if let Err(e) = replay_socket.send(req_msg).await { if let Err(error) = replay_socket.send(req_msg).await {
tracing::error!(worker_id, dp_rank, error = %e, "Failed to send replay request"); tracing::error!(worker_id, dp_rank, error = %error, "Failed to send replay request");
return 0; return 0;
} }
...@@ -132,6 +184,157 @@ async fn replay_gap( ...@@ -132,6 +184,157 @@ async fn replay_gap(
tracing::info!(worker_id, dp_rank, replayed, "Replay complete"); tracing::info!(worker_id, dp_rank, replayed, "Replay complete");
replayed replayed
}
async fn handle_gap(&mut self, seq: u64) {
match self.cursor().observe(seq) {
CursorObservation::Initial { got } if got > 0 => {
tracing::warn!(
self.worker_id,
self.dp_rank,
expected = 0,
got,
"Gap detected: expected seq 0, got {got}"
);
self.replay_gap(0, got).await;
}
CursorObservation::Gap { expected, got } => {
tracing::warn!(
self.worker_id,
self.dp_rank,
expected,
got,
"Gap detected: expected seq {expected}, got {got}"
);
self.replay_gap(expected, got).await;
}
CursorObservation::Initial { .. }
| CursorObservation::Contiguous { .. }
| CursorObservation::Stale { .. }
| CursorObservation::FreshAfterBarrier { .. } => {}
}
}
async fn apply_live_batch(&mut self, seq: u64, payload: &[u8]) {
let batch = match rmps::from_slice::<KvEventBatch>(payload) {
Ok(batch) => batch,
Err(error) => {
tracing::warn!(
self.worker_id,
self.dp_rank,
"Failed to decode KvEventBatch: {error}"
);
return;
}
};
let effective_dp_rank = batch
.data_parallel_rank
.map_or(self.dp_rank, |rank| rank.cast_unsigned());
for raw_event in batch.events {
let placement_event = convert_event(
raw_event,
seq,
self.block_size,
WorkerWithDpRank::new(self.worker_id, effective_dp_rank),
&self.warning_count,
);
if !placement_event.placement.is_local_gpu() {
continue;
}
let router_event = placement_event
.into_router_event()
.expect("local worker placement must convert to router event");
self.indexer.apply_event(router_event).await;
self.messages_processed += 1;
}
self.watermark.store(seq, Ordering::Release);
}
async fn handle_message(&mut self, msg: zeromq::ZmqMessage) {
if msg.len() != 3 {
tracing::warn!(
self.worker_id,
self.dp_rank,
"Unexpected ZMQ frame count: {}",
msg.len()
);
return;
}
let seq_bytes = msg.get(1).expect("frame count checked above");
if seq_bytes.len() != 8 {
tracing::warn!(
self.worker_id,
self.dp_rank,
"Invalid sequence number length: {}",
seq_bytes.len()
);
return;
}
let seq = u64::from_be_bytes(seq_bytes[..8].try_into().expect("length checked above"));
self.handle_gap(seq).await;
if matches!(self.cursor().observe(seq), CursorObservation::Stale { .. }) {
return;
}
let payload = msg.get(2).expect("frame count checked above");
self.apply_live_batch(seq, payload).await;
}
async fn run(mut self) -> Result<(), String> {
loop {
let msg = tokio::select! {
biased;
_ = self.cancel.cancelled() => {
tracing::info!(
self.worker_id,
self.dp_rank,
self.messages_processed,
"ZMQ listener exiting after cancellation"
);
return Ok(());
}
msg_result = self.socket.recv() => {
match msg_result {
Ok(msg) => {
self.consecutive_errors = 0;
msg
}
Err(error) => {
self.consecutive_errors += 1;
if self.consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
return Err(format!(
"too many consecutive ZMQ recv errors for worker {} dp_rank {}: {error}",
self.worker_id,
self.dp_rank,
));
}
let backoff_ms = calculate_backoff_ms(self.consecutive_errors);
tracing::warn!(
error = %error,
consecutive_errors = self.consecutive_errors,
backoff_ms,
worker_id = self.worker_id,
dp_rank = self.dp_rank,
"ZMQ recv error, backing off"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
continue;
}
}
}
};
self.handle_message(msg).await;
}
}
} }
pub fn spawn_zmq_listener( pub fn spawn_zmq_listener(
...@@ -216,7 +419,7 @@ async fn run_listener( ...@@ -216,7 +419,7 @@ async fn run_listener(
return Ok(()); return Ok(());
} }
zmq_recv_loop( ListenerLoop::new(
worker_id, worker_id,
dp_rank, dp_rank,
block_size, block_size,
...@@ -226,6 +429,7 @@ async fn run_listener( ...@@ -226,6 +429,7 @@ async fn run_listener(
replay_socket, replay_socket,
watermark, watermark,
) )
.run()
.await .await
} }
...@@ -265,157 +469,28 @@ async fn connect_replay_socket( ...@@ -265,157 +469,28 @@ async fn connect_replay_socket(
} }
} }
#[expect(clippy::too_many_arguments)] #[cfg(test)]
async fn zmq_recv_loop( mod tests {
worker_id: WorkerId, use super::{WATERMARK_UNSET, cursor_from_watermark};
dp_rank: u32, use crate::recovery::CursorObservation;
block_size: u32, use zeromq::{PubSocket, Socket, SocketRecv, SocketSend, SubSocket};
indexer: Indexer,
cancel: CancellationToken,
mut socket: SubSocket,
mut replay_socket: Option<DealerSocket>,
watermark: Arc<AtomicU64>,
) -> Result<(), String> {
let warning_count = Arc::new(AtomicU32::new(0));
let mut consecutive_errors = 0u32;
let mut messages_processed = 0u64;
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => {
tracing::info!(
worker_id,
dp_rank,
messages_processed,
"ZMQ listener exiting after cancellation"
);
return Ok(());
}
msg_result = socket.recv() => {
let msg = match msg_result {
Ok(msg) => msg,
Err(e) => {
consecutive_errors += 1;
if consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
return Err(format!(
"too many consecutive ZMQ recv errors for worker {worker_id} dp_rank {dp_rank}: {e}"
));
}
let backoff_ms = calculate_backoff_ms(consecutive_errors);
tracing::warn!(
error = %e,
consecutive_errors,
backoff_ms,
worker_id,
dp_rank,
"ZMQ recv error, backing off"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
continue;
}
};
consecutive_errors = 0;
if msg.len() != 3 {
tracing::warn!(worker_id, dp_rank, "Unexpected ZMQ frame count: {}", msg.len());
continue;
}
let seq_bytes = msg.get(1).expect("frame count checked above");
if seq_bytes.len() != 8 {
tracing::warn!(
worker_id,
dp_rank,
"Invalid sequence number length: {}",
seq_bytes.len()
);
continue;
}
let seq = u64::from_be_bytes(seq_bytes[..8].try_into().expect("length checked above"));
let prev = watermark.load(Ordering::Acquire);
if let Some(gap_start) = gap_start(prev, seq) {
tracing::warn!(
worker_id,
dp_rank,
expected = gap_start,
got = seq,
"Gap detected: expected seq {gap_start}, got {seq}"
);
match replay_socket.as_mut() {
Some(socket) => {
replay_gap(
socket,
gap_start,
seq,
worker_id,
dp_rank,
block_size,
&indexer,
&warning_count,
&watermark,
)
.await;
}
None => tracing::warn!(
worker_id,
dp_rank,
gap_size = seq - gap_start,
"No replay endpoint configured; batches lost"
),
}
}
let current_wm = watermark.load(Ordering::Acquire);
if current_wm != WATERMARK_UNSET && current_wm >= seq {
continue;
}
let payload = msg.get(2).expect("frame count checked above"); #[test]
let batch = match rmps::from_slice::<KvEventBatch>(payload) { fn initial_gap_replays_from_zero_and_replayed_seq_becomes_stale() {
Ok(batch) => batch, let replay_start = match cursor_from_watermark(WATERMARK_UNSET).observe(5) {
Err(error) => { CursorObservation::Initial { got } if got > 0 => Some(0),
tracing::warn!(worker_id, dp_rank, "Failed to decode KvEventBatch: {error}"); CursorObservation::Gap { expected, .. } => Some(expected),
continue; _ => None,
}
}; };
assert_eq!(replay_start, Some(0));
let effective_dp_rank = batch assert!(matches!(
.data_parallel_rank cursor_from_watermark(5).observe(5),
.map_or(dp_rank, |rank| rank.cast_unsigned()); CursorObservation::Stale {
for raw_event in batch.events { got: 5,
let placement_event = convert_event( last_applied: Some(5),
raw_event,
seq,
block_size,
WorkerWithDpRank::new(worker_id, effective_dp_rank),
&warning_count,
);
if !placement_event.placement.is_local_gpu() {
continue;
}
let router_event = placement_event
.into_router_event()
.expect("local worker placement must convert to router event");
indexer.apply_event(router_event).await;
messages_processed += 1;
}
watermark.store(seq, Ordering::Release);
}
} }
));
} }
}
#[cfg(test)]
mod tests {
use zeromq::{PubSocket, Socket, SocketRecv, SocketSend, SubSocket};
#[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn zmq_buffers_messages_during_brief_delay() { async fn zmq_buffers_messages_during_brief_delay() {
......
...@@ -298,6 +298,7 @@ impl ListenerRecord { ...@@ -298,6 +298,7 @@ impl ListenerRecord {
} }
} }
#[allow(dead_code)]
fn status(&self) -> ListenerStatus { fn status(&self) -> ListenerStatus {
self.runtime.lock().status self.runtime.lock().status
} }
......
...@@ -25,6 +25,7 @@ use crate::kv_router::worker_kv_indexer_query_endpoint; ...@@ -25,6 +25,7 @@ use crate::kv_router::worker_kv_indexer_query_endpoint;
use dynamo_kv_router::{ use dynamo_kv_router::{
indexer::{LocalKvIndexer, WorkerKvQueryRequest, WorkerKvQueryResponse}, indexer::{LocalKvIndexer, WorkerKvQueryRequest, WorkerKvQueryResponse},
protocols::{DpRank, KvCacheEventData, RouterEvent, WorkerId}, protocols::{DpRank, KvCacheEventData, RouterEvent, WorkerId},
recovery::{CursorObservation, CursorState},
}; };
// Recovery retry configuration // Recovery retry configuration
...@@ -37,28 +38,16 @@ const QUERY_ENDPOINT_PREFIX: &str = "worker_kv_indexer_query_dp"; ...@@ -37,28 +38,16 @@ const QUERY_ENDPOINT_PREFIX: &str = "worker_kv_indexer_query_dp";
type RecoveryKey = (WorkerId, DpRank); type RecoveryKey = (WorkerId, DpRank);
#[derive(Clone, Copy, Debug, Default)]
enum RankCursor {
#[default]
NeedsRestore,
Live(u64),
InvalidatedByBarrier(Option<u64>),
}
#[derive(Debug, Default)] #[derive(Debug, Default)]
struct RankState { struct RankState {
cursor: RankCursor, cursor: CursorState,
max_seen_live_id: Option<u64>, max_seen_live_id: Option<u64>,
recovery_inflight: bool, recovery_inflight: bool,
} }
impl RankState { impl RankState {
fn last_applied_id(&self) -> Option<u64> { fn last_applied_id(&self) -> Option<u64> {
match self.cursor { self.cursor.last_applied_id()
RankCursor::NeedsRestore => None,
RankCursor::Live(event_id) => Some(event_id),
RankCursor::InvalidatedByBarrier(last_applied_id) => last_applied_id,
}
} }
fn observe_live_id(&mut self, event_id: u64) { fn observe_live_id(&mut self, event_id: u64) {
...@@ -301,9 +290,7 @@ impl WorkerQueryClient { ...@@ -301,9 +290,7 @@ impl WorkerQueryClient {
let spawn = { let spawn = {
let mut worker_state = worker_state.lock().await; let mut worker_state = worker_state.lock().await;
let rank_state = worker_state.ranks.entry(dp_rank).or_default(); let rank_state = worker_state.ranks.entry(dp_rank).or_default();
if matches!(rank_state.cursor, RankCursor::NeedsRestore) if matches!(rank_state.cursor, CursorState::Initial) && !rank_state.recovery_inflight {
&& !rank_state.recovery_inflight
{
tracing::info!( tracing::info!(
"WorkerQueryClient: discovered worker {worker_id} dp_rank {dp_rank}, scheduling restore" "WorkerQueryClient: discovered worker {worker_id} dp_rank {dp_rank}, scheduling restore"
); );
...@@ -350,13 +337,13 @@ impl WorkerQueryClient { ...@@ -350,13 +337,13 @@ impl WorkerQueryClient {
worker_state.epoch += 1; worker_state.epoch += 1;
for rank_state in worker_state.ranks.values_mut() { for rank_state in worker_state.ranks.values_mut() {
rank_state.cursor = RankCursor::InvalidatedByBarrier(rank_state.last_applied_id()); rank_state.cursor = rank_state.cursor.invalidate_by_barrier();
rank_state.max_seen_live_id = None; rank_state.max_seen_live_id = None;
rank_state.recovery_inflight = false; rank_state.recovery_inflight = false;
} }
let rank_state = worker_state.ranks.entry(clear_dp_rank).or_default(); let rank_state = worker_state.ranks.entry(clear_dp_rank).or_default();
rank_state.cursor = RankCursor::Live(clear_event_id); rank_state.cursor = rank_state.cursor.apply_barrier(clear_event_id);
tracing::info!( tracing::info!(
"Applying clear barrier for worker {worker_id}; invalidating recovery across {} dp_ranks", "Applying clear barrier for worker {worker_id}; invalidating recovery across {} dp_ranks",
...@@ -394,64 +381,44 @@ impl WorkerQueryClient { ...@@ -394,64 +381,44 @@ impl WorkerQueryClient {
} }
// Already applied the event, so no further action needed. // Already applied the event, so no further action needed.
return; return;
} else { }
match rank_state.cursor {
// We have never established a cursor for this rank, so live traffic only tells match rank_state.cursor.observe(event_id) {
// us how far ahead the stream has moved while a full restore catches up. CursorObservation::Stale { .. } => return,
RankCursor::NeedsRestore => { observation if rank_state.recovery_inflight => {
match observation {
CursorObservation::Initial { .. }
| CursorObservation::Contiguous { .. }
| CursorObservation::Gap { .. }
| CursorObservation::FreshAfterBarrier { .. } => {
rank_state.observe_live_id(event_id);
}
CursorObservation::Stale { .. } => {}
}
return;
}
CursorObservation::Initial { .. } => {
rank_state.observe_live_id(event_id); rank_state.observe_live_id(event_id);
if !rank_state.recovery_inflight {
rank_state.recovery_inflight = true; rank_state.recovery_inflight = true;
Action::SpawnFullRestore { Action::SpawnFullRestore {
epoch: worker_state.epoch, epoch: worker_state.epoch,
} }
} else {
// A recovery is already in flight. Nothing to do.
return;
}
} }
// Normal steady-state path: apply contiguous events directly, but coalesce any CursorObservation::Gap { expected, .. } => {
// gap into a single recovery pass using `max_seen_live_id` as the high-water mark.
RankCursor::Live(last_applied_id) => {
if event_id <= last_applied_id {
// We've already applied this event. Nothing to do.
return;
} else if rank_state.recovery_inflight {
// A recovery is already in flight. Drop the event for now, and potentially spawn a new recovery afterwards.
rank_state.observe_live_id(event_id);
return;
} else if event_id > last_applied_id.saturating_add(1) {
// We've detected a gap. Spawn a new recovery pass.
rank_state.observe_live_id(event_id); rank_state.observe_live_id(event_id);
rank_state.recovery_inflight = true; rank_state.recovery_inflight = true;
Action::SpawnIncremental { Action::SpawnIncremental {
epoch: worker_state.epoch, epoch: worker_state.epoch,
start_event_id: last_applied_id.saturating_add(1), start_event_id: expected,
} }
} else {
// Apply the event.
rank_state.cursor = RankCursor::Live(event_id);
rank_state.clear_max_seen_if_caught_up(event_id);
Action::ApplyDirect
} }
} CursorObservation::Contiguous { got }
// A worker-wide barrier (currently `Cleared`) invalidated this rank's old | CursorObservation::FreshAfterBarrier { got, .. } => {
// cursor. The next newer live event becomes the new starting point; we do not rank_state.cursor = rank_state.cursor.advance_to(got);
// recover across the barrier. rank_state.clear_max_seen_if_caught_up(got);
RankCursor::InvalidatedByBarrier(last_applied_id) => {
if last_applied_id
.is_some_and(|last_applied_id| event_id <= last_applied_id)
{
return;
} else {
rank_state.cursor = RankCursor::Live(event_id);
rank_state.max_seen_live_id = None;
rank_state.recovery_inflight = false;
Action::ApplyDirect Action::ApplyDirect
} }
} }
}
}
}; };
match action { match action {
...@@ -531,12 +498,12 @@ impl WorkerQueryClient { ...@@ -531,12 +498,12 @@ impl WorkerQueryClient {
if matches!(&event.event.data, KvCacheEventData::Cleared) { if matches!(&event.event.data, KvCacheEventData::Cleared) {
self.apply_worker_clear_locked(&mut worker_state, event) self.apply_worker_clear_locked(&mut worker_state, event)
.await; .await;
new_cursor = RankCursor::Live(event_id); new_cursor = new_cursor.apply_barrier(event_id);
saw_clear = true; saw_clear = true;
continue; continue;
} }
self.indexer.apply_event(event).await; self.indexer.apply_event(event).await;
new_cursor = RankCursor::Live(event_id); new_cursor = new_cursor.advance_to(event_id);
} }
successful_response = true; successful_response = true;
} }
...@@ -554,16 +521,14 @@ impl WorkerQueryClient { ...@@ -554,16 +521,14 @@ impl WorkerQueryClient {
for event in &events { for event in &events {
self.indexer.apply_event(event.clone()).await; self.indexer.apply_event(event.clone()).await;
} }
new_cursor = RankCursor::Live(last_event_id); new_cursor = new_cursor.advance_to(last_event_id);
successful_response = true; successful_response = true;
} }
Ok(WorkerKvQueryResponse::TooNew { Ok(WorkerKvQueryResponse::TooNew {
requested_start, newest_available, ..
requested_end,
newest_available,
}) => { }) => {
tracing::warn!( tracing::warn!(
"Requested range [{requested_start:?}, {requested_end:?}] is newer than available (newest: {newest_available}) for worker {} dp_rank {}", "Requested recovery is newer than available (newest: {newest_available}) for worker {} dp_rank {}",
key.0, key.0,
key.1 key.1
); );
...@@ -803,6 +768,10 @@ mod tests { ...@@ -803,6 +768,10 @@ mod tests {
fn call_count(&self) -> usize { fn call_count(&self) -> usize {
self.calls.lock().unwrap().len() self.calls.lock().unwrap().len()
} }
fn calls(&self) -> Vec<(RecoveryKey, Option<u64>, Option<u64>)> {
self.calls.lock().unwrap().clone()
}
} }
#[async_trait] #[async_trait]
...@@ -1043,7 +1012,7 @@ mod tests { ...@@ -1043,7 +1012,7 @@ mod tests {
{ {
let worker_state = client.get_or_create_worker_state(key.0); let worker_state = client.get_or_create_worker_state(key.0);
let mut worker_state = worker_state.lock().await; let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(key.1).or_default().cursor = RankCursor::Live(10); worker_state.ranks.entry(key.1).or_default().cursor = CursorState::Live(10);
} }
let first_started = Arc::new(Notify::new()); let first_started = Arc::new(Notify::new());
...@@ -1082,6 +1051,10 @@ mod tests { ...@@ -1082,6 +1051,10 @@ mod tests {
}) })
}) })
.await; .await;
assert_eq!(
transport.calls(),
vec![(key, Some(11), None), (key, Some(16), None)]
);
kv_indexer.flush().await; kv_indexer.flush().await;
let events = kv_indexer.dump_events().await.unwrap(); let events = kv_indexer.dump_events().await.unwrap();
...@@ -1159,13 +1132,13 @@ mod tests { ...@@ -1159,13 +1132,13 @@ mod tests {
{ {
let worker_state = client.get_or_create_worker_state(delayed_key.0); let worker_state = client.get_or_create_worker_state(delayed_key.0);
let mut worker_state = worker_state.lock().await; let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(delayed_key.1).or_default().cursor = RankCursor::Live(10); worker_state.ranks.entry(delayed_key.1).or_default().cursor = CursorState::Live(10);
} }
let other_key = (2, 0); let other_key = (2, 0);
{ {
let worker_state = client.get_or_create_worker_state(other_key.0); let worker_state = client.get_or_create_worker_state(other_key.0);
let mut worker_state = worker_state.lock().await; let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(other_key.1).or_default().cursor = RankCursor::Live(20); worker_state.ranks.entry(other_key.1).or_default().cursor = CursorState::Live(20);
} }
let started = Arc::new(Notify::new()); let started = Arc::new(Notify::new());
...@@ -1210,7 +1183,7 @@ mod tests { ...@@ -1210,7 +1183,7 @@ mod tests {
{ {
let worker_state = client.get_or_create_worker_state(key.0); let worker_state = client.get_or_create_worker_state(key.0);
let mut worker_state = worker_state.lock().await; let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(key.1).or_default().cursor = RankCursor::Live(10); worker_state.ranks.entry(key.1).or_default().cursor = CursorState::Live(10);
} }
let started = Arc::new(Notify::new()); let started = Arc::new(Notify::new());
...@@ -1247,8 +1220,8 @@ mod tests { ...@@ -1247,8 +1220,8 @@ mod tests {
{ {
let worker_state = client.get_or_create_worker_state(1); let worker_state = client.get_or_create_worker_state(1);
let mut worker_state = worker_state.lock().await; let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(0).or_default().cursor = RankCursor::Live(10); worker_state.ranks.entry(0).or_default().cursor = CursorState::Live(10);
worker_state.ranks.entry(1).or_default().cursor = RankCursor::Live(20); worker_state.ranks.entry(1).or_default().cursor = CursorState::Live(20);
} }
let started = Arc::new(Notify::new()); let started = Arc::new(Notify::new());
...@@ -1280,6 +1253,9 @@ mod tests { ...@@ -1280,6 +1253,9 @@ mod tests {
}) })
}) })
.await; .await;
assert!(rank_state_matches(&client, key1, |state| {
matches!(state.cursor, CursorState::InvalidatedByBarrier(Some(20)))
}));
client.handle_live_event(make_store_event(1, 0, 15)).await; client.handle_live_event(make_store_event(1, 0, 15)).await;
client.handle_live_event(make_store_event(1, 1, 30)).await; client.handle_live_event(make_store_event(1, 1, 30)).await;
...@@ -1300,8 +1276,8 @@ mod tests { ...@@ -1300,8 +1276,8 @@ mod tests {
{ {
let worker_state = client.get_or_create_worker_state(1); let worker_state = client.get_or_create_worker_state(1);
let mut worker_state = worker_state.lock().await; let mut worker_state = worker_state.lock().await;
worker_state.ranks.entry(0).or_default().cursor = RankCursor::Live(10); worker_state.ranks.entry(0).or_default().cursor = CursorState::Live(10);
worker_state.ranks.entry(1).or_default().cursor = RankCursor::Live(20); worker_state.ranks.entry(1).or_default().cursor = CursorState::Live(20);
} }
transport.push_action( transport.push_action(
...@@ -1327,6 +1303,9 @@ mod tests { ...@@ -1327,6 +1303,9 @@ mod tests {
}) })
}) })
.await; .await;
assert!(rank_state_matches(&client, key1, |state| {
matches!(state.cursor, CursorState::InvalidatedByBarrier(Some(20)))
}));
assert_eq!(transport.call_count(), 1); assert_eq!(transport.call_count(), 1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment