"examples/backends/vllm/vscode:/vscode.git/clone" did not exist on "8ed69ea2f8f73a00512bfe15045e7803bb9b63cb"
mock.rs 15.3 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
// SPDX-License-Identifier: Apache-2.0

use super::{
5
6
    Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
    DiscoverySpec, DiscoveryStream,
7
};
8
use anyhow::Result;
9
10
use async_trait::async_trait;
use std::sync::{Arc, Mutex};
11
use tokio_util::sync::CancellationToken;
12
13
14
15
16
17
18
19
20
21
22
23
24

/// Shared in-memory registry for mock discovery
#[derive(Clone, Default)]
pub struct SharedMockRegistry {
    instances: Arc<Mutex<Vec<DiscoveryInstance>>>,
}

impl SharedMockRegistry {
    pub fn new() -> Self {
        Self::default()
    }
}

25
26
27
/// Mock implementation of Discovery for testing
/// We can potentially remove this once we have KVStoreDiscovery fully tested
pub struct MockDiscovery {
28
29
30
31
    instance_id: u64,
    registry: SharedMockRegistry,
}

32
impl MockDiscovery {
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    pub fn new(instance_id: Option<u64>, registry: SharedMockRegistry) -> Self {
        let instance_id = instance_id.unwrap_or_else(|| {
            use std::sync::atomic::{AtomicU64, Ordering};
            static COUNTER: AtomicU64 = AtomicU64::new(1);
            COUNTER.fetch_add(1, Ordering::SeqCst)
        });

        Self {
            instance_id,
            registry,
        }
    }
}

47
48
49
/// Helper function to check if an instance matches a discovery query
fn matches_query(instance: &DiscoveryInstance, query: &DiscoveryQuery) -> bool {
    match (instance, query) {
50
        // Endpoint matching
51
52
        (DiscoveryInstance::Endpoint(_), DiscoveryQuery::AllEndpoints) => true,
        (DiscoveryInstance::Endpoint(inst), DiscoveryQuery::NamespacedEndpoints { namespace }) => {
53
54
55
56
            &inst.namespace == namespace
        }
        (
            DiscoveryInstance::Endpoint(inst),
57
            DiscoveryQuery::ComponentEndpoints {
58
59
60
61
62
63
                namespace,
                component,
            },
        ) => &inst.namespace == namespace && &inst.component == component,
        (
            DiscoveryInstance::Endpoint(inst),
64
            DiscoveryQuery::Endpoint {
65
66
67
68
69
70
71
72
73
74
                namespace,
                component,
                endpoint,
            },
        ) => {
            &inst.namespace == namespace
                && &inst.component == component
                && &inst.endpoint == endpoint
        }

75
76
        // Model matching
        (DiscoveryInstance::Model { .. }, DiscoveryQuery::AllModels) => true,
77
        (
78
            DiscoveryInstance::Model {
79
                namespace: inst_ns, ..
80
            },
81
            DiscoveryQuery::NamespacedModels { namespace },
82
        ) => inst_ns == namespace,
83
        (
84
            DiscoveryInstance::Model {
85
86
                namespace: inst_ns,
                component: inst_comp,
87
88
                ..
            },
89
            DiscoveryQuery::ComponentModels {
90
91
92
                namespace,
                component,
            },
93
        ) => inst_ns == namespace && inst_comp == component,
94
        (
95
            DiscoveryInstance::Model {
96
97
98
                namespace: inst_ns,
                component: inst_comp,
                endpoint: inst_ep,
99
100
                ..
            },
101
            DiscoveryQuery::EndpointModels {
102
103
104
105
                namespace,
                component,
                endpoint,
            },
106
107
        ) => inst_ns == namespace && inst_comp == component && inst_ep == endpoint,

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        // EventChannel matching - unified query
        (
            DiscoveryInstance::EventChannel {
                namespace: inst_ns,
                component: inst_comp,
                topic: inst_topic,
                ..
            },
            DiscoveryQuery::EventChannels(query),
        ) => {
            query.namespace.as_ref().is_none_or(|ns| ns == inst_ns)
                && query.component.as_ref().is_none_or(|c| c == inst_comp)
                && query.topic.as_ref().is_none_or(|t| t == inst_topic)
        }

123
124
125
        // Cross-type matches return false
        (
            DiscoveryInstance::Endpoint(_),
126
127
128
            DiscoveryQuery::AllModels
            | DiscoveryQuery::NamespacedModels { .. }
            | DiscoveryQuery::ComponentModels { .. }
129
130
            | DiscoveryQuery::EndpointModels { .. }
            | DiscoveryQuery::EventChannels(_),
131
132
        ) => false,
        (
133
134
135
136
            DiscoveryInstance::Model { .. },
            DiscoveryQuery::AllEndpoints
            | DiscoveryQuery::NamespacedEndpoints { .. }
            | DiscoveryQuery::ComponentEndpoints { .. }
137
138
139
140
141
142
143
144
145
146
147
148
149
            | DiscoveryQuery::Endpoint { .. }
            | DiscoveryQuery::EventChannels(_),
        ) => false,
        (
            DiscoveryInstance::EventChannel { .. },
            DiscoveryQuery::AllEndpoints
            | DiscoveryQuery::NamespacedEndpoints { .. }
            | DiscoveryQuery::ComponentEndpoints { .. }
            | DiscoveryQuery::Endpoint { .. }
            | DiscoveryQuery::AllModels
            | DiscoveryQuery::NamespacedModels { .. }
            | DiscoveryQuery::ComponentModels { .. }
            | DiscoveryQuery::EndpointModels { .. },
150
        ) => false,
151
152
153
154
    }
}

#[async_trait]
155
impl Discovery for MockDiscovery {
156
157
158
159
    fn instance_id(&self) -> u64 {
        self.instance_id
    }

160
    async fn register_internal(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
161
162
163
164
165
166
167
168
169
170
171
        let instance = spec.with_instance_id(self.instance_id);

        self.registry
            .instances
            .lock()
            .unwrap()
            .push(instance.clone());

        Ok(instance)
    }

172
    async fn unregister(&self, instance: DiscoveryInstance) -> Result<()> {
173
        let target_id = instance.id();
174
175
176
177
178

        self.registry
            .instances
            .lock()
            .unwrap()
179
            .retain(|i| i.id() != target_id);
180
181
182
183

        Ok(())
    }

184
    async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
185
186
187
        let instances = self.registry.instances.lock().unwrap();
        Ok(instances
            .iter()
188
            .filter(|instance| matches_query(instance, &query))
189
190
191
192
            .cloned()
            .collect())
    }

193
194
195
196
197
    async fn list_and_watch(
        &self,
        query: DiscoveryQuery,
        _cancel_token: Option<CancellationToken>,
    ) -> Result<DiscoveryStream> {
198
199
200
201
202
        use std::collections::HashSet;

        let registry = self.registry.clone();

        let stream = async_stream::stream! {
203
            let mut known_instances: HashSet<DiscoveryInstanceId> = HashSet::new();
204
205
206
207
208
209

            loop {
                let current: Vec<_> = {
                    let instances = registry.instances.lock().unwrap();
                    instances
                        .iter()
210
                        .filter(|instance| matches_query(instance, &query))
211
212
213
214
                        .cloned()
                        .collect()
                };

215
                let current_ids: HashSet<DiscoveryInstanceId> = current.iter().map(|i| i.id()).collect();
216
217
218

                // Emit Added events for new instances
                for instance in current {
219
                    let id = instance.id();
220
221
222
223
224
225
226
227
                    if known_instances.insert(id) {
                        yield Ok(DiscoveryEvent::Added(instance));
                    }
                }

                // Emit Removed events for instances that are gone
                for id in known_instances.difference(&current_ids).cloned().collect::<Vec<_>>() {
                    known_instances.remove(&id);
228
                    yield Ok(DiscoveryEvent::Removed(id));
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
                }

                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
            }
        };

        Ok(Box::pin(stream))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use futures::StreamExt;

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    fn model_spec(
        namespace: &str,
        component: &str,
        endpoint: &str,
        model_name: &str,
    ) -> DiscoverySpec {
        DiscoverySpec::Model {
            namespace: namespace.to_string(),
            component: component.to_string(),
            endpoint: endpoint.to_string(),
            card_json: serde_json::json!({
                "display_name": model_name,
            }),
            model_suffix: None,
        }
    }

    fn lora_model_spec(
        namespace: &str,
        component: &str,
        endpoint: &str,
        model_name: &str,
        source_path: &str,
        lora_name: &str,
    ) -> DiscoverySpec {
        DiscoverySpec::Model {
            namespace: namespace.to_string(),
            component: component.to_string(),
            endpoint: endpoint.to_string(),
            card_json: serde_json::json!({
                "display_name": model_name,
                "source_path": source_path,
                "lora": {
                    "name": lora_name,
                },
            }),
            model_suffix: Some(lora_name.to_string()),
        }
    }

284
285
286
    #[tokio::test]
    async fn test_mock_discovery_add_and_remove() {
        let registry = SharedMockRegistry::new();
287
288
        let client1 = MockDiscovery::new(Some(1), registry.clone());
        let client2 = MockDiscovery::new(Some(2), registry.clone());
289
290
291
292
293

        let spec = DiscoverySpec::Endpoint {
            namespace: "test-ns".to_string(),
            component: "test-comp".to_string(),
            endpoint: "test-ep".to_string(),
294
            transport: crate::component::TransportType::Nats("test-subject".to_string()),
295
296
        };

297
        let query = DiscoveryQuery::Endpoint {
298
299
300
301
302
303
            namespace: "test-ns".to_string(),
            component: "test-comp".to_string(),
            endpoint: "test-ep".to_string(),
        };

        // Start watching
304
        let mut stream = client1.list_and_watch(query.clone(), None).await.unwrap();
305
306

        // Add first instance
307
        let instance1 = client1.register(spec.clone()).await.unwrap();
308
309
310

        let event = stream.next().await.unwrap().unwrap();
        match event {
311
312
            DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
                assert_eq!(inst.instance_id, 1);
313
314
315
316
317
318
319
320
321
            }
            _ => panic!("Expected Added event for instance-1"),
        }

        // Add second instance
        client2.register(spec.clone()).await.unwrap();

        let event = stream.next().await.unwrap().unwrap();
        match event {
322
323
            DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
                assert_eq!(inst.instance_id, 2);
324
325
326
327
328
            }
            _ => panic!("Expected Added event for instance-2"),
        }

        // Remove first instance
329
        client1.unregister(instance1).await.unwrap();
330
331
332

        let event = stream.next().await.unwrap().unwrap();
        match event {
333
334
335
            DiscoveryEvent::Removed(id) => {
                let endpoint_id = id.extract_endpoint_id().expect("Expected endpoint removal");
                assert_eq!(endpoint_id.instance_id, 1);
336
337
338
339
            }
            _ => panic!("Expected Removed event for instance-1"),
        }
    }
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477

    #[tokio::test]
    async fn register_allows_same_model_name_on_same_endpoint() {
        let registry = SharedMockRegistry::new();
        let discovery1 = MockDiscovery::new(Some(1), registry.clone());
        let discovery2 = MockDiscovery::new(Some(2), registry);
        let spec = model_spec("ns", "comp", "generate", "model-a");

        discovery1.register(spec.clone()).await.unwrap();
        discovery2.register(spec).await.unwrap();

        let instances = discovery1
            .list(DiscoveryQuery::EndpointModels {
                namespace: "ns".to_string(),
                component: "comp".to_string(),
                endpoint: "generate".to_string(),
            })
            .await
            .unwrap();
        assert_eq!(instances.len(), 2);
    }

    #[tokio::test]
    async fn register_rejects_different_model_name_on_same_endpoint() {
        let registry = SharedMockRegistry::new();
        let discovery1 = MockDiscovery::new(Some(1), registry.clone());
        let discovery2 = MockDiscovery::new(Some(2), registry);

        discovery1
            .register(model_spec("ns", "comp", "generate", "model-a"))
            .await
            .unwrap();

        let err = discovery2
            .register(model_spec("ns", "comp", "generate", "model-b"))
            .await
            .unwrap_err();

        assert!(err.to_string().contains(
            "Cannot register model 'model-b' on endpoint 'ns/comp/generate': a different model 'model-a' is already registered there"
        ));

        let instances = discovery1
            .list(DiscoveryQuery::EndpointModels {
                namespace: "ns".to_string(),
                component: "comp".to_string(),
                endpoint: "generate".to_string(),
            })
            .await
            .unwrap();
        assert_eq!(instances.len(), 1);
    }

    #[tokio::test]
    async fn register_allows_different_model_names_on_different_endpoints() {
        let registry = SharedMockRegistry::new();
        let discovery1 = MockDiscovery::new(Some(1), registry.clone());
        let discovery2 = MockDiscovery::new(Some(2), registry);

        discovery1
            .register(model_spec("ns", "comp", "generate-a", "model-a"))
            .await
            .unwrap();
        discovery2
            .register(model_spec("ns", "comp", "generate-b", "model-b"))
            .await
            .unwrap();
    }

    #[tokio::test]
    async fn register_allows_lora_adapter_on_same_endpoint() {
        let registry = SharedMockRegistry::new();
        let discovery1 = MockDiscovery::new(Some(1), registry.clone());
        let discovery2 = MockDiscovery::new(Some(2), registry);

        discovery1
            .register(DiscoverySpec::Model {
                namespace: "ns".to_string(),
                component: "comp".to_string(),
                endpoint: "generate".to_string(),
                card_json: serde_json::json!({
                    "display_name": "base-model",
                    "source_path": "base-repo",
                }),
                model_suffix: None,
            })
            .await
            .unwrap();

        discovery2
            .register(lora_model_spec(
                "ns",
                "comp",
                "generate",
                "adapter-a",
                "base-repo",
                "adapter-a",
            ))
            .await
            .unwrap();
    }

    #[tokio::test]
    async fn register_rejects_lora_adapter_for_different_base_model() {
        let registry = SharedMockRegistry::new();
        let discovery1 = MockDiscovery::new(Some(1), registry.clone());
        let discovery2 = MockDiscovery::new(Some(2), registry);

        discovery1
            .register(DiscoverySpec::Model {
                namespace: "ns".to_string(),
                component: "comp".to_string(),
                endpoint: "generate".to_string(),
                card_json: serde_json::json!({
                    "display_name": "base-model",
                    "source_path": "base-repo",
                }),
                model_suffix: None,
            })
            .await
            .unwrap();

        let err = discovery2
            .register(lora_model_spec(
                "ns",
                "comp",
                "generate",
                "adapter-a",
                "other-base-repo",
                "adapter-a",
            ))
            .await
            .unwrap_err();

        assert!(err.to_string().contains(
            "Cannot register model 'adapter-a' on endpoint 'ns/comp/generate': a different model 'base-model' is already registered there"
        ));
    }
478
}