// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::collections::HashSet; use std::time::Duration; use anyhow::Result; use dynamo_kv_router::{config::KvRouterConfig, protocols::RouterEvent}; use dynamo_runtime::{ component::Component, config::environment_names::nats as env_nats, discovery::{DiscoveryEvent, DiscoveryQuery}, prelude::*, transports::nats::{NatsQueue, Slug}, }; use futures::StreamExt; use rand::Rng; use tokio_util::sync::CancellationToken; use crate::kv_router::{ Indexer, KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, router_discovery_query, }; /// Helper function to create a KV stream name from a component and subject. /// /// Generates a slugified stream name in the format: /// `namespace-{namespace}-component-{component}-{subject}` fn create_kv_stream_name(component: &Component, subject: &str) -> String { Slug::slugify(&format!( "namespace.{}.component.{}.{}", component.namespace().name(), component.name(), subject )) .to_string() .replace("_", "-") } /// Delay between snapshot reads to verify stability const SNAPSHOT_STABILITY_DELAY: Duration = Duration::from_millis(100); const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10; const CHECK_INTERVAL_BASE: Duration = Duration::from_secs(1); const CHECK_INTERVAL_JITTER_MS: i64 = 100; // ============================================================================ // Discovery Helpers // ============================================================================ /// Get the instance discovery stream for monitoring worker add/remove events. /// Waits for at least one instance to be discovered before returning. async fn get_instance_discovery_stream( component: &Component, cancellation_token: &CancellationToken, ) -> Result> + Send>>> { let discovery_client = component.drt().discovery(); let generate_discovery_key = DiscoveryQuery::Endpoint { namespace: component.namespace().name().to_string(), component: component.name().to_string(), endpoint: "generate".to_string(), }; let mut stream = discovery_client .list_and_watch(generate_discovery_key, Some(cancellation_token.clone())) .await? .peekable(); tracing::info!("KV subscriber waiting for at least one worker instance..."); std::pin::Pin::new(&mut stream).peek().await; Ok(Box::pin(stream)) } // ============================================================================ // Snapshot Management // ============================================================================ /// Download a stable snapshot from object store and send events to the indexer. /// Retries until two consecutive reads match or max attempts is reached. async fn download_stable_snapshot( nats_client: &dynamo_runtime::transports::nats::Client, bucket_name: &str, indexer: &Indexer, ) -> Result<()> { let url = url::Url::parse(&format!( "nats://{}/{bucket_name}/{RADIX_STATE_FILE}", nats_client.addr() ))?; // Try to get initial snapshot let Ok(mut prev_events) = nats_client .object_store_download_data::>(&url) .await else { tracing::debug!( "Failed to download snapshots. This is normal for freshly started Router replicas." ); return Ok(()); }; // Keep trying until we get two consecutive stable reads for attempt in 1..=MAX_SNAPSHOT_STABILITY_ATTEMPTS { tokio::time::sleep(SNAPSHOT_STABILITY_DELAY).await; let curr_events = match nats_client .object_store_download_data::>(&url) .await { Ok(events) => events, Err(e) => { tracing::warn!( "Snapshot read failed on attempt {attempt}, using previous snapshot with {} events: {e:?}", prev_events.len() ); break; } }; // Check if snapshot is stable (two consecutive reads match) if prev_events == curr_events { tracing::info!( "Successfully downloaded stable snapshot with {} events from object store (stable after {attempt} attempts)", curr_events.len() ); prev_events = curr_events; break; } tracing::debug!( "Snapshot changed between reads on attempt {attempt} ({} -> {} events), retrying", prev_events.len(), curr_events.len() ); prev_events = curr_events; if attempt == MAX_SNAPSHOT_STABILITY_ATTEMPTS { tracing::warn!( "Max stability attempts reached, using latest snapshot with {} events", prev_events.len() ); } } // Send all events to the indexer for event in prev_events { indexer.apply_event(event).await; } tracing::info!("Successfully sent all initial events to indexer"); Ok(()) } /// Resources required for snapshot operations #[derive(Clone)] struct SnapshotResources { nats_client: dynamo_runtime::transports::nats::Client, bucket_name: String, instances_rx: tokio::sync::watch::Receiver>, indexer: Indexer, } impl SnapshotResources { /// Perform snapshot upload and purge operations async fn purge_then_snapshot(&self, nats_queue: &mut NatsQueue) -> anyhow::Result<()> { tracing::info!("Purging acknowledged messages and performing snapshot of radix tree"); let start_time = std::time::Instant::now(); // Clean up stale workers before snapshot let current_instances = self.instances_rx.borrow().clone(); let current_worker_ids: std::collections::HashSet = current_instances .iter() .map(|instance| instance.instance_id) .collect(); let indexer_worker_ids = self.indexer.get_workers().await; for worker_id in indexer_worker_ids { if !current_worker_ids.contains(&worker_id) { tracing::info!("Removing stale worker {worker_id} from indexer during snapshot"); self.indexer.remove_worker(worker_id).await; } } // First, purge acknowledged messages from the stream nats_queue.purge_acknowledged().await?; // Now request a snapshot from the indexer (which reflects the post-purge state) let events = self .indexer .dump_events() .await .map_err(|e| anyhow::anyhow!("Failed to dump events for snapshot: {e:?}"))?; // Upload the snapshot to NATS object store in background (non-blocking) let nats_client = self.nats_client.clone(); let bucket_name = self.bucket_name.clone(); let event_count = events.len(); tokio::spawn(async move { let Ok(url) = url::Url::parse(&format!( "nats://{}/{bucket_name}/{RADIX_STATE_FILE}", nats_client.addr(), )) else { tracing::warn!("Failed to parse snapshot URL"); return; }; if let Err(e) = nats_client.object_store_upload_data(&events, &url).await { tracing::warn!("Failed to upload snapshot: {e:?}"); return; } tracing::info!( "Successfully uploaded snapshot with {event_count} events to bucket {bucket_name} in {}ms", start_time.elapsed().as_millis() ); }); Ok(()) } } /// Start a unified background task for event consumption and optional snapshot management pub(crate) async fn start_kv_router_background( component: Component, consumer_id: String, indexer: Indexer, kv_router_config: &KvRouterConfig, ) -> Result<()> { let cancellation_token = component.drt().primary_token(); let router_snapshot_threshold = kv_router_config.router_snapshot_threshold; let router_reset_states = kv_router_config.router_reset_states; // Set up NATS connections let stream_name = create_kv_stream_name(&component, KV_EVENT_SUBJECT); let nats_server = std::env::var(env_nats::NATS_SERVER) .unwrap_or_else(|_| "nats://localhost:4222".to_string()); // Create NatsQueue for event consumption let mut nats_queue = NatsQueue::new_with_consumer( stream_name.clone(), nats_server.clone(), std::time::Duration::from_secs(60), // 1 minute timeout consumer_id.clone(), ); nats_queue.connect_with_reset(router_reset_states).await?; // Always create NATS client (needed for both reset and snapshots) let client_options = dynamo_runtime::transports::nats::Client::builder() .server(&nats_server) .build()?; let nats_client = client_options.connect().await?; // Create bucket name for snapshots/state let event_plane_subject = format!( "namespace.{}.component.{}", component.namespace().name(), component.name() ); let bucket_name = Slug::slugify(&format!("{}-{RADIX_STATE_BUCKET}", event_plane_subject)) .to_string() .replace("_", "-"); // Handle initial state based on router_reset_states flag if !router_reset_states { // Try to download initial state from object store with stability check download_stable_snapshot(&nats_client, &bucket_name, &indexer).await?; } else { // Delete the bucket to reset state tracing::info!("Resetting router state, deleting bucket: {bucket_name}"); if let Err(e) = nats_client.object_store_delete_bucket(&bucket_name).await { tracing::warn!("Failed to delete bucket (may not exist): {e:?}"); } } // Cleanup orphaned consumers on startup cleanup_orphaned_consumers(&mut nats_queue, &component, &consumer_id).await; // Wait for at least one worker instance before proceeding let mut instance_event_stream = get_instance_discovery_stream(&component, &cancellation_token).await?; // Watch for router deletions to clean up orphaned consumers via discovery let generate_endpoint = component.endpoint("generate"); let discovery_client = component.drt().discovery(); let router_discovery_key = router_discovery_query(component.namespace().name(), component.name().to_string()); let mut router_event_stream = discovery_client .list_and_watch(router_discovery_key, Some(cancellation_token.clone())) .await?; // Get instances_rx for tracking current workers let client = generate_endpoint.client().await?; let instances_rx = client.instance_source.as_ref().clone(); // Only set up snapshot-related resources if snapshot threshold is configured let snapshot_resources = router_snapshot_threshold.map(|_| SnapshotResources { nats_client, bucket_name, instances_rx, indexer: indexer.clone(), }); tokio::spawn(async move { // Create interval with jitter let jitter_ms = rand::rng().random_range(-CHECK_INTERVAL_JITTER_MS..=CHECK_INTERVAL_JITTER_MS); let interval_duration = Duration::from_millis( (CHECK_INTERVAL_BASE.as_millis() as i64 + jitter_ms) .max(1) .cast_unsigned(), ); let mut check_interval = tokio::time::interval(interval_duration); check_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); loop { tokio::select! { biased; _ = cancellation_token.cancelled() => { tracing::debug!("KV Router background task received cancellation signal"); break; } // Handle generate endpoint instance deletion events Some(discovery_event_result) = instance_event_stream.next() => { let Ok(discovery_event) = discovery_event_result else { continue; }; let DiscoveryEvent::Removed(id) = discovery_event else { continue; }; let worker_id = id.instance_id(); tracing::warn!( "DISCOVERY: Generate endpoint instance removed, removing worker {worker_id}" ); indexer.remove_worker(worker_id).await; } // Handle event consumption result = nats_queue.dequeue_task(None) => { match result { Ok(Some(bytes)) => { let event: RouterEvent = match serde_json::from_slice(&bytes) { Ok(event) => event, Err(e) => { tracing::warn!("Failed to deserialize RouterEvent: {e:?}"); continue; } }; // Forward the RouterEvent to the indexer indexer.apply_event(event).await; }, Ok(None) => { tracing::trace!("Dequeue timeout, continuing"); }, Err(e) => { tracing::error!("Failed to dequeue task: {e:?}"); tokio::time::sleep(std::time::Duration::from_millis(100)).await; } } } // Handle periodic stream checking and purging (only if snapshot_resources is provided) _ = check_interval.tick() => { let Some(resources) = snapshot_resources.as_ref() else { continue; }; // Check total messages in the stream let Ok(message_count) = nats_queue.get_stream_messages().await else { tracing::warn!("Failed to get stream message count"); continue; }; let threshold = router_snapshot_threshold.unwrap_or(u32::MAX) as u64; if message_count <= threshold { continue; } tracing::info!("Stream has {message_count} messages (threshold: {threshold}), performing purge and snapshot"); match resources.purge_then_snapshot( &mut nats_queue, ).await { Ok(_) => tracing::info!("Successfully performed purge and snapshot"), Err(e) => tracing::debug!("Could not perform purge and snapshot: {e:?}"), } } // Handle router deletion events via discovery Some(router_event_result) = router_event_stream.next() => { let Ok(router_event) = router_event_result else { continue; }; let DiscoveryEvent::Removed(id) = router_event else { // We only care about removals for cleaning up consumers continue; }; let router_instance_id = id.instance_id(); // The consumer ID is the instance_id as a string let consumer_to_delete = router_instance_id.to_string(); tracing::info!( "DISCOVERY: Router instance {router_instance_id} removed, attempting to delete orphaned consumer: {consumer_to_delete}" ); // Delete the consumer (allow race condition if multiple routers try to delete) if let Err(e) = nats_queue.shutdown(Some(consumer_to_delete.clone())).await { tracing::warn!("Failed to delete consumer {consumer_to_delete}: {e}"); } else { tracing::info!("Successfully deleted orphaned consumer: {consumer_to_delete}"); } } } } // Clean up the queue and remove the durable consumer if let Err(e) = nats_queue.shutdown(None).await { tracing::warn!("Failed to shutdown NatsQueue: {e}"); } }); Ok(()) } /// Cleanup orphaned NATS consumers that no longer have corresponding router entries async fn cleanup_orphaned_consumers( nats_queue: &mut NatsQueue, component: &Component, consumer_id: &str, ) { let Ok(consumers) = nats_queue.list_consumers().await else { return; }; // Get active routers from discovery let discovery = component.drt().discovery(); let Ok(router_instances) = discovery .list(router_discovery_query( component.namespace().name(), component.name().to_string(), )) .await else { tracing::debug!("Failed to list router instances from discovery, skipping cleanup"); return; }; // Build set of active router instance IDs let active_instance_ids: HashSet = router_instances .iter() .map(|instance| instance.instance_id().to_string()) .collect(); for consumer in consumers { if consumer == consumer_id { // Never delete myself (extra/redundant safeguard) continue; } if !active_instance_ids.contains(&consumer) { tracing::info!("Cleaning up orphaned consumer: {consumer}"); let _ = nats_queue.shutdown(Some(consumer)).await; } } }