mock.rs 9.87 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
// SPDX-License-Identifier: Apache-2.0

use super::{
5
6
    Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
    DiscoverySpec, DiscoveryStream,
7
};
8
use anyhow::Result;
9
10
use async_trait::async_trait;
use std::sync::{Arc, Mutex};
11
use tokio_util::sync::CancellationToken;
12
13
14
15
16
17
18
19
20
21
22
23
24

/// Shared in-memory registry for mock discovery
#[derive(Clone, Default)]
pub struct SharedMockRegistry {
    instances: Arc<Mutex<Vec<DiscoveryInstance>>>,
}

impl SharedMockRegistry {
    pub fn new() -> Self {
        Self::default()
    }
}

25
26
27
/// Mock implementation of Discovery for testing
/// We can potentially remove this once we have KVStoreDiscovery fully tested
pub struct MockDiscovery {
28
29
30
31
    instance_id: u64,
    registry: SharedMockRegistry,
}

32
impl MockDiscovery {
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    pub fn new(instance_id: Option<u64>, registry: SharedMockRegistry) -> Self {
        let instance_id = instance_id.unwrap_or_else(|| {
            use std::sync::atomic::{AtomicU64, Ordering};
            static COUNTER: AtomicU64 = AtomicU64::new(1);
            COUNTER.fetch_add(1, Ordering::SeqCst)
        });

        Self {
            instance_id,
            registry,
        }
    }
}

47
48
49
/// Helper function to check if an instance matches a discovery query
fn matches_query(instance: &DiscoveryInstance, query: &DiscoveryQuery) -> bool {
    match (instance, query) {
50
        // Endpoint matching
51
52
        (DiscoveryInstance::Endpoint(_), DiscoveryQuery::AllEndpoints) => true,
        (DiscoveryInstance::Endpoint(inst), DiscoveryQuery::NamespacedEndpoints { namespace }) => {
53
54
55
56
            &inst.namespace == namespace
        }
        (
            DiscoveryInstance::Endpoint(inst),
57
            DiscoveryQuery::ComponentEndpoints {
58
59
60
61
62
63
                namespace,
                component,
            },
        ) => &inst.namespace == namespace && &inst.component == component,
        (
            DiscoveryInstance::Endpoint(inst),
64
            DiscoveryQuery::Endpoint {
65
66
67
68
69
70
71
72
73
74
                namespace,
                component,
                endpoint,
            },
        ) => {
            &inst.namespace == namespace
                && &inst.component == component
                && &inst.endpoint == endpoint
        }

75
76
        // Model matching
        (DiscoveryInstance::Model { .. }, DiscoveryQuery::AllModels) => true,
77
        (
78
            DiscoveryInstance::Model {
79
                namespace: inst_ns, ..
80
            },
81
            DiscoveryQuery::NamespacedModels { namespace },
82
        ) => inst_ns == namespace,
83
        (
84
            DiscoveryInstance::Model {
85
86
                namespace: inst_ns,
                component: inst_comp,
87
88
                ..
            },
89
            DiscoveryQuery::ComponentModels {
90
91
92
                namespace,
                component,
            },
93
        ) => inst_ns == namespace && inst_comp == component,
94
        (
95
            DiscoveryInstance::Model {
96
97
98
                namespace: inst_ns,
                component: inst_comp,
                endpoint: inst_ep,
99
100
                ..
            },
101
            DiscoveryQuery::EndpointModels {
102
103
104
105
                namespace,
                component,
                endpoint,
            },
106
107
        ) => inst_ns == namespace && inst_comp == component && inst_ep == endpoint,

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        // EventChannel matching - unified query
        (
            DiscoveryInstance::EventChannel {
                namespace: inst_ns,
                component: inst_comp,
                topic: inst_topic,
                ..
            },
            DiscoveryQuery::EventChannels(query),
        ) => {
            query.namespace.as_ref().is_none_or(|ns| ns == inst_ns)
                && query.component.as_ref().is_none_or(|c| c == inst_comp)
                && query.topic.as_ref().is_none_or(|t| t == inst_topic)
        }

123
124
125
        // Cross-type matches return false
        (
            DiscoveryInstance::Endpoint(_),
126
127
128
            DiscoveryQuery::AllModels
            | DiscoveryQuery::NamespacedModels { .. }
            | DiscoveryQuery::ComponentModels { .. }
129
130
            | DiscoveryQuery::EndpointModels { .. }
            | DiscoveryQuery::EventChannels(_),
131
132
        ) => false,
        (
133
134
135
136
            DiscoveryInstance::Model { .. },
            DiscoveryQuery::AllEndpoints
            | DiscoveryQuery::NamespacedEndpoints { .. }
            | DiscoveryQuery::ComponentEndpoints { .. }
137
138
139
140
141
142
143
144
145
146
147
148
149
            | DiscoveryQuery::Endpoint { .. }
            | DiscoveryQuery::EventChannels(_),
        ) => false,
        (
            DiscoveryInstance::EventChannel { .. },
            DiscoveryQuery::AllEndpoints
            | DiscoveryQuery::NamespacedEndpoints { .. }
            | DiscoveryQuery::ComponentEndpoints { .. }
            | DiscoveryQuery::Endpoint { .. }
            | DiscoveryQuery::AllModels
            | DiscoveryQuery::NamespacedModels { .. }
            | DiscoveryQuery::ComponentModels { .. }
            | DiscoveryQuery::EndpointModels { .. },
150
        ) => false,
151
152
153
154
    }
}

#[async_trait]
155
impl Discovery for MockDiscovery {
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    fn instance_id(&self) -> u64 {
        self.instance_id
    }

    async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
        let instance = spec.with_instance_id(self.instance_id);

        self.registry
            .instances
            .lock()
            .unwrap()
            .push(instance.clone());

        Ok(instance)
    }

172
173
174
175
176
177
178
179
180
181
182
183
    async fn unregister(&self, instance: DiscoveryInstance) -> Result<()> {
        let instance_id = instance.instance_id();

        self.registry
            .instances
            .lock()
            .unwrap()
            .retain(|i| i.instance_id() != instance_id);

        Ok(())
    }

184
    async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
185
186
187
        let instances = self.registry.instances.lock().unwrap();
        Ok(instances
            .iter()
188
            .filter(|instance| matches_query(instance, &query))
189
190
191
192
            .cloned()
            .collect())
    }

193
194
195
196
197
    async fn list_and_watch(
        &self,
        query: DiscoveryQuery,
        _cancel_token: Option<CancellationToken>,
    ) -> Result<DiscoveryStream> {
198
199
200
201
202
        use std::collections::HashSet;

        let registry = self.registry.clone();

        let stream = async_stream::stream! {
203
            let mut known_instances: HashSet<DiscoveryInstanceId> = HashSet::new();
204
205
206
207
208
209

            loop {
                let current: Vec<_> = {
                    let instances = registry.instances.lock().unwrap();
                    instances
                        .iter()
210
                        .filter(|instance| matches_query(instance, &query))
211
212
213
214
                        .cloned()
                        .collect()
                };

215
                let current_ids: HashSet<DiscoveryInstanceId> = current.iter().map(|i| i.id()).collect();
216
217
218

                // Emit Added events for new instances
                for instance in current {
219
                    let id = instance.id();
220
221
222
223
224
225
226
227
                    if known_instances.insert(id) {
                        yield Ok(DiscoveryEvent::Added(instance));
                    }
                }

                // Emit Removed events for instances that are gone
                for id in known_instances.difference(&current_ids).cloned().collect::<Vec<_>>() {
                    known_instances.remove(&id);
228
                    yield Ok(DiscoveryEvent::Removed(id));
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
                }

                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
            }
        };

        Ok(Box::pin(stream))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use futures::StreamExt;

    #[tokio::test]
    async fn test_mock_discovery_add_and_remove() {
        let registry = SharedMockRegistry::new();
247
248
        let client1 = MockDiscovery::new(Some(1), registry.clone());
        let client2 = MockDiscovery::new(Some(2), registry.clone());
249
250
251
252
253

        let spec = DiscoverySpec::Endpoint {
            namespace: "test-ns".to_string(),
            component: "test-comp".to_string(),
            endpoint: "test-ep".to_string(),
254
            transport: crate::component::TransportType::Nats("test-subject".to_string()),
255
256
        };

257
        let query = DiscoveryQuery::Endpoint {
258
259
260
261
262
263
            namespace: "test-ns".to_string(),
            component: "test-comp".to_string(),
            endpoint: "test-ep".to_string(),
        };

        // Start watching
264
        let mut stream = client1.list_and_watch(query.clone(), None).await.unwrap();
265
266
267
268
269
270

        // Add first instance
        client1.register(spec.clone()).await.unwrap();

        let event = stream.next().await.unwrap().unwrap();
        match event {
271
272
            DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
                assert_eq!(inst.instance_id, 1);
273
274
275
276
277
278
279
280
281
            }
            _ => panic!("Expected Added event for instance-1"),
        }

        // Add second instance
        client2.register(spec.clone()).await.unwrap();

        let event = stream.next().await.unwrap().unwrap();
        match event {
282
283
            DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
                assert_eq!(inst.instance_id, 2);
284
285
286
287
288
289
            }
            _ => panic!("Expected Added event for instance-2"),
        }

        // Remove first instance
        registry.instances.lock().unwrap().retain(|i| match i {
290
            DiscoveryInstance::Endpoint(inst) => inst.instance_id != 1,
291
            DiscoveryInstance::Model { instance_id, .. } => *instance_id != 1,
292
            DiscoveryInstance::EventChannel { instance_id, .. } => *instance_id != 1,
293
294
295
296
        });

        let event = stream.next().await.unwrap().unwrap();
        match event {
297
298
299
            DiscoveryEvent::Removed(id) => {
                let endpoint_id = id.extract_endpoint_id().expect("Expected endpoint removal");
                assert_eq!(endpoint_id.instance_id, 1);
300
301
302
303
304
            }
            _ => panic!("Expected Removed event for instance-1"),
        }
    }
}