mod.rs 12.8 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use anyhow::Result;
5
6
7
8
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
9
use tokio_util::sync::CancellationToken;
10

11
12
13
mod metadata;
pub use metadata::{DiscoveryMetadata, MetadataSnapshot};

14
mod mock;
15
16
17
pub use mock::{MockDiscovery, SharedMockRegistry};
mod kv_store;
pub use kv_store::KVStoreDiscovery;
18
19
20
21

mod kube;
pub use kube::{KubeDiscoveryClient, hash_pod_name};

22
pub mod utils;
23
use crate::component::TransportType;
24
25
pub use utils::watch_and_extract_field;

26
27
28
/// Query key for prefix-based discovery queries
/// Supports hierarchical queries from all endpoints down to specific endpoints
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
29
pub enum DiscoveryQuery {
30
31
32
    /// Query all endpoints in the system
    AllEndpoints,
    /// Query all endpoints in a specific namespace
33
34
35
    NamespacedEndpoints {
        namespace: String,
    },
36
37
38
39
40
41
42
43
44
45
46
    /// Query all endpoints in a namespace/component
    ComponentEndpoints {
        namespace: String,
        component: String,
    },
    /// Query a specific endpoint
    Endpoint {
        namespace: String,
        component: String,
        endpoint: String,
    },
47
48
    AllModels,
    NamespacedModels {
49
50
        namespace: String,
    },
51
    ComponentModels {
52
53
54
        namespace: String,
        component: String,
    },
55
    EndpointModels {
56
57
58
59
        namespace: String,
        component: String,
        endpoint: String,
    },
60
61
62
63
}

/// Specification for registering objects in the discovery plane
/// Represents the input to the register() operation
64
#[derive(Debug, Clone, PartialEq, Eq)]
65
66
67
68
69
70
pub enum DiscoverySpec {
    /// Endpoint specification for registration
    Endpoint {
        namespace: String,
        component: String,
        endpoint: String,
71
72
73
        /// Transport type and routing information
        transport: TransportType,
    },
74
    Model {
75
76
77
78
79
        namespace: String,
        component: String,
        endpoint: String,
        /// ModelDeploymentCard serialized as JSON
        /// This allows lib/runtime to remain independent of lib/llm types
80
        /// DiscoverySpec.from_model() and DiscoveryInstance.deserialize_model() are ergonomic helpers to create and deserialize the model card.
81
        card_json: serde_json::Value,
82
83
84
        /// Optional suffix appended after instance_id in the key path (e.g., for LoRA adapters)
        /// Key format: {namespace}/{component}/{endpoint}/{instance_id}[/{model_suffix}]
        model_suffix: Option<String>,
85
86
87
88
    },
}

impl DiscoverySpec {
89
    /// Creates a Model discovery spec from a serializable type
90
    /// The card will be serialized to JSON to avoid cross-crate dependencies
91
    pub fn from_model<T>(
92
93
94
95
        namespace: String,
        component: String,
        endpoint: String,
        card: &T,
96
    ) -> Result<Self>
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    where
        T: Serialize,
    {
        Self::from_model_with_suffix(namespace, component, endpoint, card, None)
    }

    /// Creates a Model discovery spec with an optional suffix (e.g., for LoRA adapters)
    /// The suffix is appended after the instance_id in the key path
    pub fn from_model_with_suffix<T>(
        namespace: String,
        component: String,
        endpoint: String,
        card: &T,
        model_suffix: Option<String>,
    ) -> Result<Self>
112
113
114
115
    where
        T: Serialize,
    {
        let card_json = serde_json::to_value(card)?;
116
        Ok(Self::Model {
117
118
119
120
            namespace,
            component,
            endpoint,
            card_json,
121
            model_suffix,
122
123
124
        })
    }

125
126
127
128
129
130
131
    /// Attaches an instance ID to create a DiscoveryInstance
    pub fn with_instance_id(self, instance_id: u64) -> DiscoveryInstance {
        match self {
            Self::Endpoint {
                namespace,
                component,
                endpoint,
132
133
134
135
136
137
138
139
                transport,
            } => DiscoveryInstance::Endpoint(crate::component::Instance {
                namespace,
                component,
                endpoint,
                instance_id,
                transport,
            }),
140
            Self::Model {
141
142
143
144
                namespace,
                component,
                endpoint,
                card_json,
145
                model_suffix,
146
            } => DiscoveryInstance::Model {
147
148
149
150
                namespace,
                component,
                endpoint,
                instance_id,
151
                card_json,
152
                model_suffix,
153
154
155
156
157
158
159
            },
        }
    }
}

/// Registered instances in the discovery plane
/// Represents objects that have been successfully registered with an instance ID
160
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
161
162
#[serde(tag = "type")]
pub enum DiscoveryInstance {
163
164
    /// Registered endpoint instance - wraps the component::Instance directly
    Endpoint(crate::component::Instance),
165
    Model {
166
167
168
169
        namespace: String,
        component: String,
        endpoint: String,
        instance_id: u64,
170
171
172
        /// ModelDeploymentCard serialized as JSON
        /// This allows lib/runtime to remain independent of lib/llm types
        card_json: serde_json::Value,
173
174
175
        /// Optional suffix appended after instance_id in the key path (e.g., for LoRA adapters)
        #[serde(default, skip_serializing_if = "Option::is_none")]
        model_suffix: Option<String>,
176
    },
177
178
179
180
181
182
183
}

impl DiscoveryInstance {
    /// Returns the instance ID for this discovery instance
    pub fn instance_id(&self) -> u64 {
        match self {
            Self::Endpoint(inst) => inst.instance_id,
184
            Self::Model { instance_id, .. } => *instance_id,
185
186
187
        }
    }

188
189
    /// Deserializes the model JSON into the specified type T
    /// Returns an error if this is not a Model instance or if deserialization fails
190
    pub fn deserialize_model<T>(&self) -> Result<T>
191
192
193
194
    where
        T: for<'de> Deserialize<'de>,
    {
        match self {
195
            Self::Model { card_json, .. } => Ok(serde_json::from_value(card_json.clone())?),
196
            Self::Endpoint(_) => {
197
                anyhow::bail!("Cannot deserialize model from Endpoint instance")
198
199
200
            }
        }
    }
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344

    /// Extracts the unique identifier for this discovery instance
    /// Used for tracking, diffing, and removal events
    pub fn id(&self) -> DiscoveryInstanceId {
        match self {
            Self::Endpoint(inst) => DiscoveryInstanceId::Endpoint(EndpointInstanceId {
                namespace: inst.namespace.clone(),
                component: inst.component.clone(),
                endpoint: inst.endpoint.clone(),
                instance_id: inst.instance_id,
            }),
            Self::Model {
                namespace,
                component,
                endpoint,
                instance_id,
                model_suffix,
                ..
            } => DiscoveryInstanceId::Model(ModelCardInstanceId {
                namespace: namespace.clone(),
                component: component.clone(),
                endpoint: endpoint.clone(),
                instance_id: *instance_id,
                model_suffix: model_suffix.clone(),
            }),
        }
    }
}

/// Unique identifier for an endpoint instance
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct EndpointInstanceId {
    pub namespace: String,
    pub component: String,
    pub endpoint: String,
    pub instance_id: u64,
}

impl EndpointInstanceId {
    /// Converts to a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}`
    pub fn to_path(&self) -> String {
        format!(
            "{}/{}/{}/{:x}",
            self.namespace, self.component, self.endpoint, self.instance_id
        )
    }

    /// Parses from a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}`
    pub fn from_path(path: &str) -> Result<Self> {
        let parts: Vec<&str> = path.split('/').collect();
        if parts.len() != 4 {
            anyhow::bail!(
                "Invalid EndpointInstanceId path: expected 4 parts, got {}",
                parts.len()
            );
        }
        Ok(Self {
            namespace: parts[0].to_string(),
            component: parts[1].to_string(),
            endpoint: parts[2].to_string(),
            instance_id: u64::from_str_radix(parts[3], 16)
                .map_err(|e| anyhow::anyhow!("Invalid instance_id hex: {}", e))?,
        })
    }
}

/// Unique identifier for a model card instance
/// The combination of (namespace, component, endpoint, instance_id, model_suffix) uniquely identifies a model card
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelCardInstanceId {
    pub namespace: String,
    pub component: String,
    pub endpoint: String,
    pub instance_id: u64,
    /// None for base models, Some(slug) for LoRA adapters
    pub model_suffix: Option<String>,
}

impl ModelCardInstanceId {
    /// Converts to a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}[/{model_suffix}]`
    pub fn to_path(&self) -> String {
        match &self.model_suffix {
            Some(suffix) => format!(
                "{}/{}/{}/{:x}/{}",
                self.namespace, self.component, self.endpoint, self.instance_id, suffix
            ),
            None => format!(
                "{}/{}/{}/{:x}",
                self.namespace, self.component, self.endpoint, self.instance_id
            ),
        }
    }

    /// Parses from a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}[/{model_suffix}]`
    pub fn from_path(path: &str) -> Result<Self> {
        let parts: Vec<&str> = path.split('/').collect();
        if parts.len() < 4 || parts.len() > 5 {
            anyhow::bail!(
                "Invalid ModelCardInstanceId path: expected 4 or 5 parts, got {}",
                parts.len()
            );
        }
        Ok(Self {
            namespace: parts[0].to_string(),
            component: parts[1].to_string(),
            endpoint: parts[2].to_string(),
            instance_id: u64::from_str_radix(parts[3], 16)
                .map_err(|e| anyhow::anyhow!("Invalid instance_id hex: {}", e))?,
            model_suffix: parts.get(4).map(|s| s.to_string()),
        })
    }
}

/// Union of instance identifiers for different discovery object types
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DiscoveryInstanceId {
    Endpoint(EndpointInstanceId),
    Model(ModelCardInstanceId),
}

impl DiscoveryInstanceId {
    /// Returns the raw instance_id regardless of variant type
    pub fn instance_id(&self) -> u64 {
        match self {
            Self::Endpoint(eid) => eid.instance_id,
            Self::Model(mid) => mid.instance_id,
        }
    }

    /// Extracts the EndpointInstanceId, returning an error if this is a Model variant
    pub fn extract_endpoint_id(&self) -> Result<&EndpointInstanceId> {
        match self {
            Self::Endpoint(eid) => Ok(eid),
            Self::Model(_) => anyhow::bail!("Expected Endpoint variant, got Model"),
        }
    }

    /// Extracts the ModelCardInstanceId, returning an error if this is an Endpoint variant
    pub fn extract_model_id(&self) -> Result<&ModelCardInstanceId> {
        match self {
            Self::Model(mid) => Ok(mid),
            Self::Endpoint(_) => anyhow::bail!("Expected Model variant, got Endpoint"),
        }
    }
345
346
}

347
/// Events emitted by the discovery watch stream
348
349
350
351
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DiscoveryEvent {
    /// A new instance was added
    Added(DiscoveryInstance),
352
353
    /// An instance was removed (identified by its unique ID)
    Removed(DiscoveryInstanceId),
354
355
356
357
358
}

/// Stream type for discovery events
pub type DiscoveryStream = Pin<Box<dyn Stream<Item = Result<DiscoveryEvent>> + Send>>;

359
/// Discovery trait for service discovery across different backends
360
#[async_trait]
361
pub trait Discovery: Send + Sync {
362
363
364
365
366
367
368
    /// Returns a unique identifier for this worker (e.g lease id if using etcd or generated id for memory store)
    /// Discovery objects created by this worker will be associated with this id.
    fn instance_id(&self) -> u64;

    /// Registers an object in the discovery plane with the instance id
    async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance>;

369
370
371
    /// Unregisters an instance from the discovery plane
    async fn unregister(&self, instance: DiscoveryInstance) -> Result<()>;

372
    /// Returns a list of currently registered instances for the given discovery query
373
    /// This is a one-time snapshot without watching for changes
374
375
376
377
378
379
380
381
382
    async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>>;

    /// Returns a stream of discovery events (Added/Removed) for the given discovery query
    /// The optional cancellation token can be used to stop the watch stream
    async fn list_and_watch(
        &self,
        query: DiscoveryQuery,
        cancel_token: Option<CancellationToken>,
    ) -> Result<DiscoveryStream>;
383
}