Unverified Commit c7f6f6d9 authored by Janelle Cai's avatar Janelle Cai Committed by GitHub
Browse files

fix: race between worker discovery and runtimeconfig discovery in KV router (#5924)


Signed-off-by: default avatarJanelle Cai <jcai18@mit.edu>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarYan Ru Pei <yanrpei@gmail.com>
parent 72a12869
......@@ -386,11 +386,12 @@ impl KvRouter {
.await?;
// Initialize worker query client using namespace abstraction
// (created before background task so we can use it for startup recovery)
// (for query/recovery API methods - no lifecycle tracking needed)
// Uses a subscriber from workers_with_configs
let worker_query_client = worker_query::WorkerQueryClient::new(
component.clone(),
workers_with_configs.subscribe(),
None, // No removal channel - query only
);
tracing::info!("Worker query client initialized");
......@@ -431,11 +432,11 @@ impl KvRouter {
start_kv_router_background_event_plane(
component.clone(),
kv_indexer.event_sender(),
kv_indexer.remove_worker_sender(),
cancellation_token.clone(),
worker_query::WorkerQueryClient::new(
component.clone(),
workers_with_configs.subscribe(),
Some(kv_indexer.remove_worker_sender()),
),
transport_kind,
)
......
......@@ -495,41 +495,6 @@ pub async fn start_kv_router_background(
Ok(())
}
/// Handle a worker discovery event (added or removed).
async fn handle_worker_discovery(
event: DiscoveryEvent,
worker_query_client: &WorkerQueryClient,
kv_events_tx: &mpsc::Sender<RouterEvent>,
remove_worker_tx: &mpsc::Sender<WorkerId>,
) {
match event {
DiscoveryEvent::Added(instance) => {
let worker_id = instance.instance_id();
tracing::info!(
"DISCOVERY: Worker {worker_id} added, dumping local indexer into router"
);
let total_recovered = worker_query_client
.recover_all_dp_ranks(worker_id, kv_events_tx)
.await;
if total_recovered > 0 {
tracing::info!(
"DISCOVERY: Worker {worker_id} total recovered {total_recovered} events"
);
}
}
DiscoveryEvent::Removed(id) => {
let worker_id = id.instance_id();
tracing::warn!("DISCOVERY: Worker {worker_id} removed, removing from router indexer");
if let Err(e) = remove_worker_tx.send(worker_id).await {
tracing::warn!("Failed to send worker removal for worker {worker_id}: {e}");
}
}
}
}
/// Start a simplified background task for event consumption using the event plane.
///
/// This is used when local indexer mode is enabled. Unlike `start_kv_router_background`,
......@@ -546,7 +511,6 @@ async fn handle_worker_discovery(
pub async fn start_kv_router_background_event_plane(
component: Component,
kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<WorkerId>,
cancellation_token: CancellationToken,
mut worker_query_client: WorkerQueryClient,
transport_kind: EventTransportKind,
......@@ -587,18 +551,10 @@ pub async fn start_kv_router_background_event_plane(
ready_workers.len()
);
// Recover initial state from all ready workers (all dp_ranks)
for worker_id in &ready_workers {
if worker_query_client.has_local_indexer(*worker_id) {
// Recover initial state from all workers with local indexer enabled
worker_query_client
.recover_all_dp_ranks(*worker_id, &kv_events_tx)
.process_and_recover_workers(&kv_events_tx, "Initial recovery")
.await;
}
}
// Get instance discovery stream for ongoing monitoring of worker add/remove events
let mut instance_event_stream =
get_instance_discovery_stream(&component, &cancellation_token).await?;
tokio::spawn(async move {
// Track last received event ID per (worker, dp_rank) for gap detection
......@@ -614,18 +570,15 @@ pub async fn start_kv_router_background_event_plane(
break;
}
// Handle generate endpoint instance add/remove events
Some(discovery_event_result) = instance_event_stream.next() => {
let Ok(event) = discovery_event_result else {
// Handle runtime config changes (worker add/remove, recovery for new workers)
result = worker_query_client.wait_for_config_change() => {
if result.is_err() {
tracing::warn!("Runtime config watch sender dropped");
continue;
};
}
handle_worker_discovery(
event,
&worker_query_client,
&kv_events_tx,
&remove_worker_tx,
)
worker_query_client
.process_and_recover_workers(&kv_events_tx, "DISCOVERY")
.await;
}
......@@ -708,14 +661,12 @@ pub async fn start_kv_router_background_event_plane(
pub async fn start_kv_router_background_nats_core(
component: Component,
kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<WorkerId>,
cancellation_token: CancellationToken,
worker_query_client: WorkerQueryClient,
) -> Result<()> {
start_kv_router_background_event_plane(
component,
kv_events_tx,
remove_worker_tx,
cancellation_token,
worker_query_client,
EventTransportKind::Nats,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
......@@ -33,21 +34,41 @@ const RECOVERY_INITIAL_BACKOFF_MS: u64 = 200;
///
/// Each dp_rank has its own LocalKvIndexer and query endpoint, so we maintain separate
/// routers per dp_rank to ensure queries go to the correct endpoint.
///
/// Also handles worker lifecycle (add/remove) by tracking known workers and sending
/// removal events to the router indexer.
pub struct WorkerQueryClient {
component: Component,
/// Subscriber for runtime configs (includes shared configs DashMap)
subscriber: RuntimeConfigsSubscriber,
/// Routers keyed by dp_rank - each dp_rank has its own endpoint
routers: DashMap<DpRank, Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>>,
/// Workers that have been successfully recovered (full recovery)
recovered: HashSet<WorkerId>,
/// Workers we know about (to detect removals)
known_workers: HashSet<WorkerId>,
/// Channel to send worker removal events to the router indexer (optional)
remove_worker_tx: Option<mpsc::Sender<WorkerId>>,
}
impl WorkerQueryClient {
/// Create a new WorkerQueryClient with a subscriber to runtime configs
pub fn new(component: Component, subscriber: RuntimeConfigsSubscriber) -> Self {
/// Create a new WorkerQueryClient with a subscriber to runtime configs.
///
/// If `remove_worker_tx` is provided, this client will handle worker lifecycle
/// (tracking known workers, sending removal events). If None, lifecycle tracking
/// is disabled (suitable for query-only usage).
pub fn new(
component: Component,
subscriber: RuntimeConfigsSubscriber,
remove_worker_tx: Option<mpsc::Sender<WorkerId>>,
) -> Self {
Self {
component,
subscriber,
routers: DashMap::new(),
recovered: HashSet::new(),
known_workers: HashSet::new(),
remove_worker_tx,
}
}
......@@ -57,8 +78,84 @@ impl WorkerQueryClient {
self.subscriber.wait_for_some().await
}
/// Wait for runtime config changes.
/// Returns Ok(()) when configs have changed, or Err if the sender was dropped.
pub async fn wait_for_config_change(
&mut self,
) -> Result<(), tokio::sync::watch::error::RecvError> {
self.subscriber.change_rx.changed().await
}
/// Process config changes and recover pending workers.
///
/// This method:
/// 1. Detects removed workers and sends removal events (if remove_worker_tx is set)
/// 2. Recovers workers that have config + local_indexer enabled but haven't been recovered yet
/// 3. Marks recovered workers so they won't be recovered again
///
/// Should be called after `wait_for_config_change()` returns.
///
/// # Arguments
/// * `event_tx` - Channel to send recovered events to the router indexer
/// * `log_prefix` - Prefix for log messages (e.g., "DISCOVERY" or "Initial recovery")
pub async fn process_and_recover_workers(
&mut self,
event_tx: &mpsc::Sender<RouterEvent>,
log_prefix: &str,
) {
// Get current workers from configs
let current_workers: HashSet<WorkerId> =
self.subscriber.configs.iter().map(|r| *r.key()).collect();
// Handle removed workers (only if we have a removal channel)
if let Some(ref remove_worker_tx) = self.remove_worker_tx {
for worker_id in self.known_workers.difference(&current_workers) {
self.recovered.remove(worker_id);
tracing::warn!(
"{log_prefix}: Worker {worker_id} removed, removing from router indexer"
);
if let Err(e) = remove_worker_tx.send(*worker_id).await {
tracing::warn!("Failed to send worker removal for worker {worker_id}: {e}");
}
}
}
self.known_workers = current_workers;
// Find workers needing recovery:
// - Has config (Some)
// - Has local_indexer enabled
// - Not yet recovered
let workers_to_recover: Vec<WorkerId> = self
.subscriber
.configs
.iter()
.filter(|r| {
r.value()
.as_ref()
.map(|c| c.enable_local_indexer)
.unwrap_or(false)
})
.map(|r| *r.key())
.filter(|id| !self.recovered.contains(id))
.collect();
// Recover each worker
for worker_id in workers_to_recover {
tracing::info!(
"{log_prefix}: Worker {worker_id} added, dumping local indexer into router"
);
let recovered = self.recover_all_dp_ranks(worker_id, event_tx).await;
if recovered > 0 {
tracing::info!(
"{log_prefix}: Worker {worker_id} recovered {recovered} events from local indexer"
);
}
self.recovered.insert(worker_id);
}
}
/// Check if a worker has local indexer enabled
pub fn has_local_indexer(&self, worker_id: WorkerId) -> bool {
fn has_local_indexer(&self, worker_id: WorkerId) -> bool {
self.subscriber
.configs
.get(&worker_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment