// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::{collections::HashMap, collections::HashSet, time::Duration}; use anyhow::Result; use dynamo_runtime::{ component::Component, config::environment_names::nats as env_nats, discovery::{DiscoveryEvent, DiscoveryQuery, EventTransportKind}, prelude::*, transports::event_plane::EventSubscriber, transports::nats::{NatsQueue, Slug}, }; use futures::StreamExt; use rand::Rng; use tokio_util::sync::CancellationToken; use crate::kv_router::{ Indexer, KV_EVENT_SUBJECT, KvRouterConfig, RADIX_STATE_BUCKET, RADIX_STATE_FILE, protocols::{DpRank, RouterEvent, WorkerId}, router_discovery_query, worker_query::WorkerQueryClient, }; /// Helper function to create a KV stream name from a component and subject. /// /// Generates a slugified stream name in the format: /// `namespace-{namespace}-component-{component}-{subject}` fn create_kv_stream_name(component: &Component, subject: &str) -> String { Slug::slugify(&format!( "namespace.{}.component.{}.{}", component.namespace().name(), component.name(), subject )) .to_string() .replace("_", "-") } /// Delay between snapshot reads to verify stability const SNAPSHOT_STABILITY_DELAY: Duration = Duration::from_millis(100); const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10; const CHECK_INTERVAL_BASE: Duration = Duration::from_secs(1); const CHECK_INTERVAL_JITTER_MS: i64 = 100; // ============================================================================ // Discovery Helpers // ============================================================================ /// Get the instance discovery stream for monitoring worker add/remove events. /// Waits for at least one instance to be discovered before returning. async fn get_instance_discovery_stream( component: &Component, cancellation_token: &CancellationToken, ) -> Result> + Send>>> { let discovery_client = component.drt().discovery(); let generate_discovery_key = DiscoveryQuery::Endpoint { namespace: component.namespace().name().to_string(), component: component.name().to_string(), endpoint: "generate".to_string(), }; let mut stream = discovery_client .list_and_watch(generate_discovery_key, Some(cancellation_token.clone())) .await? .peekable(); tracing::info!("KV subscriber waiting for at least one worker instance..."); std::pin::Pin::new(&mut stream).peek().await; Ok(Box::pin(stream)) } // ============================================================================ // Snapshot Management // ============================================================================ /// Download a stable snapshot from object store and send events to the indexer. /// Retries until two consecutive reads match or max attempts is reached. async fn download_stable_snapshot( nats_client: &dynamo_runtime::transports::nats::Client, bucket_name: &str, indexer: &Indexer, ) -> Result<()> { let url = url::Url::parse(&format!( "nats://{}/{bucket_name}/{RADIX_STATE_FILE}", nats_client.addr() ))?; // Try to get initial snapshot let Ok(mut prev_events) = nats_client .object_store_download_data::>(&url) .await else { tracing::debug!( "Failed to download snapshots. This is normal for freshly started Router replicas." ); return Ok(()); }; // Keep trying until we get two consecutive stable reads for attempt in 1..=MAX_SNAPSHOT_STABILITY_ATTEMPTS { tokio::time::sleep(SNAPSHOT_STABILITY_DELAY).await; let curr_events = match nats_client .object_store_download_data::>(&url) .await { Ok(events) => events, Err(e) => { tracing::warn!( "Snapshot read failed on attempt {attempt}, using previous snapshot with {} events: {e:?}", prev_events.len() ); break; } }; // Check if snapshot is stable (two consecutive reads match) if prev_events == curr_events { tracing::info!( "Successfully downloaded stable snapshot with {} events from object store (stable after {attempt} attempts)", curr_events.len() ); prev_events = curr_events; break; } tracing::debug!( "Snapshot changed between reads on attempt {attempt} ({} -> {} events), retrying", prev_events.len(), curr_events.len() ); prev_events = curr_events; if attempt == MAX_SNAPSHOT_STABILITY_ATTEMPTS { tracing::warn!( "Max stability attempts reached, using latest snapshot with {} events", prev_events.len() ); } } // Send all events to the indexer for event in prev_events { indexer.apply_event(event).await; } tracing::info!("Successfully sent all initial events to indexer"); Ok(()) } /// Resources required for snapshot operations #[derive(Clone)] struct SnapshotResources { nats_client: dynamo_runtime::transports::nats::Client, bucket_name: String, instances_rx: tokio::sync::watch::Receiver>, indexer: Indexer, } impl SnapshotResources { /// Perform snapshot upload and purge operations async fn purge_then_snapshot(&self, nats_queue: &mut NatsQueue) -> anyhow::Result<()> { tracing::info!("Purging acknowledged messages and performing snapshot of radix tree"); let start_time = std::time::Instant::now(); // Clean up stale workers before snapshot let current_instances = self.instances_rx.borrow().clone(); let current_worker_ids: std::collections::HashSet = current_instances .iter() .map(|instance| instance.instance_id) .collect(); let indexer_worker_ids = self.indexer.get_workers().await; for worker_id in indexer_worker_ids { if !current_worker_ids.contains(&worker_id) { tracing::info!("Removing stale worker {worker_id} from indexer during snapshot"); self.indexer.remove_worker(worker_id).await; } } // First, purge acknowledged messages from the stream nats_queue.purge_acknowledged().await?; // Now request a snapshot from the indexer (which reflects the post-purge state) let events = self .indexer .dump_events() .await .map_err(|e| anyhow::anyhow!("Failed to dump events for snapshot: {e:?}"))?; // Upload the snapshot to NATS object store in background (non-blocking) let nats_client = self.nats_client.clone(); let bucket_name = self.bucket_name.clone(); let event_count = events.len(); tokio::spawn(async move { let Ok(url) = url::Url::parse(&format!( "nats://{}/{bucket_name}/{RADIX_STATE_FILE}", nats_client.addr(), )) else { tracing::warn!("Failed to parse snapshot URL"); return; }; if let Err(e) = nats_client.object_store_upload_data(&events, &url).await { tracing::warn!("Failed to upload snapshot: {e:?}"); return; } tracing::info!( "Successfully uploaded snapshot with {event_count} events to bucket {bucket_name} in {}ms", start_time.elapsed().as_millis() ); }); Ok(()) } } /// Start a unified background task for event consumption and optional snapshot management pub async fn start_kv_router_background( component: Component, consumer_id: String, indexer: Indexer, kv_router_config: &KvRouterConfig, ) -> Result<()> { let cancellation_token = component.drt().primary_token(); let router_snapshot_threshold = kv_router_config.router_snapshot_threshold; let router_reset_states = kv_router_config.router_reset_states; // Set up NATS connections let stream_name = create_kv_stream_name(&component, KV_EVENT_SUBJECT); let nats_server = std::env::var(env_nats::NATS_SERVER) .unwrap_or_else(|_| "nats://localhost:4222".to_string()); // Create NatsQueue for event consumption let mut nats_queue = NatsQueue::new_with_consumer( stream_name.clone(), nats_server.clone(), std::time::Duration::from_secs(60), // 1 minute timeout consumer_id.clone(), ); nats_queue.connect_with_reset(router_reset_states).await?; // Always create NATS client (needed for both reset and snapshots) let client_options = dynamo_runtime::transports::nats::Client::builder() .server(&nats_server) .build()?; let nats_client = client_options.connect().await?; // Create bucket name for snapshots/state let event_plane_subject = format!( "namespace.{}.component.{}", component.namespace().name(), component.name() ); let bucket_name = Slug::slugify(&format!("{}-{RADIX_STATE_BUCKET}", event_plane_subject)) .to_string() .replace("_", "-"); // Handle initial state based on router_reset_states flag if !router_reset_states { // Try to download initial state from object store with stability check download_stable_snapshot(&nats_client, &bucket_name, &indexer).await?; } else { // Delete the bucket to reset state tracing::info!("Resetting router state, deleting bucket: {bucket_name}"); if let Err(e) = nats_client.object_store_delete_bucket(&bucket_name).await { tracing::warn!("Failed to delete bucket (may not exist): {e:?}"); } } // Cleanup orphaned consumers on startup cleanup_orphaned_consumers(&mut nats_queue, &component, &consumer_id).await; // Wait for at least one worker instance before proceeding let mut instance_event_stream = get_instance_discovery_stream(&component, &cancellation_token).await?; // Watch for router deletions to clean up orphaned consumers via discovery let generate_endpoint = component.endpoint("generate"); let discovery_client = component.drt().discovery(); let router_discovery_key = router_discovery_query(component.namespace().name()); let mut router_event_stream = discovery_client .list_and_watch(router_discovery_key, Some(cancellation_token.clone())) .await?; // Get instances_rx for tracking current workers let client = generate_endpoint.client().await?; let instances_rx = client.instance_source.as_ref().clone(); // Only set up snapshot-related resources if snapshot threshold is configured let snapshot_resources = router_snapshot_threshold.map(|_| SnapshotResources { nats_client, bucket_name, instances_rx, indexer: indexer.clone(), }); tokio::spawn(async move { // Create interval with jitter let jitter_ms = rand::rng().random_range(-CHECK_INTERVAL_JITTER_MS..=CHECK_INTERVAL_JITTER_MS); let interval_duration = Duration::from_millis( (CHECK_INTERVAL_BASE.as_millis() as i64 + jitter_ms).max(1) as u64, ); let mut check_interval = tokio::time::interval(interval_duration); check_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); loop { tokio::select! { biased; _ = cancellation_token.cancelled() => { tracing::debug!("KV Router background task received cancellation signal"); // Clean up the queue and remove the durable consumer // TODO: durable consumer cannot cleanup if ungraceful shutdown (crash) if let Err(e) = nats_queue.shutdown(None).await { tracing::warn!("Failed to shutdown NatsQueue: {e}"); } break; } // Handle generate endpoint instance deletion events Some(discovery_event_result) = instance_event_stream.next() => { let Ok(discovery_event) = discovery_event_result else { continue; }; let DiscoveryEvent::Removed(id) = discovery_event else { continue; }; let worker_id = id.instance_id(); tracing::warn!( "DISCOVERY: Generate endpoint instance removed, removing worker {worker_id}" ); indexer.remove_worker(worker_id).await; } // Handle event consumption result = nats_queue.dequeue_task(None) => { match result { Ok(Some(bytes)) => { let event: RouterEvent = match serde_json::from_slice(&bytes) { Ok(event) => event, Err(e) => { tracing::warn!("Failed to deserialize RouterEvent: {e:?}"); continue; } }; // Forward the RouterEvent to the indexer indexer.apply_event(event).await; }, Ok(None) => { tracing::trace!("Dequeue timeout, continuing"); }, Err(e) => { tracing::error!("Failed to dequeue task: {e:?}"); tokio::time::sleep(std::time::Duration::from_millis(100)).await; } } } // Handle periodic stream checking and purging (only if snapshot_resources is provided) _ = check_interval.tick() => { let Some(resources) = snapshot_resources.as_ref() else { continue; }; // Check total messages in the stream let Ok(message_count) = nats_queue.get_stream_messages().await else { tracing::warn!("Failed to get stream message count"); continue; }; let threshold = router_snapshot_threshold.unwrap_or(u32::MAX) as u64; if message_count <= threshold { continue; } tracing::info!("Stream has {message_count} messages (threshold: {threshold}), performing purge and snapshot"); match resources.purge_then_snapshot( &mut nats_queue, ).await { Ok(_) => tracing::info!("Successfully performed purge and snapshot"), Err(e) => tracing::debug!("Could not perform purge and snapshot: {e:?}"), } } // Handle router deletion events via discovery Some(router_event_result) = router_event_stream.next() => { let Ok(router_event) = router_event_result else { continue; }; let DiscoveryEvent::Removed(id) = router_event else { // We only care about removals for cleaning up consumers continue; }; let router_instance_id = id.instance_id(); // The consumer ID is the instance_id as a string let consumer_to_delete = router_instance_id.to_string(); tracing::info!( "DISCOVERY: Router instance {router_instance_id} removed, attempting to delete orphaned consumer: {consumer_to_delete}" ); // Delete the consumer (allow race condition if multiple routers try to delete) if let Err(e) = nats_queue.shutdown(Some(consumer_to_delete.clone())).await { tracing::warn!("Failed to delete consumer {consumer_to_delete}: {e}"); } else { tracing::info!("Successfully deleted orphaned consumer: {consumer_to_delete}"); } } } } // Clean up the queue and remove the durable consumer if let Err(e) = nats_queue.shutdown(None).await { tracing::warn!("Failed to shutdown NatsQueue: {e}"); } }); Ok(()) } /// Start a simplified background task for event consumption using the event plane. /// /// This is used when local indexer mode is enabled. Unlike `start_kv_router_background`, /// this function: /// - Uses the event plane (NATS Core or ZMQ) instead of JetStream /// - Does not support snapshots, purging, or durable consumers /// - On worker Added: dumps worker's local indexer into router /// - On worker Removed: removes worker from router indexer /// /// This function first recovers state from all currently registered workers before /// spawning the background task, ensuring the router is ready before returning. /// /// This is appropriate when workers have local indexers enabled. pub async fn start_kv_router_background_event_plane( component: Component, indexer: Indexer, transport_kind: EventTransportKind, ) -> Result<()> { let cancellation_token = component.drt().primary_token(); // WorkerQueryClient handles its own discovery loop for lifecycle + initial recovery. // No blocking wait — recovery happens asynchronously as endpoints are discovered. let worker_query_client = WorkerQueryClient::spawn(component.clone(), indexer.clone()).await?; // Subscribe to KV events using the selected event plane transport let mut subscriber = EventSubscriber::for_component_with_transport(&component, KV_EVENT_SUBJECT, transport_kind) .await? .typed::(); let kv_event_subject = format!( "namespace.{}.component.{}.{}", component.namespace().name(), component.name(), KV_EVENT_SUBJECT ); match transport_kind { EventTransportKind::Nats => { tracing::info!( subject = %kv_event_subject, "KV Router using NATS Core subscription (local_indexer mode)" ); } EventTransportKind::Zmq => { tracing::info!( subject = %kv_event_subject, "KV Router using ZMQ event plane subscription (local_indexer mode)" ); } } tokio::spawn(async move { // Track last received event ID per (worker, dp_rank) for gap detection // Each dp_rank has its own monotonic event ID sequence let mut last_event_ids: HashMap<(WorkerId, DpRank), u64> = HashMap::new(); loop { tokio::select! { biased; _ = cancellation_token.cancelled() => { tracing::debug!("KV Router event plane background task received cancellation signal"); break; } // Handle event consumption from event plane subscription Some(result) = subscriber.next() => { let (envelope, event) = match result { Ok((envelope, event)) => (envelope, event), Err(e) => { tracing::warn!("Failed to receive RouterEvent from event plane: {e:?}"); continue; } }; let worker_id = event.worker_id; let dp_rank = event.event.dp_rank; let event_id = event.event.event_id; let event_key = (worker_id, dp_rank); tracing::trace!( "Received event from publisher {} (seq {})", envelope.publisher_id, envelope.sequence ); // Gap detection: check if event ID is monotonically increasing per (worker, dp_rank) // Note: event_id <= last_id is duplicate/out-of-order, apply anyway (idempotent) if let Some(&last_id) = last_event_ids.get(&event_key) && event_id > last_id + 1 { let gap_start = last_id + 1; let gap_end = event_id - 1; let gap_size = gap_end - gap_start + 1; tracing::warn!( "Event ID gap detected for worker {worker_id} dp_rank {dp_rank}, recovering events [{gap_start}, {gap_end}], gap_size: {gap_size}" ); if let Err(e) = worker_query_client .recover_from_worker(worker_id, dp_rank, Some(gap_start), Some(gap_end)) .await { tracing::error!( "Failed to recover gap events for worker {worker_id} dp_rank {dp_rank} (gap_start: {gap_start}, gap_end: {gap_end}); proceeding with current event anyway: {e}" ); } } // Update last seen event ID (use max to handle out-of-order) last_event_ids .entry(event_key) .and_modify(|id| *id = (*id).max(event_id)) .or_insert(event_id); // Forward the RouterEvent to the indexer indexer.apply_event(event).await; } } } tracing::debug!("KV Router event plane background task exiting"); }); Ok(()) } /// Cleanup orphaned NATS consumers that no longer have corresponding router entries async fn cleanup_orphaned_consumers( nats_queue: &mut NatsQueue, component: &Component, consumer_id: &str, ) { let Ok(consumers) = nats_queue.list_consumers().await else { return; }; // Get active routers from discovery let discovery = component.drt().discovery(); let Ok(router_instances) = discovery .list(router_discovery_query(component.namespace().name())) .await else { tracing::debug!("Failed to list router instances from discovery, skipping cleanup"); return; }; // Build set of active router instance IDs let active_instance_ids: HashSet = router_instances .iter() .map(|instance| instance.instance_id().to_string()) .collect(); for consumer in consumers { if consumer == consumer_id { // Never delete myself (extra/redundant safeguard) continue; } if !active_instance_ids.contains(&consumer) { tracing::info!("Cleaning up orphaned consumer: {consumer}"); let _ = nats_queue.shutdown(Some(consumer)).await; } } } /// Helper to decide which subscriber (JetStream or Event Plane) to start based on config pub async fn start_subscriber( component: Component, kv_router_config: &KvRouterConfig, indexer: Indexer, ) -> Result<()> { let transport_kind = EventTransportKind::from_env_or_default(); // Start subscriber - durable_kv_events flag determines the mode: // - durable_kv_events=false (default): Use NATS Core / generic event plane (requires workers to have local_indexer enabled) // - durable_kv_events=true: Use JetStream for durability and multi-replica consistency if kv_router_config.durable_kv_events { if transport_kind == EventTransportKind::Zmq { tracing::warn!( "--durable-kv-events requires NATS, but ZMQ event plane is configured; falling back to JetStream anyway" ); } tracing::info!("Using JetStream subscription (--durable-kv-events enabled)"); let consumer_id = component.drt().discovery().instance_id().to_string(); start_kv_router_background(component, consumer_id, indexer, kv_router_config).await } else { if transport_kind == EventTransportKind::Zmq { if kv_router_config.router_snapshot_threshold.is_some() || kv_router_config.router_reset_states { tracing::warn!( "ZMQ event plane does not support KV snapshots or state reset; ignoring snapshot/reset settings" ); } tracing::info!("Using ZMQ event plane subscription (local_indexer mode)"); } else { tracing::info!("Using NATS Core subscription (local_indexer mode)"); } start_kv_router_background_event_plane(component, indexer, transport_kind).await } }