Unverified Commit e1547e24 authored by mohammedabdulwahhab's avatar mohammedabdulwahhab Committed by GitHub
Browse files

fix: expand discovery interface to support model types (#4090)


Signed-off-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
parent 3c0763f7
...@@ -46,37 +46,79 @@ impl MockDiscoveryClient { ...@@ -46,37 +46,79 @@ impl MockDiscoveryClient {
/// Helper function to check if an instance matches a discovery key query /// Helper function to check if an instance matches a discovery key query
fn matches_key(instance: &DiscoveryInstance, key: &DiscoveryKey) -> bool { fn matches_key(instance: &DiscoveryInstance, key: &DiscoveryKey) -> bool {
match (instance, key) { match (instance, key) {
(DiscoveryInstance::Endpoint { .. }, DiscoveryKey::AllEndpoints) => true, // Endpoint matching
(DiscoveryInstance::Endpoint(_), DiscoveryKey::AllEndpoints) => true,
(DiscoveryInstance::Endpoint(inst), DiscoveryKey::NamespacedEndpoints { namespace }) => {
&inst.namespace == namespace
}
(
DiscoveryInstance::Endpoint(inst),
DiscoveryKey::ComponentEndpoints {
namespace,
component,
},
) => &inst.namespace == namespace && &inst.component == component,
(
DiscoveryInstance::Endpoint(inst),
DiscoveryKey::Endpoint {
namespace,
component,
endpoint,
},
) => {
&inst.namespace == namespace
&& &inst.component == component
&& &inst.endpoint == endpoint
}
// ModelCard matching
(DiscoveryInstance::ModelCard { .. }, DiscoveryKey::AllModelCards) => true,
( (
DiscoveryInstance::Endpoint { DiscoveryInstance::ModelCard {
namespace: ins_ns, .. namespace: inst_ns, ..
}, },
DiscoveryKey::NamespacedEndpoints { namespace }, DiscoveryKey::NamespacedModelCards { namespace },
) => ins_ns == namespace, ) => inst_ns == namespace,
( (
DiscoveryInstance::Endpoint { DiscoveryInstance::ModelCard {
namespace: ins_ns, namespace: inst_ns,
component: ins_comp, component: inst_comp,
.. ..
}, },
DiscoveryKey::ComponentEndpoints { DiscoveryKey::ComponentModelCards {
namespace, namespace,
component, component,
}, },
) => ins_ns == namespace && ins_comp == component, ) => inst_ns == namespace && inst_comp == component,
( (
DiscoveryInstance::Endpoint { DiscoveryInstance::ModelCard {
namespace: ins_ns, namespace: inst_ns,
component: ins_comp, component: inst_comp,
endpoint: ins_ep, endpoint: inst_ep,
.. ..
}, },
DiscoveryKey::Endpoint { DiscoveryKey::EndpointModelCards {
namespace, namespace,
component, component,
endpoint, endpoint,
}, },
) => ins_ns == namespace && ins_comp == component && ins_ep == endpoint, ) => inst_ns == namespace && inst_comp == component && inst_ep == endpoint,
// Cross-type matches return false
(
DiscoveryInstance::Endpoint(_),
DiscoveryKey::AllModelCards
| DiscoveryKey::NamespacedModelCards { .. }
| DiscoveryKey::ComponentModelCards { .. }
| DiscoveryKey::EndpointModelCards { .. },
) => false,
(
DiscoveryInstance::ModelCard { .. },
DiscoveryKey::AllEndpoints
| DiscoveryKey::NamespacedEndpoints { .. }
| DiscoveryKey::ComponentEndpoints { .. }
| DiscoveryKey::Endpoint { .. },
) => false,
} }
} }
...@@ -98,6 +140,15 @@ impl DiscoveryClient for MockDiscoveryClient { ...@@ -98,6 +140,15 @@ impl DiscoveryClient for MockDiscoveryClient {
Ok(instance) Ok(instance)
} }
async fn list(&self, key: DiscoveryKey) -> Result<Vec<DiscoveryInstance>> {
let instances = self.registry.instances.lock().unwrap();
Ok(instances
.iter()
.filter(|instance| matches_key(instance, &key))
.cloned()
.collect())
}
async fn list_and_watch(&self, key: DiscoveryKey) -> Result<DiscoveryStream> { async fn list_and_watch(&self, key: DiscoveryKey) -> Result<DiscoveryStream> {
use std::collections::HashSet; use std::collections::HashSet;
...@@ -118,14 +169,16 @@ impl DiscoveryClient for MockDiscoveryClient { ...@@ -118,14 +169,16 @@ impl DiscoveryClient for MockDiscoveryClient {
let current_ids: HashSet<_> = current.iter().map(|i| { let current_ids: HashSet<_> = current.iter().map(|i| {
match i { match i {
DiscoveryInstance::Endpoint { instance_id, .. } => *instance_id, DiscoveryInstance::Endpoint(inst) => inst.instance_id,
DiscoveryInstance::ModelCard { instance_id, .. } => *instance_id,
} }
}).collect(); }).collect();
// Emit Added events for new instances // Emit Added events for new instances
for instance in current { for instance in current {
let id = match &instance { let id = match &instance {
DiscoveryInstance::Endpoint { instance_id, .. } => *instance_id, DiscoveryInstance::Endpoint(inst) => inst.instance_id,
DiscoveryInstance::ModelCard { instance_id, .. } => *instance_id,
}; };
if known_instances.insert(id) { if known_instances.insert(id) {
yield Ok(DiscoveryEvent::Added(instance)); yield Ok(DiscoveryEvent::Added(instance));
...@@ -161,6 +214,7 @@ mod tests { ...@@ -161,6 +214,7 @@ mod tests {
namespace: "test-ns".to_string(), namespace: "test-ns".to_string(),
component: "test-comp".to_string(), component: "test-comp".to_string(),
endpoint: "test-ep".to_string(), endpoint: "test-ep".to_string(),
transport: crate::component::TransportType::NatsTcp("test-subject".to_string()),
}; };
let key = DiscoveryKey::Endpoint { let key = DiscoveryKey::Endpoint {
...@@ -177,8 +231,8 @@ mod tests { ...@@ -177,8 +231,8 @@ mod tests {
let event = stream.next().await.unwrap().unwrap(); let event = stream.next().await.unwrap().unwrap();
match event { match event {
DiscoveryEvent::Added(DiscoveryInstance::Endpoint { instance_id, .. }) => { DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
assert_eq!(instance_id, 1); assert_eq!(inst.instance_id, 1);
} }
_ => panic!("Expected Added event for instance-1"), _ => panic!("Expected Added event for instance-1"),
} }
...@@ -188,15 +242,16 @@ mod tests { ...@@ -188,15 +242,16 @@ mod tests {
let event = stream.next().await.unwrap().unwrap(); let event = stream.next().await.unwrap().unwrap();
match event { match event {
DiscoveryEvent::Added(DiscoveryInstance::Endpoint { instance_id, .. }) => { DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
assert_eq!(instance_id, 2); assert_eq!(inst.instance_id, 2);
} }
_ => panic!("Expected Added event for instance-2"), _ => panic!("Expected Added event for instance-2"),
} }
// Remove first instance // Remove first instance
registry.instances.lock().unwrap().retain(|i| match i { registry.instances.lock().unwrap().retain(|i| match i {
DiscoveryInstance::Endpoint { instance_id, .. } => *instance_id != 1, DiscoveryInstance::Endpoint(inst) => inst.instance_id != 1,
DiscoveryInstance::ModelCard { instance_id, .. } => *instance_id != 1,
}); });
let event = stream.next().await.unwrap().unwrap(); let event = stream.next().await.unwrap().unwrap();
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::Result; use crate::Result;
use crate::component::TransportType;
use async_trait::async_trait; use async_trait::async_trait;
use futures::Stream; use futures::Stream;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -10,6 +11,9 @@ use std::pin::Pin; ...@@ -10,6 +11,9 @@ use std::pin::Pin;
mod mock; mod mock;
pub use mock::{MockDiscoveryClient, SharedMockRegistry}; pub use mock::{MockDiscoveryClient, SharedMockRegistry};
pub mod utils;
pub use utils::watch_and_extract_field;
/// Query key for prefix-based discovery queries /// Query key for prefix-based discovery queries
/// Supports hierarchical queries from all endpoints down to specific endpoints /// Supports hierarchical queries from all endpoints down to specific endpoints
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
...@@ -17,7 +21,9 @@ pub enum DiscoveryKey { ...@@ -17,7 +21,9 @@ pub enum DiscoveryKey {
/// Query all endpoints in the system /// Query all endpoints in the system
AllEndpoints, AllEndpoints,
/// Query all endpoints in a specific namespace /// Query all endpoints in a specific namespace
NamespacedEndpoints { namespace: String }, NamespacedEndpoints {
namespace: String,
},
/// Query all endpoints in a namespace/component /// Query all endpoints in a namespace/component
ComponentEndpoints { ComponentEndpoints {
namespace: String, namespace: String,
...@@ -29,28 +35,65 @@ pub enum DiscoveryKey { ...@@ -29,28 +35,65 @@ pub enum DiscoveryKey {
component: String, component: String,
endpoint: String, endpoint: String,
}, },
// TODO: Extend to support ModelCard queries: AllModelCards,
// - AllModels NamespacedModelCards {
// - NamespacedModels { namespace } namespace: String,
// - ComponentModels { namespace, component } },
// - Model { namespace, component, model_name } ComponentModelCards {
namespace: String,
component: String,
},
EndpointModelCards {
namespace: String,
component: String,
endpoint: String,
},
} }
/// Specification for registering objects in the discovery plane /// Specification for registering objects in the discovery plane
/// Represents the input to the register() operation /// Represents the input to the register() operation
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum DiscoverySpec { pub enum DiscoverySpec {
/// Endpoint specification for registration /// Endpoint specification for registration
Endpoint { Endpoint {
namespace: String, namespace: String,
component: String, component: String,
endpoint: String, endpoint: String,
/// Transport type and routing information
transport: TransportType,
},
ModelCard {
namespace: String,
component: String,
endpoint: String,
/// ModelDeploymentCard serialized as JSON
/// This allows lib/runtime to remain independent of lib/llm types
/// DiscoverySpec.from_model_card() and DiscoveryInstance.deserialize_model_card() are ergonomic helpers to create and deserialize the model card.
card_json: serde_json::Value,
}, },
// TODO: Add ModelCard variant:
// - ModelCard { namespace, component, model_name, card: ModelDeploymentCard }
} }
impl DiscoverySpec { impl DiscoverySpec {
/// Creates a ModelCard discovery spec from a serializable type
/// The card will be serialized to JSON to avoid cross-crate dependencies
pub fn from_model_card<T>(
namespace: String,
component: String,
endpoint: String,
card: &T,
) -> crate::Result<Self>
where
T: Serialize,
{
let card_json = serde_json::to_value(card)?;
Ok(Self::ModelCard {
namespace,
component,
endpoint,
card_json,
})
}
/// Attaches an instance ID to create a DiscoveryInstance /// Attaches an instance ID to create a DiscoveryInstance
pub fn with_instance_id(self, instance_id: u64) -> DiscoveryInstance { pub fn with_instance_id(self, instance_id: u64) -> DiscoveryInstance {
match self { match self {
...@@ -58,11 +101,25 @@ impl DiscoverySpec { ...@@ -58,11 +101,25 @@ impl DiscoverySpec {
namespace, namespace,
component, component,
endpoint, endpoint,
} => DiscoveryInstance::Endpoint { transport,
} => DiscoveryInstance::Endpoint(crate::component::Instance {
namespace,
component,
endpoint,
instance_id,
transport,
}),
Self::ModelCard {
namespace,
component,
endpoint,
card_json,
} => DiscoveryInstance::ModelCard {
namespace, namespace,
component, component,
endpoint, endpoint,
instance_id, instance_id,
card_json,
}, },
} }
} }
...@@ -70,18 +127,44 @@ impl DiscoverySpec { ...@@ -70,18 +127,44 @@ impl DiscoverySpec {
/// Registered instances in the discovery plane /// Registered instances in the discovery plane
/// Represents objects that have been successfully registered with an instance ID /// Represents objects that have been successfully registered with an instance ID
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum DiscoveryInstance { pub enum DiscoveryInstance {
/// Registered endpoint instance /// Registered endpoint instance - wraps the component::Instance directly
Endpoint { Endpoint(crate::component::Instance),
ModelCard {
namespace: String, namespace: String,
component: String, component: String,
endpoint: String, endpoint: String,
instance_id: u64, instance_id: u64,
/// ModelDeploymentCard serialized as JSON
/// This allows lib/runtime to remain independent of lib/llm types
card_json: serde_json::Value,
}, },
// TODO: Add ModelCard variant: }
// - ModelCard { namespace, component, model_name, instance_id, card: ModelDeploymentCard }
impl DiscoveryInstance {
/// Returns the instance ID for this discovery instance
pub fn instance_id(&self) -> u64 {
match self {
Self::Endpoint(inst) => inst.instance_id,
Self::ModelCard { instance_id, .. } => *instance_id,
}
}
/// Deserializes the model card JSON into the specified type T
/// Returns an error if this is not a ModelCard instance or if deserialization fails
pub fn deserialize_model_card<T>(&self) -> crate::Result<T>
where
T: for<'de> Deserialize<'de>,
{
match self {
Self::ModelCard { card_json, .. } => Ok(serde_json::from_value(card_json.clone())?),
Self::Endpoint(_) => {
crate::raise!("Cannot deserialize model card from Endpoint instance")
}
}
}
} }
/// Events emitted by the discovery client watch stream /// Events emitted by the discovery client watch stream
...@@ -106,6 +189,10 @@ pub trait DiscoveryClient: Send + Sync { ...@@ -106,6 +189,10 @@ pub trait DiscoveryClient: Send + Sync {
/// Registers an object in the discovery plane with the instance id /// Registers an object in the discovery plane with the instance id
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance>; async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance>;
/// Returns a list of currently registered instances for the given discovery key
/// This is a one-time snapshot without watching for changes
async fn list(&self, key: DiscoveryKey) -> Result<Vec<DiscoveryInstance>>;
/// Returns a stream of discovery events (Added/Removed) for the given discovery key /// Returns a stream of discovery events (Added/Removed) for the given discovery key
async fn list_and_watch(&self, key: DiscoveryKey) -> Result<DiscoveryStream>; async fn list_and_watch(&self, key: DiscoveryKey) -> Result<DiscoveryStream>;
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Utility functions for working with discovery streams
use serde::Deserialize;
use super::{DiscoveryEvent, DiscoveryInstance, DiscoveryStream};
/// Helper to watch a discovery stream and extract a specific field into a HashMap
///
/// This helper spawns a background task that:
/// - Deserializes ModelCards from discovery events
/// - Extracts a specific field using the provided extractor function
/// - Maintains a HashMap<instance_id, Field> that auto-updates on Add/Remove events
/// - Returns a watch::Receiver that consumers can use to read the current state
///
/// # Type Parameters
/// - `T`: The type to deserialize from DiscoveryInstance (e.g., ModelDeploymentCard)
/// - `V`: The extracted field type (e.g., ModelRuntimeConfig)
/// - `F`: The extractor function type
///
/// # Arguments
/// - `stream`: The discovery event stream to watch
/// - `extractor`: Function that extracts the desired field from the deserialized type
///
/// # Example
/// ```ignore
/// let stream = discovery.list_and_watch(DiscoveryKey::ComponentModelCards { ... }).await?;
/// let runtime_configs_rx = watch_and_extract_field(
/// stream,
/// |card: ModelDeploymentCard| card.runtime_config,
/// );
///
/// // Use it:
/// let configs = runtime_configs_rx.borrow();
/// if let Some(config) = configs.get(&worker_id) {
/// // Use config...
/// }
/// ```
pub fn watch_and_extract_field<T, V, F>(
stream: DiscoveryStream,
extractor: F,
) -> tokio::sync::watch::Receiver<std::collections::HashMap<u64, V>>
where
T: for<'de> Deserialize<'de> + 'static,
V: Clone + Send + Sync + 'static,
F: Fn(T) -> V + Send + 'static,
{
use futures::StreamExt;
use std::collections::HashMap;
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
tokio::spawn(async move {
let mut state: HashMap<u64, V> = HashMap::new();
let mut stream = stream;
while let Some(result) = stream.next().await {
match result {
Ok(DiscoveryEvent::Added(instance)) => {
let instance_id = instance.instance_id();
// Deserialize the full instance into type T
let deserialized: T = match instance.deserialize_model_card() {
Ok(d) => d,
Err(e) => {
tracing::warn!(
instance_id,
error = %e,
"Failed to deserialize discovery instance, skipping"
);
continue;
}
};
// Extract the field we care about
let value = extractor(deserialized);
// Update state and send
state.insert(instance_id, value);
if tx.send(state.clone()).is_err() {
tracing::debug!("watch_and_extract_field receiver dropped, stopping");
break;
}
}
Ok(DiscoveryEvent::Removed(instance_id)) => {
// Remove from state and send update
state.remove(&instance_id);
if tx.send(state.clone()).is_err() {
tracing::debug!("watch_and_extract_field receiver dropped, stopping");
break;
}
}
Err(e) => {
tracing::error!(error = %e, "Discovery event stream error in watch_and_extract_field");
// Continue processing other events
}
}
}
tracing::debug!("watch_and_extract_field task stopped");
});
rx
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment