registry.rs 18.1 KB
Newer Older
1
2
3
4
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::collections::HashMap;
5
6
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
7
8
9

use anyhow::{Result, bail};
use dashmap::DashMap;
10
use dashmap::mapref::one::Ref;
11
use tokio::sync::watch;
12
13
use tokio_util::sync::CancellationToken;

14
use crate::protocols::WorkerId;
15

16
use super::indexer::{Indexer, create_indexer};
17
18
use super::listener::run_zmq_listener;

19
20
21
22
23
24
25
26
27
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct IndexerKey {
    pub model_name: String,
    pub tenant_id: String,
}

pub struct IndexerEntry {
    pub indexer: Indexer,
    pub block_size: u32,
28
29
30
}

pub struct WorkerEntry {
31
    pub endpoints: HashMap<u32, String>,
32
    pub replay_endpoints: HashMap<u32, String>,
33
    cancels: HashMap<u32, CancellationToken>,
34
35
}

36
37
38
39
40
41
42
43
44
/// State needed to restart a paused ZMQ listener.
struct ListenerState {
    endpoint: String,
    replay_endpoint: Option<String>,
    block_size: u32,
    indexer: Indexer,
    watermark: Arc<AtomicU64>,
}

45
46
pub struct WorkerRegistry {
    workers: DashMap<WorkerId, WorkerEntry>,
47
    indexers: DashMap<IndexerKey, IndexerEntry>,
48
    peers: DashMap<String, ()>,
49
50
51
52
    /// Persists across unregister/register cycles so gap detection works after re-registration.
    watermarks: DashMap<(WorkerId, u32), Arc<AtomicU64>>,
    /// Saved listener state for pause/resume. Populated on register, kept on pause.
    listener_states: DashMap<(WorkerId, u32), ListenerState>,
53
54
55
    /// Workers added via MDC discovery (no ZMQ listener). Maps worker_id → indexer key.
    #[cfg(feature = "indexer-runtime")]
    discovered_workers: DashMap<WorkerId, IndexerKey>,
56
    num_threads: usize,
57
58
    ready_tx: watch::Sender<bool>,
    ready_rx: watch::Receiver<bool>,
59
60
61
}

impl WorkerRegistry {
62
    pub fn new(num_threads: usize) -> Self {
63
        let (ready_tx, ready_rx) = watch::channel(false);
64
65
        Self {
            workers: DashMap::new(),
66
            indexers: DashMap::new(),
67
            peers: DashMap::new(),
68
69
            watermarks: DashMap::new(),
            listener_states: DashMap::new(),
70
71
            #[cfg(feature = "indexer-runtime")]
            discovered_workers: DashMap::new(),
72
            num_threads,
73
74
            ready_tx,
            ready_rx,
75
76
77
        }
    }

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    pub fn signal_ready(&self) {
        let _ = self.ready_tx.send(true);
    }

    pub fn ready_rx(&self) -> watch::Receiver<bool> {
        self.ready_rx.clone()
    }

    pub fn register_peer(&self, url: String) {
        self.peers.entry(url).or_insert(());
    }

    pub fn deregister_peer(&self, url: &str) -> bool {
        self.peers.remove(url).is_some()
    }

    pub fn list_peers(&self) -> Vec<String> {
        self.peers.iter().map(|entry| entry.key().clone()).collect()
    }

98
99
    #[expect(clippy::too_many_arguments)]
    pub async fn register(
100
101
102
103
104
105
106
        &self,
        instance_id: WorkerId,
        endpoint: String,
        dp_rank: u32,
        model_name: String,
        tenant_id: String,
        block_size: u32,
107
        replay_endpoint: Option<String>,
108
    ) -> Result<()> {
109
110
111
112
113
114
115
116
117
        // Reject if this worker was already added via discovery
        #[cfg(feature = "indexer-runtime")]
        if self.discovered_workers.contains_key(&instance_id) {
            bail!(
                "instance {instance_id} is already registered via discovery; \
                 use the Dynamo runtime to manage it"
            );
        }

118
119
120
121
122
123
124
125
126
127
128
129
        let key = IndexerKey {
            model_name,
            tenant_id,
        };

        let indexer_entry = self.indexers.entry(key.clone()).or_insert_with(|| {
            tracing::info!(
                model_name = %key.model_name,
                tenant_id = %key.tenant_id,
                block_size,
                "Creating new indexer"
            );
130
            super::metrics::inc_models();
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
            IndexerEntry {
                indexer: create_indexer(block_size, self.num_threads),
                block_size,
            }
        });

        if indexer_entry.block_size != block_size {
            bail!(
                "block_size mismatch for model={} tenant={}: existing={}, requested={}",
                key.model_name,
                key.tenant_id,
                indexer_entry.block_size,
                block_size
            );
        }

        let indexer = indexer_entry.indexer.clone();
        let bs = indexer_entry.block_size;
        drop(indexer_entry);

151
152
        // Check for duplicate and insert replay endpoint while holding the lock briefly.
        {
153
154
155
            let mut entry = self.workers.entry(instance_id).or_insert_with(|| {
                super::metrics::inc_workers();
                WorkerEntry {
156
157
158
                    endpoints: HashMap::new(),
                    replay_endpoints: HashMap::new(),
                    cancels: HashMap::new(),
159
160
                }
            });
161
162
163
164
165
166
167
168

            if entry.endpoints.contains_key(&dp_rank) {
                bail!("instance {instance_id} dp_rank {dp_rank} already registered");
            }

            if let Some(rep) = &replay_endpoint {
                entry.replay_endpoints.insert(dp_rank, rep.clone());
            }
169
170
        }

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        // Reuse watermark if it survived a previous unregister (preserves gap detection).
        let watermark = self
            .watermarks
            .entry((instance_id, dp_rank))
            .or_insert_with(|| Arc::new(AtomicU64::new(u64::MAX)))
            .clone();

        self.listener_states.insert(
            (instance_id, dp_rank),
            ListenerState {
                endpoint: endpoint.clone(),
                replay_endpoint: replay_endpoint.clone(),
                block_size: bs,
                indexer: indexer.clone(),
                watermark: watermark.clone(),
            },
        );

189
190
        let cancel = CancellationToken::new();
        let child_cancel = cancel.child_token();
191
        let addr = endpoint.clone();
192
        let ready = self.ready_rx();
193

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        // Connect the ZMQ socket and spawn the listener task (non-blocking).
        run_zmq_listener(
            instance_id,
            dp_rank,
            addr,
            bs,
            indexer,
            child_cancel,
            ready,
            replay_endpoint,
            watermark,
        )
        .await;

        // Re-acquire to store the endpoint and cancel token.
        let mut entry = self
            .workers
            .get_mut(&instance_id)
            .expect("worker entry disappeared during listener setup");
213
214
        entry.endpoints.insert(dp_rank, endpoint);
        entry.cancels.insert(dp_rank, cancel);
215
216
217
        Ok(())
    }

218
219
220
221
222
223
224
225
226
227
    pub async fn deregister(
        &self,
        instance_id: WorkerId,
        model_name: &str,
        tenant_id: &str,
    ) -> Result<()> {
        let key = IndexerKey {
            model_name: model_name.to_string(),
            tenant_id: tenant_id.to_string(),
        };
228
229
230
231
232
233
234
235
236
237
238
239
240
241

        // Check ZMQ-registered workers first, then discovery workers (if runtime mode)
        if let Some((_, entry)) = self.workers.remove(&instance_id) {
            super::metrics::dec_workers();
            for cancel in entry.cancels.values() {
                cancel.cancel();
            }
        } else if self.remove_discovered_worker(instance_id) {
            super::metrics::dec_workers();
            tracing::info!(instance_id, "Deregistering discovered worker via HTTP");
        } else {
            bail!("instance {instance_id} not found");
        }

242
243
244
245
246
247
248
249
250
251
252
        if let Some(ie) = self.indexers.get(&key) {
            ie.indexer.remove_worker(instance_id).await;
        } else {
            tracing::warn!(
                instance_id,
                model_name,
                tenant_id,
                "indexer key not found on deregister; tree will not be cleaned up"
            );
        }

253
254
255
        Ok(())
    }

256
257
258
259
260
261
262
    pub async fn deregister_dp_rank(
        &self,
        instance_id: WorkerId,
        dp_rank: u32,
        model_name: &str,
        tenant_id: &str,
    ) -> Result<()> {
263
264
265
266
267
268
269
270
271
        let mut entry = self
            .workers
            .get_mut(&instance_id)
            .ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;

        if entry.endpoints.remove(&dp_rank).is_none() {
            bail!("instance {instance_id} dp_rank {dp_rank} not found");
        }

272
273
274
275
        if let Some(cancel) = entry.cancels.remove(&dp_rank) {
            cancel.cancel();
        }

276
277
        if entry.endpoints.is_empty() {
            drop(entry);
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
            return self.deregister(instance_id, model_name, tenant_id).await;
        }
        drop(entry);

        let key = IndexerKey {
            model_name: model_name.to_string(),
            tenant_id: tenant_id.to_string(),
        };
        if let Some(ie) = self.indexers.get(&key) {
            ie.indexer.remove_worker_dp_rank(instance_id, dp_rank).await;
        } else {
            tracing::warn!(
                instance_id,
                dp_rank,
                model_name,
                tenant_id,
                "indexer key not found on deregister_dp_rank; tree will not be cleaned up"
            );
        }

        Ok(())
    }

    pub async fn deregister_all_tenants(
        &self,
        instance_id: WorkerId,
        model_name: &str,
    ) -> Result<()> {
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        // Check ZMQ-registered workers first, then discovery workers (if runtime mode)
        if let Some((_, entry)) = self.workers.remove(&instance_id) {
            super::metrics::dec_workers();
            for cancel in entry.cancels.values() {
                cancel.cancel();
            }
        } else if self.remove_discovered_worker(instance_id) {
            super::metrics::dec_workers();
            tracing::info!(
                instance_id,
                "Deregistering discovered worker (all tenants) via HTTP"
            );
        } else {
            bail!("instance {instance_id} not found");
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        }

        let mut found = false;
        for ie in self.indexers.iter() {
            if ie.key().model_name == model_name {
                ie.indexer.remove_worker(instance_id).await;
                found = true;
            }
        }
        if !found {
            tracing::warn!(
                instance_id,
                model_name,
                "no indexers found for model on deregister_all_tenants; tree will not be cleaned up"
            );
335
336
337
338
339
        }

        Ok(())
    }

340
    #[cfg_attr(not(feature = "test-endpoints"), allow(dead_code))]
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    pub fn pause_listener(&self, instance_id: WorkerId, dp_rank: u32) -> Result<()> {
        let mut entry = self
            .workers
            .get_mut(&instance_id)
            .ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;

        let cancel = entry.cancels.remove(&dp_rank).ok_or_else(|| {
            anyhow::anyhow!("instance {instance_id} dp_rank {dp_rank} not active")
        })?;
        cancel.cancel();

        tracing::info!(instance_id, dp_rank, "Paused ZMQ listener");
        Ok(())
    }

356
    #[cfg_attr(not(feature = "test-endpoints"), allow(dead_code))]
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    pub async fn resume_listener(&self, instance_id: WorkerId, dp_rank: u32) -> Result<()> {
        {
            let entry = self
                .workers
                .get(&instance_id)
                .ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;

            if entry.cancels.contains_key(&dp_rank) {
                bail!("instance {instance_id} dp_rank {dp_rank} already running");
            }
        }

        let state = self
            .listener_states
            .get(&(instance_id, dp_rank))
            .ok_or_else(|| anyhow::anyhow!("no saved state for {instance_id} dp_rank {dp_rank}"))?;

        let cancel = CancellationToken::new();
        let child_cancel = cancel.child_token();
        let ready = self.ready_rx();
        let addr = state.endpoint.clone();
        let bs = state.block_size;
        let indexer = state.indexer.clone();
        let replay_ep = state.replay_endpoint.clone();
        let watermark = state.watermark.clone();
        drop(state);

        run_zmq_listener(
            instance_id,
            dp_rank,
            addr,
            bs,
            indexer,
            child_cancel,
            ready,
            replay_ep,
            watermark,
        )
        .await;

        let mut entry = self
            .workers
            .get_mut(&instance_id)
            .expect("worker entry disappeared during listener resume");
        entry.cancels.insert(dp_rank, cancel);
        Ok(())
    }

405
    pub fn list(&self) -> Vec<(WorkerId, HashMap<u32, String>)> {
406
407
408
        #[allow(unused_mut)]
        let mut result: Vec<(WorkerId, HashMap<u32, String>)> = self
            .workers
409
            .iter()
410
            .map(|entry| (*entry.key(), entry.value().endpoints.clone()))
411
412
413
414
415
416
417
418
419
420
421
422
423
424
            .collect();

        // Include discovered workers (no ZMQ endpoints)
        #[cfg(feature = "indexer-runtime")]
        for entry in self.discovered_workers.iter() {
            let worker_id = *entry.key();
            // Skip if already in the workers map (shouldn't happen, but be safe)
            if self.workers.contains_key(&worker_id) {
                continue;
            }
            result.push((worker_id, HashMap::new()));
        }

        result
425
426
    }

427
428
429
430
    pub fn get_indexer(&self, key: &IndexerKey) -> Option<Ref<'_, IndexerKey, IndexerEntry>> {
        self.indexers.get(key)
    }

431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
    pub fn get_or_create_indexer(&self, key: IndexerKey, block_size: u32) -> Indexer {
        let entry = self.indexers.entry(key.clone()).or_insert_with(|| {
            tracing::info!(
                model_name = %key.model_name,
                tenant_id = %key.tenant_id,
                block_size,
                "Creating indexer from recovery dump"
            );
            IndexerEntry {
                indexer: create_indexer(block_size, self.num_threads),
                block_size,
            }
        });
        if entry.block_size != block_size {
            tracing::warn!(
                model_name = %key.model_name,
                tenant_id = %key.tenant_id,
                existing_block_size = entry.block_size,
                requested_block_size = block_size,
                "Block size mismatch for existing indexer"
            );
        }
        entry.indexer.clone()
    }

    pub fn all_indexers_with_block_size(&self) -> Vec<(IndexerKey, Indexer, u32)> {
457
458
        self.indexers
            .iter()
459
460
461
462
463
464
465
            .map(|entry| {
                (
                    entry.key().clone(),
                    entry.value().indexer.clone(),
                    entry.value().block_size,
                )
            })
466
            .collect()
467
    }
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568

    /// Helper: try to remove a worker from the discovered_workers map.
    /// Returns false when the feature is disabled (no discovered workers exist).
    fn remove_discovered_worker(&self, _instance_id: WorkerId) -> bool {
        #[cfg(feature = "indexer-runtime")]
        {
            self.discovered_workers.remove(&_instance_id).is_some()
        }
        #[cfg(not(feature = "indexer-runtime"))]
        {
            false
        }
    }

    // ---------------------------------------------------------------
    // Discovery-based worker management (no ZMQ listener)
    // ---------------------------------------------------------------

    /// Register a worker discovered via MDC. Creates the indexer if needed but
    /// does NOT start a ZMQ listener — events arrive via the event plane.
    #[cfg(feature = "indexer-runtime")]
    pub fn add_worker_from_discovery(
        &self,
        instance_id: WorkerId,
        model_name: String,
        tenant_id: String,
        block_size: u32,
    ) -> Result<()> {
        // Reject if this worker is already registered via ZMQ (--workers or /register)
        if self.workers.contains_key(&instance_id) {
            bail!(
                "instance {instance_id} is already registered via ZMQ; \
                 cannot add via discovery"
            );
        }

        let key = IndexerKey {
            model_name,
            tenant_id,
        };

        let indexer_entry = self.indexers.entry(key.clone()).or_insert_with(|| {
            tracing::info!(
                model_name = %key.model_name,
                tenant_id = %key.tenant_id,
                block_size,
                "Creating new indexer (discovery)"
            );
            IndexerEntry {
                indexer: create_indexer(block_size, self.num_threads),
                block_size,
            }
        });

        if indexer_entry.block_size != block_size {
            bail!(
                "block_size mismatch for model={} tenant={}: existing={}, requested={}",
                key.model_name,
                key.tenant_id,
                indexer_entry.block_size,
                block_size
            );
        }
        drop(indexer_entry);

        self.discovered_workers.insert(instance_id, key);
        Ok(())
    }

    /// Remove a worker that was discovered via MDC.
    #[cfg(feature = "indexer-runtime")]
    pub async fn remove_worker_from_discovery(&self, instance_id: WorkerId) {
        if let Some((_, key)) = self.discovered_workers.remove(&instance_id) {
            if let Some(ie) = self.indexers.get(&key) {
                ie.indexer.remove_worker(instance_id).await;
            }
        } else {
            tracing::debug!(
                instance_id,
                "remove_worker_from_discovery: worker not in discovered_workers map"
            );
        }
    }

    /// Look up the indexer responsible for a given worker_id.
    /// Checks both discovery-registered and CLI-registered workers.
    #[cfg(feature = "indexer-runtime")]
    pub fn get_indexer_for_worker(&self, worker_id: WorkerId) -> Option<Indexer> {
        // Check discovery workers first (more common in runtime mode)
        if let Some(key) = self.discovered_workers.get(&worker_id)
            && let Some(ie) = self.indexers.get(key.value())
        {
            return Some(ie.indexer.clone());
        }
        // Fall back for legacy --workers mode: only if this worker is actually
        // in the ZMQ-registered workers map, route to the first indexer.
        if self.workers.contains_key(&worker_id) {
            return self.indexers.iter().next().map(|ie| ie.indexer.clone());
        }
        None
    }
569
}