subscriber.rs 19 KB
Newer Older
1
2
3
4
5
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! Background processes for the KV Router including event consumption and snapshot uploads.

6
use std::{collections::HashSet, time::Duration};
7
8
9
10

use anyhow::Result;
use dynamo_runtime::{
    component::Component,
11
    config::environment_names::nats as env_nats,
12
    discovery::{DiscoveryEvent, DiscoveryQuery},
13
14
    prelude::*,
    traits::events::EventPublisher,
15
    transports::nats::{NatsQueue, Slug},
16
};
17
use futures::StreamExt;
18
use rand::Rng;
19
20
21
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;

22
23
24
25
26
use crate::kv_router::{
    KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE,
    indexer::{DumpRequest, GetWorkersRequest, RouterEvent},
    protocols::WorkerId,
    router_discovery_query,
27
28
};

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
/// Delay between snapshot reads to verify stability
const SNAPSHOT_STABILITY_DELAY: Duration = Duration::from_millis(100);
const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10;

const CHECK_INTERVAL_BASE: Duration = Duration::from_secs(1);
const CHECK_INTERVAL_JITTER_MS: i64 = 100;

/// Download a stable snapshot from object store and send events to the indexer.
/// Retries until two consecutive reads match or max attempts is reached.
async fn download_stable_snapshot(
    nats_client: &dynamo_runtime::transports::nats::Client,
    bucket_name: &str,
    kv_events_tx: &mpsc::Sender<RouterEvent>,
) -> Result<()> {
    let url = url::Url::parse(&format!(
        "nats://{}/{bucket_name}/{RADIX_STATE_FILE}",
        nats_client.addr()
    ))?;

    // Try to get initial snapshot
    let Ok(mut prev_events) = nats_client
        .object_store_download_data::<Vec<RouterEvent>>(&url)
        .await
    else {
        tracing::debug!(
            "Failed to download snapshots. This is normal for freshly started Router replicas."
        );
        return Ok(());
    };

    // Keep trying until we get two consecutive stable reads
    for attempt in 1..=MAX_SNAPSHOT_STABILITY_ATTEMPTS {
        tokio::time::sleep(SNAPSHOT_STABILITY_DELAY).await;

        let curr_events = match nats_client
            .object_store_download_data::<Vec<RouterEvent>>(&url)
            .await
        {
            Ok(events) => events,
            Err(e) => {
                tracing::warn!(
                    "Snapshot read failed on attempt {attempt}, using previous snapshot with {} events: {e:?}",
                    prev_events.len()
                );
                break;
            }
        };

        // Check if snapshot is stable (two consecutive reads match)
        if prev_events == curr_events {
            tracing::info!(
                "Successfully downloaded stable snapshot with {} events from object store (stable after {attempt} attempts)",
                curr_events.len()
            );
            prev_events = curr_events;
            break;
        }

        tracing::debug!(
            "Snapshot changed between reads on attempt {attempt} ({} -> {} events), retrying",
            prev_events.len(),
            curr_events.len()
        );
        prev_events = curr_events;

        if attempt == MAX_SNAPSHOT_STABILITY_ATTEMPTS {
            tracing::warn!(
                "Max stability attempts reached, using latest snapshot with {} events",
                prev_events.len()
            );
        }
    }

    // Send all events to the indexer
    for event in prev_events {
        if let Err(e) = kv_events_tx.send(event).await {
            tracing::warn!("Failed to send initial event to indexer: {e:?}");
        }
    }
    tracing::info!("Successfully sent all initial events to indexer");

    Ok(())
}

113
114
115
116
117
/// Resources required for snapshot operations
#[derive(Clone)]
struct SnapshotResources {
    nats_client: dynamo_runtime::transports::nats::Client,
    bucket_name: String,
118
119
120
    instances_rx: tokio::sync::watch::Receiver<Vec<dynamo_runtime::component::Instance>>,
    get_workers_tx: mpsc::Sender<GetWorkersRequest>,
    snapshot_tx: mpsc::Sender<DumpRequest>,
121
122
123
}

impl SnapshotResources {
124
    /// Perform snapshot upload and purge operations
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    async fn purge_then_snapshot(
        &self,
        nats_queue: &mut NatsQueue,
        remove_worker_tx: &mpsc::Sender<WorkerId>,
    ) -> anyhow::Result<()> {
        // Purge before snapshot ensures new/warm-restarted routers won't replay already-acknowledged messages.
        // Since KV events are idempotent, this ordering reduces unnecessary reprocessing while maintaining
        // at-least-once delivery guarantees. The snapshot will capture the clean state after purge.
        tracing::info!("Purging acknowledged messages and performing snapshot of radix tree");
        let start_time = std::time::Instant::now();

        // Clean up stale workers before snapshot
        // Get current worker IDs from instances_rx
        let current_instances = self.instances_rx.borrow().clone();
139
        let current_worker_ids: std::collections::HashSet<u64> = current_instances
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
            .iter()
            .map(|instance| instance.instance_id)
            .collect();

        // Get worker IDs from the indexer
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let get_workers_req = GetWorkersRequest { resp: resp_tx };

        if let Err(e) = self.get_workers_tx.send(get_workers_req).await {
            tracing::warn!("Failed to send get_workers request during snapshot: {e:?}");
        } else {
            match resp_rx.await {
                Ok(indexer_worker_ids) => {
                    // Find workers in indexer but not in current instances
                    for worker_id in indexer_worker_ids {
                        if !current_worker_ids.contains(&worker_id) {
                            tracing::info!(
157
                                "Removing stale worker {worker_id} from indexer during snapshot"
158
159
160
                            );
                            if let Err(e) = remove_worker_tx.send(worker_id).await {
                                tracing::warn!(
161
                                    "Failed to send remove_worker for stale worker {worker_id}: {e:?}"
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
                                );
                            }
                        }
                    }
                }
                Err(e) => {
                    tracing::warn!("Failed to receive worker IDs from indexer: {e:?}");
                }
            }
        }

        // First, purge acknowledged messages from the stream
        nats_queue.purge_acknowledged().await?;

        // Now request a snapshot from the indexer (which reflects the post-purge state)
        let (resp_tx, resp_rx) = oneshot::channel();
        let dump_req = DumpRequest { resp: resp_tx };

        self.snapshot_tx
            .send(dump_req)
            .await
            .map_err(|e| anyhow::anyhow!("Failed to send dump request: {e:?}"))?;

        // Wait for the dump response
        let events = resp_rx
            .await
            .map_err(|e| anyhow::anyhow!("Failed to receive dump response: {e:?}"))?;

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        // Upload the snapshot to NATS object store in background (non-blocking)
        let nats_client = self.nats_client.clone();
        let bucket_name = self.bucket_name.clone();
        let event_count = events.len();
        tokio::spawn(async move {
            let Ok(url) = url::Url::parse(&format!(
                "nats://{}/{bucket_name}/{RADIX_STATE_FILE}",
                nats_client.addr(),
            )) else {
                tracing::warn!("Failed to parse snapshot URL");
                return;
            };

            if let Err(e) = nats_client.object_store_upload_data(&events, &url).await {
                tracing::warn!("Failed to upload snapshot: {e:?}");
                return;
            }
207

208
209
210
211
212
            tracing::info!(
                "Successfully uploaded snapshot with {event_count} events to bucket {bucket_name} in {}ms",
                start_time.elapsed().as_millis()
            );
        });
213
214
215

        Ok(())
    }
216
217
218
}

/// Start a unified background task for event consumption and optional snapshot management
219
#[allow(clippy::too_many_arguments)]
220
221
pub async fn start_kv_router_background(
    component: Component,
222
    consumer_id: String,
223
    kv_events_tx: mpsc::Sender<RouterEvent>,
224
225
226
    remove_worker_tx: mpsc::Sender<WorkerId>,
    maybe_get_workers_tx: Option<mpsc::Sender<GetWorkersRequest>>,
    maybe_snapshot_tx: Option<mpsc::Sender<DumpRequest>>,
227
228
229
230
231
232
233
234
    cancellation_token: CancellationToken,
    router_snapshot_threshold: Option<u32>,
    router_reset_states: bool,
) -> Result<()> {
    // Set up NATS connections
    let stream_name = Slug::slugify(&format!("{}.{}", component.subject(), KV_EVENT_SUBJECT))
        .to_string()
        .replace("_", "-");
235
236
    let nats_server = std::env::var(env_nats::NATS_SERVER)
        .unwrap_or_else(|_| "nats://localhost:4222".to_string());
237
238
239
240
241
242

    // Create NatsQueue for event consumption
    let mut nats_queue = NatsQueue::new_with_consumer(
        stream_name.clone(),
        nats_server.clone(),
        std::time::Duration::from_secs(60), // 1 minute timeout
243
        consumer_id.clone(),
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    );
    nats_queue.connect_with_reset(router_reset_states).await?;

    // Always create NATS client (needed for both reset and snapshots)
    let client_options = dynamo_runtime::transports::nats::Client::builder()
        .server(&nats_server)
        .build()?;
    let nats_client = client_options.connect().await?;

    // Create bucket name for snapshots/state
    let bucket_name = Slug::slugify(&format!("{}-{RADIX_STATE_BUCKET}", component.subject()))
        .to_string()
        .replace("_", "-");

    // Handle initial state based on router_reset_states flag
259
260
261
262
    if !router_reset_states {
        // Try to download initial state from object store with stability check
        download_stable_snapshot(&nats_client, &bucket_name, &kv_events_tx).await?;
    } else {
263
264
265
266
267
268
269
        // Delete the bucket to reset state
        tracing::info!("Resetting router state, deleting bucket: {bucket_name}");
        if let Err(e) = nats_client.object_store_delete_bucket(&bucket_name).await {
            tracing::warn!("Failed to delete bucket (may not exist): {e:?}");
        }
    }

270
    // Cleanup orphaned consumers on startup
271
    cleanup_orphaned_consumers(&mut nats_queue, &component, &consumer_id).await;
272

273
274
    // Get the generate endpoint and watch for instance deletions
    let generate_endpoint = component.endpoint("generate");
275
    let discovery_client = component.drt().discovery();
276
    let generate_discovery_key = DiscoveryQuery::Endpoint {
277
278
279
280
281
        namespace: component.namespace().name().to_string(),
        component: component.name().to_string(),
        endpoint: "generate".to_string(),
    };
    let mut instance_event_stream = discovery_client
282
283
284
285
286
287
288
        .list_and_watch(generate_discovery_key, Some(cancellation_token.clone()))
        .await?;

    // Watch for router deletions to clean up orphaned consumers via discovery
    let router_discovery_key = router_discovery_query(component.namespace().name());
    let mut router_event_stream = discovery_client
        .list_and_watch(router_discovery_key, Some(cancellation_token.clone()))
289
        .await?;
290

291
292
    // Get instances_rx for tracking current workers
    let client = generate_endpoint.client().await?;
293
    let instances_rx = client.instance_source.as_ref().clone();
294
295
296
297
298
299
300

    // Only set up snapshot-related resources if snapshot_tx, get_workers_tx, and threshold are provided
    let snapshot_resources = if let (Some(get_workers_tx), Some(snapshot_tx), Some(_)) = (
        maybe_get_workers_tx,
        maybe_snapshot_tx,
        router_snapshot_threshold,
    ) {
301
302
303
        Some(SnapshotResources {
            nats_client,
            bucket_name,
304
305
306
            instances_rx,
            get_workers_tx,
            snapshot_tx,
307
308
309
310
311
        })
    } else {
        None
    };

312
    tokio::spawn(async move {
313
314
315
316
317
318
319
        // Create interval with jitter
        let jitter_ms =
            rand::rng().random_range(-CHECK_INTERVAL_JITTER_MS..=CHECK_INTERVAL_JITTER_MS);
        let interval_duration = Duration::from_millis(
            (CHECK_INTERVAL_BASE.as_millis() as i64 + jitter_ms).max(1) as u64,
        );
        let mut check_interval = tokio::time::interval(interval_duration);
320
321
322
323
324
325
326
327
328
        check_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);

        loop {
            tokio::select! {
                biased;

                _ = cancellation_token.cancelled() => {
                    tracing::debug!("KV Router background task received cancellation signal");
                    // Clean up the queue and remove the durable consumer
329
                    // TODO: durable consumer cannot cleanup if ungraceful shutdown (crash)
330
331
332
333
334
335
                    if let Err(e) = nats_queue.shutdown(None).await {
                        tracing::warn!("Failed to shutdown NatsQueue: {e}");
                    }
                    break;
                }

336
                // Handle generate endpoint instance deletion events
337
338
                Some(discovery_event_result) = instance_event_stream.next() => {
                    let Ok(discovery_event) = discovery_event_result else {
339
340
341
                        continue;
                    };

342
                    let DiscoveryEvent::Removed(worker_id) = discovery_event else {
343
344
345
                        continue;
                    };

346
347
348
349
                    tracing::warn!(
                        worker_id = worker_id,
                        "DISCOVERY: Generate endpoint instance removed, removing worker"
                    );
350
351

                    if let Err(e) = remove_worker_tx.send(worker_id).await {
352
                        tracing::warn!("Failed to send worker removal for worker {worker_id}: {e}");
353
354
355
                    }
                }

356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
                // Handle event consumption
                result = nats_queue.dequeue_task(None) => {
                    match result {
                        Ok(Some(bytes)) => {
                            let event: RouterEvent = match serde_json::from_slice(&bytes) {
                                Ok(event) => event,
                                Err(e) => {
                                    tracing::warn!("Failed to deserialize RouterEvent: {e:?}");
                                    continue;
                                }
                            };

                            // Forward the RouterEvent to the indexer
                            if let Err(e) = kv_events_tx.send(event).await {
                                tracing::warn!(
                                    "failed to send kv event to indexer; shutting down: {e:?}"
                                );
                                break;
                            }
                        },
                        Ok(None) => {
                            tracing::trace!("Dequeue timeout, continuing");
                        },
                        Err(e) => {
                            tracing::error!("Failed to dequeue task: {e:?}");
                            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
                        }
                    }
                }

386
                // Handle periodic stream checking and purging (only if snapshot_resources is provided)
387
                _ = check_interval.tick() => {
388
                    let Some(resources) = snapshot_resources.as_ref() else {
389
390
391
392
393
394
395
396
397
398
                        continue;
                    };

                    // Check total messages in the stream
                    let Ok(message_count) = nats_queue.get_stream_messages().await else {
                        tracing::warn!("Failed to get stream message count");
                        continue;
                    };

                    let threshold = router_snapshot_threshold.unwrap_or(u32::MAX) as u64;
399

400
401
402
403
                    if message_count <= threshold {
                        continue;
                    }

404
                    tracing::info!("Stream has {message_count} messages (threshold: {threshold}), performing purge and snapshot");
405

406
                    match resources.purge_then_snapshot(
407
                        &mut nats_queue,
408
                        &remove_worker_tx,
409
410
                    ).await {
                        Ok(_) => tracing::info!("Successfully performed purge and snapshot"),
411
                        Err(e) => tracing::debug!("Could not perform purge and snapshot: {e:?}"),
412
413
414
                    }
                }

415
416
417
                // Handle router deletion events via discovery
                Some(router_event_result) = router_event_stream.next() => {
                    let Ok(router_event) = router_event_result else {
418
419
420
                        continue;
                    };

421
422
                    let DiscoveryEvent::Removed(router_instance_id) = router_event else {
                        // We only care about removals for cleaning up consumers
423
424
425
                        continue;
                    };

426
427
                    // The consumer UUID is the instance_id in hex format
                    let consumer_to_delete = router_instance_id.to_string();
428

429
430
431
432
                    tracing::info!(
                        router_instance_id = router_instance_id,
                        "DISCOVERY: Router instance removed, attempting to delete orphaned consumer: {consumer_to_delete}"
                    );
433

434
435
436
                    // Delete the consumer (allow race condition if multiple routers try to delete)
                    if let Err(e) = nats_queue.shutdown(Some(consumer_to_delete.clone())).await {
                        tracing::warn!("Failed to delete consumer {consumer_to_delete}: {e}");
437
                    } else {
438
                        tracing::info!("Successfully deleted orphaned consumer: {consumer_to_delete}");
439
440
441
442
443
444
445
446
447
448
449
450
451
452
                    }
                }
            }
        }

        // Clean up the queue and remove the durable consumer
        if let Err(e) = nats_queue.shutdown(None).await {
            tracing::warn!("Failed to shutdown NatsQueue: {e}");
        }
    });

    Ok(())
}

453
/// Cleanup orphaned NATS consumers that no longer have corresponding router entries
454
455
456
async fn cleanup_orphaned_consumers(
    nats_queue: &mut NatsQueue,
    component: &Component,
457
    consumer_id: &str,
458
459
460
461
462
) {
    let Ok(consumers) = nats_queue.list_consumers().await else {
        return;
    };

463
464
465
466
467
468
469
    // Get active routers from discovery
    let discovery = component.drt().discovery();
    let Ok(router_instances) = discovery
        .list(router_discovery_query(component.namespace().name()))
        .await
    else {
        tracing::debug!("Failed to list router instances from discovery, skipping cleanup");
470
471
472
        return;
    };

473
474
    // Build set of active router instance IDs
    let active_instance_ids: HashSet<String> = router_instances
475
        .iter()
476
        .map(|instance| instance.instance_id().to_string())
477
478
479
        .collect();

    for consumer in consumers {
480
        if consumer == consumer_id {
481
482
483
            // Never delete myself (extra/redundant safeguard)
            continue;
        }
484
        if !active_instance_ids.contains(&consumer) {
485
            tracing::info!("Cleaning up orphaned consumer: {consumer}");
486
487
488
489
            let _ = nats_queue.shutdown(Some(consumer)).await;
        }
    }
}