utils.rs 12.4 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
// SPDX-License-Identifier: Apache-2.0

//! Utility functions for working with discovery streams

use serde::Deserialize;

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
use super::{DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryStream};

/// Collapse state keyed by full `DiscoveryInstanceId` into a flat HashMap<u64, V>.
/// When multiple entries share the same instance_id (e.g., base model +
/// LoRA adapters on the same worker, or the same worker on different endpoints),
/// the base model (suffix=None) is preferred. If no base model exists, an
/// arbitrary LoRA entry is used.
fn collapse_by_instance_id<V: Clone>(
    state: &std::collections::HashMap<DiscoveryInstanceId, V>,
) -> std::collections::HashMap<u64, V> {
    let mut result = std::collections::HashMap::new();
    for (id, val) in state {
        let instance_id = id.instance_id();
        let model_suffix = match id {
            DiscoveryInstanceId::Model(mid) => mid.model_suffix.as_ref(),
            _ => None,
        };
        if model_suffix.is_none() || !result.contains_key(&instance_id) {
            result.insert(instance_id, val.clone());
        }
    }
    result
}
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

/// Helper to watch a discovery stream and extract a specific field into a HashMap
///
/// This helper spawns a background task that:
/// - Deserializes ModelCards from discovery events
/// - Extracts a specific field using the provided extractor function
/// - Maintains a HashMap<instance_id, Field> that auto-updates on Add/Remove events
/// - Returns a watch::Receiver that consumers can use to read the current state
///
/// # Type Parameters
/// - `T`: The type to deserialize from DiscoveryInstance (e.g., ModelDeploymentCard)
/// - `V`: The extracted field type (e.g., ModelRuntimeConfig)
/// - `F`: The extractor function type
///
/// # Arguments
/// - `stream`: The discovery event stream to watch
/// - `extractor`: Function that extracts the desired field from the deserialized type
///
/// # Example
/// ```ignore
51
/// let stream = discovery.list_and_watch(DiscoveryQuery::ComponentModels { ... }, None).await?;
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
/// let runtime_configs_rx = watch_and_extract_field(
///     stream,
///     |card: ModelDeploymentCard| card.runtime_config,
/// );
///
/// // Use it:
/// let configs = runtime_configs_rx.borrow();
/// if let Some(config) = configs.get(&worker_id) {
///     // Use config...
/// }
/// ```
pub fn watch_and_extract_field<T, V, F>(
    stream: DiscoveryStream,
    extractor: F,
) -> tokio::sync::watch::Receiver<std::collections::HashMap<u64, V>>
where
    T: for<'de> Deserialize<'de> + 'static,
69
    V: Clone + PartialEq + Send + Sync + 'static,
70
71
72
73
74
75
76
77
    F: Fn(T) -> V + Send + 'static,
{
    use futures::StreamExt;
    use std::collections::HashMap;

    let (tx, rx) = tokio::sync::watch::channel(HashMap::new());

    tokio::spawn(async move {
78
79
80
81
82
83
        // Internal state keyed by full DiscoveryInstanceId to correctly
        // distinguish entries across namespaces, components, endpoints, and
        // model suffixes — even when they share the same raw instance_id.
        // Collapsed to HashMap<u64, V> for consumers, preferring suffix=None
        // (base model) when multiple entries exist for the same instance_id.
        let mut state: HashMap<DiscoveryInstanceId, V> = HashMap::new();
84
85
86
87
88
89
        let mut stream = stream;

        while let Some(result) = stream.next().await {
            match result {
                Ok(DiscoveryEvent::Added(instance)) => {
                    let instance_id = instance.instance_id();
90
                    let key = instance.id();
91
92

                    // Deserialize the full instance into type T
93
                    let deserialized: T = match instance.deserialize_model() {
94
95
96
97
98
99
100
101
102
103
104
105
106
107
                        Ok(d) => d,
                        Err(e) => {
                            tracing::warn!(
                                instance_id,
                                error = %e,
                                "Failed to deserialize discovery instance, skipping"
                            );
                            continue;
                        }
                    };

                    // Extract the field we care about
                    let value = extractor(deserialized);

108
109
110
111
112
113
114
115
116
117
118
119
120
121
                    tracing::debug!(
                        instance_id,
                        ?key,
                        state_len = state.len(),
                        "watch_and_extract_field: inserting instance"
                    );

                    state.insert(key, value);

                    // Only publish if the collapsed worker view actually changed,
                    // to avoid waking downstream watchers on no-op events
                    // (e.g., adding a LoRA when base model already represents the worker).
                    let collapsed = collapse_by_instance_id(&state);
                    if *tx.borrow() != collapsed && tx.send(collapsed).is_err() {
122
123
124
125
                        tracing::debug!("watch_and_extract_field receiver dropped, stopping");
                        break;
                    }
                }
126
                Ok(DiscoveryEvent::Removed(id)) => {
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
                    let had_entry = state.contains_key(&id);

                    tracing::debug!(
                        instance_id = id.instance_id(),
                        ?id,
                        had_entry,
                        state_len = state.len(),
                        "watch_and_extract_field: removing instance"
                    );

                    state.remove(&id);

                    // Only publish if the collapsed worker view actually changed,
                    // to avoid waking downstream watchers on no-op events
                    // (e.g., adding a LoRA when base model already represents the worker).
                    let collapsed = collapse_by_instance_id(&state);
                    if *tx.borrow() != collapsed && tx.send(collapsed).is_err() {
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
                        tracing::debug!("watch_and_extract_field receiver dropped, stopping");
                        break;
                    }
                }
                Err(e) => {
                    tracing::error!(error = %e, "Discovery event stream error in watch_and_extract_field");
                    // Continue processing other events
                }
            }
        }

        tracing::debug!("watch_and_extract_field task stopped");
    });

    rx
}
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331

#[cfg(test)]
mod tests {
    use super::*;
    use crate::discovery::mock::{MockDiscovery, SharedMockRegistry};
    use crate::discovery::{Discovery, DiscoveryQuery, DiscoverySpec};

    /// Minimal struct that mirrors the fields watch_and_extract_field deserializes.
    #[derive(serde::Deserialize, Clone, Debug)]
    struct FakeCard {
        display_name: String,
    }

    fn model_spec(name: &str) -> DiscoverySpec {
        DiscoverySpec::Model {
            namespace: "ns".to_string(),
            component: "comp".to_string(),
            endpoint: "generate".to_string(),
            card_json: serde_json::json!({ "display_name": name }),
            model_suffix: None,
        }
    }

    /// Poll a watch receiver until the predicate is satisfied, or timeout after 1s.
    async fn poll_until(
        rx: &tokio::sync::watch::Receiver<std::collections::HashMap<u64, String>>,
        pred: impl Fn(&std::collections::HashMap<u64, String>) -> bool,
        msg: &str,
    ) {
        for _ in 0..100 {
            if pred(&rx.borrow()) {
                return;
            }
            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
        }
        panic!("{}: state={:?}", msg, *rx.borrow());
    }

    fn lora_spec(lora_name: &str) -> DiscoverySpec {
        DiscoverySpec::Model {
            namespace: "ns".to_string(),
            component: "comp".to_string(),
            endpoint: "generate".to_string(),
            card_json: serde_json::json!({
                "display_name": lora_name,
                "source_path": "base-model",
                "lora": { "name": lora_name },
            }),
            model_suffix: Some(lora_name.to_string()),
        }
    }

    /// Unregistering a single LoRA adapter must not remove the worker's
    /// runtime config. Base model and other LoRA adapters on the same worker
    /// share the same instance_id; removing one must leave the others intact.
    #[tokio::test]
    async fn test_lora_unregister_preserves_worker_runtime_config() {
        // All registrations use the same instance_id (same worker)
        let discovery = MockDiscovery::new(Some(42), SharedMockRegistry::new());

        let query = DiscoveryQuery::EndpointModels {
            namespace: "ns".to_string(),
            component: "comp".to_string(),
            endpoint: "generate".to_string(),
        };

        let stream = discovery.list_and_watch(query, None).await.unwrap();

        // Watch the stream, extracting display_name as a stand-in for runtime_config
        let rx = watch_and_extract_field(stream, |card: FakeCard| card.display_name);

        // Register base model + LoRA-A + LoRA-B on the same worker (instance_id=42)
        let base = discovery.register(model_spec("base-model")).await.unwrap();
        let lora_a = discovery.register(lora_spec("lora-a")).await.unwrap();
        discovery.register(lora_spec("lora-b")).await.unwrap();

        poll_until(
            &rx,
            |s| s.contains_key(&42),
            "Worker 42 should be present after registrations",
        )
        .await;

        // Unregister LoRA-A only — base model and LoRA-B remain.
        discovery.unregister(lora_a).await.unwrap();

        // Base model is preferred in the collapsed view.
        poll_until(
            &rx,
            |s| s.get(&42).map(|v| v.as_str()) == Some("base-model"),
            "Worker 42 should have base-model after removing lora-a",
        )
        .await;

        {
            let state = rx.borrow();
            assert_eq!(state.get(&42).map(|s| s.as_str()), Some("base-model"));
        }

        // Unregister the base model — lora-b should be the fallback.
        discovery.unregister(base).await.unwrap();

        poll_until(
            &rx,
            |s| s.get(&42).map(|v| v.as_str()) == Some("lora-b"),
            "Worker 42 should fall back to lora-b after removing base model",
        )
        .await;

        {
            let state = rx.borrow();
            assert_eq!(state.get(&42).map(|s| s.as_str()), Some("lora-b"));
        }
    }

    /// Same worker (instance_id) registered on two different endpoints must not
    /// alias when watched via AllModels. Removing the registration from one
    /// endpoint must leave the other intact in the collapsed view.
    #[tokio::test]
    async fn test_all_models_cross_endpoint_no_alias() {
        let registry = SharedMockRegistry::new();
        // Same instance_id for both — simulates a single worker serving two endpoints
        let discovery = MockDiscovery::new(Some(7), registry.clone());

        let stream = discovery
            .list_and_watch(DiscoveryQuery::AllModels, None)
            .await
            .unwrap();
        let rx = watch_and_extract_field(stream, |card: FakeCard| card.display_name);

        // Register on endpoint "ep-a"
        let ep_a = discovery
            .register(DiscoverySpec::Model {
                namespace: "ns".to_string(),
                component: "comp".to_string(),
                endpoint: "ep-a".to_string(),
                card_json: serde_json::json!({ "display_name": "model-on-ep-a" }),
                model_suffix: None,
            })
            .await
            .unwrap();

        // Register on endpoint "ep-b"
        discovery
            .register(DiscoverySpec::Model {
                namespace: "ns".to_string(),
                component: "comp".to_string(),
                endpoint: "ep-b".to_string(),
                card_json: serde_json::json!({ "display_name": "model-on-ep-b" }),
                model_suffix: None,
            })
            .await
            .unwrap();

        poll_until(
            &rx,
            |s| s.contains_key(&7),
            "Worker 7 should appear after registrations",
        )
        .await;

        // Remove the ep-a registration — ep-b should keep worker 7 alive.
        discovery.unregister(ep_a).await.unwrap();

        poll_until(
            &rx,
            |s| s.get(&7).map(|v| v.as_str()) == Some("model-on-ep-b"),
            "Worker 7 should still be present via ep-b after removing ep-a",
        )
        .await;
    }
}