registry.rs 12.7 KB
Newer Older
1
2
3
4
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::collections::HashMap;
5
6
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
7
8
9

use anyhow::{Result, bail};
use dashmap::DashMap;
10
use dashmap::mapref::one::Ref;
11
use tokio::sync::watch;
12
13
use tokio_util::sync::CancellationToken;

14
use crate::protocols::WorkerId;
15

16
use super::indexer::{Indexer, create_indexer};
17
18
use super::listener::run_zmq_listener;

19
20
21
22
23
24
25
26
27
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct IndexerKey {
    pub model_name: String,
    pub tenant_id: String,
}

pub struct IndexerEntry {
    pub indexer: Indexer,
    pub block_size: u32,
28
29
30
}

pub struct WorkerEntry {
31
    pub endpoints: HashMap<u32, String>,
32
    pub replay_endpoints: HashMap<u32, String>,
33
    cancels: HashMap<u32, CancellationToken>,
34
35
}

36
37
38
39
40
41
42
43
44
/// State needed to restart a paused ZMQ listener.
struct ListenerState {
    endpoint: String,
    replay_endpoint: Option<String>,
    block_size: u32,
    indexer: Indexer,
    watermark: Arc<AtomicU64>,
}

45
46
pub struct WorkerRegistry {
    workers: DashMap<WorkerId, WorkerEntry>,
47
    indexers: DashMap<IndexerKey, IndexerEntry>,
48
    peers: DashMap<String, ()>,
49
50
51
52
    /// Persists across unregister/register cycles so gap detection works after re-registration.
    watermarks: DashMap<(WorkerId, u32), Arc<AtomicU64>>,
    /// Saved listener state for pause/resume. Populated on register, kept on pause.
    listener_states: DashMap<(WorkerId, u32), ListenerState>,
53
    num_threads: usize,
54
55
    ready_tx: watch::Sender<bool>,
    ready_rx: watch::Receiver<bool>,
56
57
58
}

impl WorkerRegistry {
59
    pub fn new(num_threads: usize) -> Self {
60
        let (ready_tx, ready_rx) = watch::channel(false);
61
62
        Self {
            workers: DashMap::new(),
63
            indexers: DashMap::new(),
64
            peers: DashMap::new(),
65
66
            watermarks: DashMap::new(),
            listener_states: DashMap::new(),
67
            num_threads,
68
69
            ready_tx,
            ready_rx,
70
71
72
        }
    }

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    pub fn signal_ready(&self) {
        let _ = self.ready_tx.send(true);
    }

    pub fn ready_rx(&self) -> watch::Receiver<bool> {
        self.ready_rx.clone()
    }

    pub fn register_peer(&self, url: String) {
        self.peers.entry(url).or_insert(());
    }

    pub fn deregister_peer(&self, url: &str) -> bool {
        self.peers.remove(url).is_some()
    }

    pub fn list_peers(&self) -> Vec<String> {
        self.peers.iter().map(|entry| entry.key().clone()).collect()
    }

93
94
    #[expect(clippy::too_many_arguments)]
    pub async fn register(
95
96
97
98
99
100
101
        &self,
        instance_id: WorkerId,
        endpoint: String,
        dp_rank: u32,
        model_name: String,
        tenant_id: String,
        block_size: u32,
102
        replay_endpoint: Option<String>,
103
104
105
106
107
108
109
110
111
112
113
114
115
    ) -> Result<()> {
        let key = IndexerKey {
            model_name,
            tenant_id,
        };

        let indexer_entry = self.indexers.entry(key.clone()).or_insert_with(|| {
            tracing::info!(
                model_name = %key.model_name,
                tenant_id = %key.tenant_id,
                block_size,
                "Creating new indexer"
            );
116
            super::metrics::inc_models();
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            IndexerEntry {
                indexer: create_indexer(block_size, self.num_threads),
                block_size,
            }
        });

        if indexer_entry.block_size != block_size {
            bail!(
                "block_size mismatch for model={} tenant={}: existing={}, requested={}",
                key.model_name,
                key.tenant_id,
                indexer_entry.block_size,
                block_size
            );
        }

        let indexer = indexer_entry.indexer.clone();
        let bs = indexer_entry.block_size;
        drop(indexer_entry);

137
138
        // Check for duplicate and insert replay endpoint while holding the lock briefly.
        {
139
140
141
            let mut entry = self.workers.entry(instance_id).or_insert_with(|| {
                super::metrics::inc_workers();
                WorkerEntry {
142
143
144
                    endpoints: HashMap::new(),
                    replay_endpoints: HashMap::new(),
                    cancels: HashMap::new(),
145
146
                }
            });
147
148
149
150
151
152
153
154

            if entry.endpoints.contains_key(&dp_rank) {
                bail!("instance {instance_id} dp_rank {dp_rank} already registered");
            }

            if let Some(rep) = &replay_endpoint {
                entry.replay_endpoints.insert(dp_rank, rep.clone());
            }
155
156
        }

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        // Reuse watermark if it survived a previous unregister (preserves gap detection).
        let watermark = self
            .watermarks
            .entry((instance_id, dp_rank))
            .or_insert_with(|| Arc::new(AtomicU64::new(u64::MAX)))
            .clone();

        self.listener_states.insert(
            (instance_id, dp_rank),
            ListenerState {
                endpoint: endpoint.clone(),
                replay_endpoint: replay_endpoint.clone(),
                block_size: bs,
                indexer: indexer.clone(),
                watermark: watermark.clone(),
            },
        );

175
176
        let cancel = CancellationToken::new();
        let child_cancel = cancel.child_token();
177
        let addr = endpoint.clone();
178
        let ready = self.ready_rx();
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        // Connect the ZMQ socket and spawn the listener task (non-blocking).
        run_zmq_listener(
            instance_id,
            dp_rank,
            addr,
            bs,
            indexer,
            child_cancel,
            ready,
            replay_endpoint,
            watermark,
        )
        .await;

        // Re-acquire to store the endpoint and cancel token.
        let mut entry = self
            .workers
            .get_mut(&instance_id)
            .expect("worker entry disappeared during listener setup");
199
200
        entry.endpoints.insert(dp_rank, endpoint);
        entry.cancels.insert(dp_rank, cancel);
201
202
203
        Ok(())
    }

204
205
206
207
208
209
    pub async fn deregister(
        &self,
        instance_id: WorkerId,
        model_name: &str,
        tenant_id: &str,
    ) -> Result<()> {
210
211
212
213
214
        let (_, entry) = self
            .workers
            .remove(&instance_id)
            .ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;

215
216
        super::metrics::dec_workers();

217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        for cancel in entry.cancels.values() {
            cancel.cancel();
        }

        let key = IndexerKey {
            model_name: model_name.to_string(),
            tenant_id: tenant_id.to_string(),
        };
        if let Some(ie) = self.indexers.get(&key) {
            ie.indexer.remove_worker(instance_id).await;
        } else {
            tracing::warn!(
                instance_id,
                model_name,
                tenant_id,
                "indexer key not found on deregister; tree will not be cleaned up"
            );
        }

236
237
238
        Ok(())
    }

239
240
241
242
243
244
245
    pub async fn deregister_dp_rank(
        &self,
        instance_id: WorkerId,
        dp_rank: u32,
        model_name: &str,
        tenant_id: &str,
    ) -> Result<()> {
246
247
248
249
250
251
252
253
254
        let mut entry = self
            .workers
            .get_mut(&instance_id)
            .ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;

        if entry.endpoints.remove(&dp_rank).is_none() {
            bail!("instance {instance_id} dp_rank {dp_rank} not found");
        }

255
256
257
258
        if let Some(cancel) = entry.cancels.remove(&dp_rank) {
            cancel.cancel();
        }

259
260
        if entry.endpoints.is_empty() {
            drop(entry);
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
            return self.deregister(instance_id, model_name, tenant_id).await;
        }
        drop(entry);

        let key = IndexerKey {
            model_name: model_name.to_string(),
            tenant_id: tenant_id.to_string(),
        };
        if let Some(ie) = self.indexers.get(&key) {
            ie.indexer.remove_worker_dp_rank(instance_id, dp_rank).await;
        } else {
            tracing::warn!(
                instance_id,
                dp_rank,
                model_name,
                tenant_id,
                "indexer key not found on deregister_dp_rank; tree will not be cleaned up"
            );
        }

        Ok(())
    }

    pub async fn deregister_all_tenants(
        &self,
        instance_id: WorkerId,
        model_name: &str,
    ) -> Result<()> {
        let (_, entry) = self
            .workers
            .remove(&instance_id)
            .ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;

294
295
        super::metrics::dec_workers();

296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        for cancel in entry.cancels.values() {
            cancel.cancel();
        }

        let mut found = false;
        for ie in self.indexers.iter() {
            if ie.key().model_name == model_name {
                ie.indexer.remove_worker(instance_id).await;
                found = true;
            }
        }
        if !found {
            tracing::warn!(
                instance_id,
                model_name,
                "no indexers found for model on deregister_all_tenants; tree will not be cleaned up"
            );
313
314
315
316
317
        }

        Ok(())
    }

318
    #[cfg_attr(not(feature = "test-endpoints"), allow(dead_code))]
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    pub fn pause_listener(&self, instance_id: WorkerId, dp_rank: u32) -> Result<()> {
        let mut entry = self
            .workers
            .get_mut(&instance_id)
            .ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;

        let cancel = entry.cancels.remove(&dp_rank).ok_or_else(|| {
            anyhow::anyhow!("instance {instance_id} dp_rank {dp_rank} not active")
        })?;
        cancel.cancel();

        tracing::info!(instance_id, dp_rank, "Paused ZMQ listener");
        Ok(())
    }

334
    #[cfg_attr(not(feature = "test-endpoints"), allow(dead_code))]
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
    pub async fn resume_listener(&self, instance_id: WorkerId, dp_rank: u32) -> Result<()> {
        {
            let entry = self
                .workers
                .get(&instance_id)
                .ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;

            if entry.cancels.contains_key(&dp_rank) {
                bail!("instance {instance_id} dp_rank {dp_rank} already running");
            }
        }

        let state = self
            .listener_states
            .get(&(instance_id, dp_rank))
            .ok_or_else(|| anyhow::anyhow!("no saved state for {instance_id} dp_rank {dp_rank}"))?;

        let cancel = CancellationToken::new();
        let child_cancel = cancel.child_token();
        let ready = self.ready_rx();
        let addr = state.endpoint.clone();
        let bs = state.block_size;
        let indexer = state.indexer.clone();
        let replay_ep = state.replay_endpoint.clone();
        let watermark = state.watermark.clone();
        drop(state);

        run_zmq_listener(
            instance_id,
            dp_rank,
            addr,
            bs,
            indexer,
            child_cancel,
            ready,
            replay_ep,
            watermark,
        )
        .await;

        let mut entry = self
            .workers
            .get_mut(&instance_id)
            .expect("worker entry disappeared during listener resume");
        entry.cancels.insert(dp_rank, cancel);
        Ok(())
    }

383
384
385
    pub fn list(&self) -> Vec<(WorkerId, HashMap<u32, String>)> {
        self.workers
            .iter()
386
            .map(|entry| (*entry.key(), entry.value().endpoints.clone()))
387
388
389
            .collect()
    }

390
391
392
393
    pub fn get_indexer(&self, key: &IndexerKey) -> Option<Ref<'_, IndexerKey, IndexerEntry>> {
        self.indexers.get(key)
    }

394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    pub fn get_or_create_indexer(&self, key: IndexerKey, block_size: u32) -> Indexer {
        let entry = self.indexers.entry(key.clone()).or_insert_with(|| {
            tracing::info!(
                model_name = %key.model_name,
                tenant_id = %key.tenant_id,
                block_size,
                "Creating indexer from recovery dump"
            );
            IndexerEntry {
                indexer: create_indexer(block_size, self.num_threads),
                block_size,
            }
        });
        if entry.block_size != block_size {
            tracing::warn!(
                model_name = %key.model_name,
                tenant_id = %key.tenant_id,
                existing_block_size = entry.block_size,
                requested_block_size = block_size,
                "Block size mismatch for existing indexer"
            );
        }
        entry.indexer.clone()
    }

    pub fn all_indexers_with_block_size(&self) -> Vec<(IndexerKey, Indexer, u32)> {
420
421
        self.indexers
            .iter()
422
423
424
425
426
427
428
            .map(|entry| {
                (
                    entry.key().clone(),
                    entry.value().indexer.clone(),
                    entry.value().block_size,
                )
            })
429
            .collect()
430
431
    }
}