Unverified Commit c7f6f6d9 authored by Janelle Cai's avatar Janelle Cai Committed by GitHub
Browse files

fix: race between worker discovery and runtimeconfig discovery in KV router (#5924)


Signed-off-by: default avatarJanelle Cai <jcai18@mit.edu>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarYan Ru Pei <yanrpei@gmail.com>
parent 72a12869
...@@ -386,11 +386,12 @@ impl KvRouter { ...@@ -386,11 +386,12 @@ impl KvRouter {
.await?; .await?;
// Initialize worker query client using namespace abstraction // Initialize worker query client using namespace abstraction
// (created before background task so we can use it for startup recovery) // (for query/recovery API methods - no lifecycle tracking needed)
// Uses a subscriber from workers_with_configs // Uses a subscriber from workers_with_configs
let worker_query_client = worker_query::WorkerQueryClient::new( let worker_query_client = worker_query::WorkerQueryClient::new(
component.clone(), component.clone(),
workers_with_configs.subscribe(), workers_with_configs.subscribe(),
None, // No removal channel - query only
); );
tracing::info!("Worker query client initialized"); tracing::info!("Worker query client initialized");
...@@ -431,11 +432,11 @@ impl KvRouter { ...@@ -431,11 +432,11 @@ impl KvRouter {
start_kv_router_background_event_plane( start_kv_router_background_event_plane(
component.clone(), component.clone(),
kv_indexer.event_sender(), kv_indexer.event_sender(),
kv_indexer.remove_worker_sender(),
cancellation_token.clone(), cancellation_token.clone(),
worker_query::WorkerQueryClient::new( worker_query::WorkerQueryClient::new(
component.clone(), component.clone(),
workers_with_configs.subscribe(), workers_with_configs.subscribe(),
Some(kv_indexer.remove_worker_sender()),
), ),
transport_kind, transport_kind,
) )
......
...@@ -495,41 +495,6 @@ pub async fn start_kv_router_background( ...@@ -495,41 +495,6 @@ pub async fn start_kv_router_background(
Ok(()) Ok(())
} }
/// Handle a worker discovery event (added or removed).
async fn handle_worker_discovery(
event: DiscoveryEvent,
worker_query_client: &WorkerQueryClient,
kv_events_tx: &mpsc::Sender<RouterEvent>,
remove_worker_tx: &mpsc::Sender<WorkerId>,
) {
match event {
DiscoveryEvent::Added(instance) => {
let worker_id = instance.instance_id();
tracing::info!(
"DISCOVERY: Worker {worker_id} added, dumping local indexer into router"
);
let total_recovered = worker_query_client
.recover_all_dp_ranks(worker_id, kv_events_tx)
.await;
if total_recovered > 0 {
tracing::info!(
"DISCOVERY: Worker {worker_id} total recovered {total_recovered} events"
);
}
}
DiscoveryEvent::Removed(id) => {
let worker_id = id.instance_id();
tracing::warn!("DISCOVERY: Worker {worker_id} removed, removing from router indexer");
if let Err(e) = remove_worker_tx.send(worker_id).await {
tracing::warn!("Failed to send worker removal for worker {worker_id}: {e}");
}
}
}
}
/// Start a simplified background task for event consumption using the event plane. /// Start a simplified background task for event consumption using the event plane.
/// ///
/// This is used when local indexer mode is enabled. Unlike `start_kv_router_background`, /// This is used when local indexer mode is enabled. Unlike `start_kv_router_background`,
...@@ -546,7 +511,6 @@ async fn handle_worker_discovery( ...@@ -546,7 +511,6 @@ async fn handle_worker_discovery(
pub async fn start_kv_router_background_event_plane( pub async fn start_kv_router_background_event_plane(
component: Component, component: Component,
kv_events_tx: mpsc::Sender<RouterEvent>, kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<WorkerId>,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
mut worker_query_client: WorkerQueryClient, mut worker_query_client: WorkerQueryClient,
transport_kind: EventTransportKind, transport_kind: EventTransportKind,
...@@ -587,18 +551,10 @@ pub async fn start_kv_router_background_event_plane( ...@@ -587,18 +551,10 @@ pub async fn start_kv_router_background_event_plane(
ready_workers.len() ready_workers.len()
); );
// Recover initial state from all ready workers (all dp_ranks) // Recover initial state from all workers with local indexer enabled
for worker_id in &ready_workers { worker_query_client
if worker_query_client.has_local_indexer(*worker_id) { .process_and_recover_workers(&kv_events_tx, "Initial recovery")
worker_query_client .await;
.recover_all_dp_ranks(*worker_id, &kv_events_tx)
.await;
}
}
// Get instance discovery stream for ongoing monitoring of worker add/remove events
let mut instance_event_stream =
get_instance_discovery_stream(&component, &cancellation_token).await?;
tokio::spawn(async move { tokio::spawn(async move {
// Track last received event ID per (worker, dp_rank) for gap detection // Track last received event ID per (worker, dp_rank) for gap detection
...@@ -614,19 +570,16 @@ pub async fn start_kv_router_background_event_plane( ...@@ -614,19 +570,16 @@ pub async fn start_kv_router_background_event_plane(
break; break;
} }
// Handle generate endpoint instance add/remove events // Handle runtime config changes (worker add/remove, recovery for new workers)
Some(discovery_event_result) = instance_event_stream.next() => { result = worker_query_client.wait_for_config_change() => {
let Ok(event) = discovery_event_result else { if result.is_err() {
tracing::warn!("Runtime config watch sender dropped");
continue; continue;
}; }
handle_worker_discovery( worker_query_client
event, .process_and_recover_workers(&kv_events_tx, "DISCOVERY")
&worker_query_client, .await;
&kv_events_tx,
&remove_worker_tx,
)
.await;
} }
// Handle event consumption from event plane subscription // Handle event consumption from event plane subscription
...@@ -708,14 +661,12 @@ pub async fn start_kv_router_background_event_plane( ...@@ -708,14 +661,12 @@ pub async fn start_kv_router_background_event_plane(
pub async fn start_kv_router_background_nats_core( pub async fn start_kv_router_background_nats_core(
component: Component, component: Component,
kv_events_tx: mpsc::Sender<RouterEvent>, kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<WorkerId>,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
worker_query_client: WorkerQueryClient, worker_query_client: WorkerQueryClient,
) -> Result<()> { ) -> Result<()> {
start_kv_router_background_event_plane( start_kv_router_background_event_plane(
component, component,
kv_events_tx, kv_events_tx,
remove_worker_tx,
cancellation_token, cancellation_token,
worker_query_client, worker_query_client,
EventTransportKind::Nats, EventTransportKind::Nats,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::HashSet;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
...@@ -33,21 +34,41 @@ const RECOVERY_INITIAL_BACKOFF_MS: u64 = 200; ...@@ -33,21 +34,41 @@ const RECOVERY_INITIAL_BACKOFF_MS: u64 = 200;
/// ///
/// Each dp_rank has its own LocalKvIndexer and query endpoint, so we maintain separate /// Each dp_rank has its own LocalKvIndexer and query endpoint, so we maintain separate
/// routers per dp_rank to ensure queries go to the correct endpoint. /// routers per dp_rank to ensure queries go to the correct endpoint.
///
/// Also handles worker lifecycle (add/remove) by tracking known workers and sending
/// removal events to the router indexer.
pub struct WorkerQueryClient { pub struct WorkerQueryClient {
component: Component, component: Component,
/// Subscriber for runtime configs (includes shared configs DashMap) /// Subscriber for runtime configs (includes shared configs DashMap)
subscriber: RuntimeConfigsSubscriber, subscriber: RuntimeConfigsSubscriber,
/// Routers keyed by dp_rank - each dp_rank has its own endpoint /// Routers keyed by dp_rank - each dp_rank has its own endpoint
routers: DashMap<DpRank, Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>>, routers: DashMap<DpRank, Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>>,
/// Workers that have been successfully recovered (full recovery)
recovered: HashSet<WorkerId>,
/// Workers we know about (to detect removals)
known_workers: HashSet<WorkerId>,
/// Channel to send worker removal events to the router indexer (optional)
remove_worker_tx: Option<mpsc::Sender<WorkerId>>,
} }
impl WorkerQueryClient { impl WorkerQueryClient {
/// Create a new WorkerQueryClient with a subscriber to runtime configs /// Create a new WorkerQueryClient with a subscriber to runtime configs.
pub fn new(component: Component, subscriber: RuntimeConfigsSubscriber) -> Self { ///
/// If `remove_worker_tx` is provided, this client will handle worker lifecycle
/// (tracking known workers, sending removal events). If None, lifecycle tracking
/// is disabled (suitable for query-only usage).
pub fn new(
component: Component,
subscriber: RuntimeConfigsSubscriber,
remove_worker_tx: Option<mpsc::Sender<WorkerId>>,
) -> Self {
Self { Self {
component, component,
subscriber, subscriber,
routers: DashMap::new(), routers: DashMap::new(),
recovered: HashSet::new(),
known_workers: HashSet::new(),
remove_worker_tx,
} }
} }
...@@ -57,8 +78,84 @@ impl WorkerQueryClient { ...@@ -57,8 +78,84 @@ impl WorkerQueryClient {
self.subscriber.wait_for_some().await self.subscriber.wait_for_some().await
} }
/// Wait for runtime config changes.
/// Returns Ok(()) when configs have changed, or Err if the sender was dropped.
pub async fn wait_for_config_change(
&mut self,
) -> Result<(), tokio::sync::watch::error::RecvError> {
self.subscriber.change_rx.changed().await
}
/// Process config changes and recover pending workers.
///
/// This method:
/// 1. Detects removed workers and sends removal events (if remove_worker_tx is set)
/// 2. Recovers workers that have config + local_indexer enabled but haven't been recovered yet
/// 3. Marks recovered workers so they won't be recovered again
///
/// Should be called after `wait_for_config_change()` returns.
///
/// # Arguments
/// * `event_tx` - Channel to send recovered events to the router indexer
/// * `log_prefix` - Prefix for log messages (e.g., "DISCOVERY" or "Initial recovery")
pub async fn process_and_recover_workers(
&mut self,
event_tx: &mpsc::Sender<RouterEvent>,
log_prefix: &str,
) {
// Get current workers from configs
let current_workers: HashSet<WorkerId> =
self.subscriber.configs.iter().map(|r| *r.key()).collect();
// Handle removed workers (only if we have a removal channel)
if let Some(ref remove_worker_tx) = self.remove_worker_tx {
for worker_id in self.known_workers.difference(&current_workers) {
self.recovered.remove(worker_id);
tracing::warn!(
"{log_prefix}: Worker {worker_id} removed, removing from router indexer"
);
if let Err(e) = remove_worker_tx.send(*worker_id).await {
tracing::warn!("Failed to send worker removal for worker {worker_id}: {e}");
}
}
}
self.known_workers = current_workers;
// Find workers needing recovery:
// - Has config (Some)
// - Has local_indexer enabled
// - Not yet recovered
let workers_to_recover: Vec<WorkerId> = self
.subscriber
.configs
.iter()
.filter(|r| {
r.value()
.as_ref()
.map(|c| c.enable_local_indexer)
.unwrap_or(false)
})
.map(|r| *r.key())
.filter(|id| !self.recovered.contains(id))
.collect();
// Recover each worker
for worker_id in workers_to_recover {
tracing::info!(
"{log_prefix}: Worker {worker_id} added, dumping local indexer into router"
);
let recovered = self.recover_all_dp_ranks(worker_id, event_tx).await;
if recovered > 0 {
tracing::info!(
"{log_prefix}: Worker {worker_id} recovered {recovered} events from local indexer"
);
}
self.recovered.insert(worker_id);
}
}
/// Check if a worker has local indexer enabled /// Check if a worker has local indexer enabled
pub fn has_local_indexer(&self, worker_id: WorkerId) -> bool { fn has_local_indexer(&self, worker_id: WorkerId) -> bool {
self.subscriber self.subscriber
.configs .configs
.get(&worker_id) .get(&worker_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment