Unverified Commit 2f9812aa authored by mohammedabdulwahhab's avatar mohammedabdulwahhab Committed by GitHub
Browse files

fix: fix bug in diffing logic in list_and_watch (#5318)


Signed-off-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
parent 91eb0ed8
......@@ -1306,12 +1306,12 @@ fn spawn_prefill_watcher(
}
}
}
DiscoveryEvent::Removed(instance_id) => {
DiscoveryEvent::Removed(id) => {
// Log removal for observability
// Note: The PrefillRouter remains active - worker availability
// is handled dynamically by the underlying Client's instance tracking
tracing::debug!(
instance_id = instance_id,
instance_id = id.instance_id(),
"Prefill worker instance removed from discovery"
);
}
......
......@@ -320,14 +320,14 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
/// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is
/// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is
/// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
self.cards.lock().insert(key.to_string(), card);
Ok(())
}
/// Remove and return model card for this instance's etcd key. We do this when the instance stops.
/// Remove and return model card for this instance's key. We do this when the instance stops.
pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
self.cards.lock().remove(key)
}
......
......@@ -10,7 +10,10 @@ use futures::StreamExt;
use dynamo_runtime::{
DistributedRuntime,
discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoveryStream},
discovery::{
DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery, DiscoveryStream,
ModelCardInstanceId,
},
pipeline::{
ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
network::egress::push_router::PushRouter,
......@@ -119,23 +122,26 @@ impl ModelWatcher {
match event {
DiscoveryEvent::Added(instance) => {
// Extract EndpointId, instance_id, and card from the discovery instance
let (endpoint_id, instance_id, mut card) = match &instance {
// Extract ModelCardInstanceId and card from the discovery instance
let (mcid, mut card) = match &instance {
DiscoveryInstance::Model {
namespace,
component,
endpoint,
instance_id,
model_suffix,
..
} => {
let eid = EndpointId {
let mcid = ModelCardInstanceId {
namespace: namespace.clone(),
component: component.clone(),
name: endpoint.clone(),
endpoint: endpoint.clone(),
instance_id: *instance_id,
model_suffix: model_suffix.clone(),
};
match instance.deserialize_model::<ModelDeploymentCard>() {
Ok(card) => (eid, *instance_id, card),
Ok(card) => (mcid, card),
Err(err) => {
tracing::error!(%err, instance_id, "Failed to deserialize model card");
continue;
......@@ -153,10 +159,10 @@ impl ModelWatcher {
// Filter by namespace if target_namespace is specified
if !global_namespace
&& let Some(target_ns) = target_namespace
&& endpoint_id.namespace != target_ns
&& mcid.namespace != target_ns
{
tracing::debug!(
model_namespace = endpoint_id.namespace,
model_namespace = mcid.namespace,
target_namespace = target_ns,
"Skipping model from different namespace"
);
......@@ -185,14 +191,11 @@ impl ModelWatcher {
continue;
}
// Use instance_id as the HashMap key (simpler and sufficient since keys are opaque)
let key = format!("{:x}", instance_id);
match self.handle_put(&key, &endpoint_id, &mut card).await {
match self.handle_put(&mcid, &mut card).await {
Ok(()) => {
tracing::info!(
model_name = card.name(),
namespace = endpoint_id.namespace,
namespace = mcid.namespace,
"added model"
);
self.notify_on_model.notify_waiters();
......@@ -200,19 +203,27 @@ impl ModelWatcher {
Err(err) => {
tracing::error!(
model_name = card.name(),
namespace = endpoint_id.namespace,
namespace = mcid.namespace,
error = format!("{err:#}"),
"Error adding model from discovery",
);
}
}
}
DiscoveryEvent::Removed(instance_id) => {
// Use instance_id hex as the HashMap key (matches what we saved with)
let key = format!("{:x}", instance_id);
DiscoveryEvent::Removed(id) => {
// Extract ModelCardInstanceId from the removal event
let model_card_instance_id = match &id {
DiscoveryInstanceId::Model(mcid) => mcid,
DiscoveryInstanceId::Endpoint(_) => {
tracing::error!(
"Unexpected discovery instance type in removal (expected Model)"
);
continue;
}
};
match self
.handle_delete(&key, target_namespace, global_namespace)
.handle_delete(model_card_instance_id, target_namespace, global_namespace)
.await
{
Ok(Some(model_name)) => {
......@@ -234,14 +245,15 @@ impl ModelWatcher {
/// Returns the name of the model we just deleted, if any.
async fn handle_delete(
&self,
key: &str,
mcid: &ModelCardInstanceId,
target_namespace: Option<&str>,
is_global_namespace: bool,
) -> anyhow::Result<Option<String>> {
let card = match self.manager.remove_model_card(key) {
let key = mcid.to_path();
let card = match self.manager.remove_model_card(&key) {
Some(card) => card,
None => {
anyhow::bail!("Missing ModelDeploymentCard for {key}");
anyhow::bail!("Missing ModelDeploymentCard for {}", key);
}
};
let model_name = card.name().to_string();
......@@ -325,20 +337,20 @@ impl ModelWatcher {
// models.
async fn handle_put(
&self,
key: &str,
endpoint_id: &EndpointId,
mcid: &ModelCardInstanceId,
card: &mut ModelDeploymentCard,
) -> anyhow::Result<()> {
card.download_config().await?;
let component = self
.drt
.namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?;
let endpoint = component.endpoint(&endpoint_id.name);
.namespace(&mcid.namespace)?
.component(&mcid.component)?;
let endpoint = component.endpoint(&mcid.endpoint);
let client = endpoint.client().await?;
tracing::debug!(model_name = card.name(), "adding model");
self.manager.save_model_card(key, card.clone())?;
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
// Skip duplicate registrations based on model type.
// Prefill and decode models are tracked separately, so registering one
......@@ -352,7 +364,7 @@ impl ModelWatcher {
if already_registered {
tracing::debug!(
model_name = card.name(),
namespace = endpoint_id.namespace,
namespace = mcid.namespace,
model_type = %card.model_type,
"Model already registered, skipping"
);
......@@ -372,7 +384,7 @@ impl ModelWatcher {
// A model that expects pre-processed requests meaning it's up to us whether we
// handle Chat or Completions requests, so handle whatever the model supports.
let endpoint = component.endpoint(&endpoint_id.name);
let endpoint = component.endpoint(&mcid.endpoint);
let kv_chooser = if self.router_config.router_mode == RouterMode::KV {
Some(
self.manager
......
......@@ -564,10 +564,12 @@ pub async fn start_kv_router_background(
continue;
};
let DiscoveryEvent::Removed(worker_id) = discovery_event else {
let DiscoveryEvent::Removed(id) = discovery_event else {
continue;
};
let worker_id = id.instance_id();
tracing::warn!(
"DISCOVERY: Generate endpoint instance removed, removing worker {worker_id}"
);
......@@ -642,11 +644,13 @@ pub async fn start_kv_router_background(
continue;
};
let DiscoveryEvent::Removed(router_instance_id) = router_event else {
let DiscoveryEvent::Removed(id) = router_event else {
// We only care about removals for cleaning up consumers
continue;
};
let router_instance_id = id.instance_id();
// The consumer UUID is the instance_id in hex format
let consumer_to_delete = router_instance_id.to_string();
......@@ -708,7 +712,8 @@ async fn handle_worker_discovery(
}
}
}
DiscoveryEvent::Removed(worker_id) => {
DiscoveryEvent::Removed(id) => {
let worker_id = id.instance_id();
tracing::warn!("DISCOVERY: Worker {worker_id} removed, removing from router indexer");
if let Err(e) = remove_worker_tx.send(worker_id).await {
......
......@@ -9,7 +9,7 @@ use arc_swap::ArcSwap;
use futures::StreamExt;
use tokio::net::unix::pipe::Receiver;
use crate::discovery::{DiscoveryEvent, DiscoveryInstance};
use crate::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId};
use crate::{
component::{Endpoint, Instance},
pipeline::async_trait,
......@@ -255,8 +255,8 @@ impl Client {
map.insert(instance.instance_id, instance);
}
}
DiscoveryEvent::Removed(instance_id) => {
map.remove(&instance_id);
DiscoveryEvent::Removed(id) => {
map.remove(&id.instance_id());
}
}
......
......@@ -14,8 +14,8 @@ use utils::PodInfo;
use crate::CancellationToken;
use crate::discovery::{
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryMetadata, DiscoveryQuery, DiscoverySpec,
DiscoveryStream, MetadataSnapshot,
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryMetadata,
DiscoveryQuery, DiscoverySpec, DiscoveryStream, MetadataSnapshot,
};
use anyhow::Result;
use async_trait::async_trait;
......@@ -257,56 +257,53 @@ impl Discovery for KubeDiscoveryClient {
// Spawn task to process snapshots
tokio::spawn(async move {
// Initialize known_instances from current snapshot state
// Initialize from current snapshot state
// This is critical: watch_rx.changed() only fires on FUTURE changes,
// so we must capture the current state first to detect removals correctly
let initial_snapshot = watch_rx.borrow_and_update().clone();
let mut known_instances: HashSet<u64> = initial_snapshot
.instances
.iter()
.filter_map(|(&instance_id, metadata)| {
let filtered = metadata.filter(&query);
if !filtered.is_empty() {
Some(instance_id)
} else {
None
}
})
.collect();
// Build initial map: DiscoveryInstanceId -> DiscoveryInstance
let initial: std::collections::HashMap<DiscoveryInstanceId, DiscoveryInstance> =
initial_snapshot
.instances
.values()
.flat_map(|metadata| metadata.filter(&query))
.map(|instance| (instance.id(), instance))
.collect();
tracing::debug!(
stream_id = %stream_id,
initial_instances = known_instances.len(),
initial_count = initial.len(),
"Watch started for query={:?}",
query
);
// Emit initial Added events for all existing instances (the "list" part of list_and_watch)
for &instance_id in &known_instances {
if let Some(metadata) = initial_snapshot.instances.get(&instance_id) {
let instances = metadata.filter(&query);
for instance in instances {
tracing::info!(
stream_id = %stream_id,
instance_id = format!("{:x}", instance.instance_id()),
"Emitting initial Added event"
);
if event_tx.send(Ok(DiscoveryEvent::Added(instance))).is_err() {
tracing::debug!(
stream_id = %stream_id,
"Watch receiver dropped during initial sync"
);
return;
}
}
// Emit initial Added events (the "list" part of list_and_watch)
for instance in initial.values() {
tracing::info!(
stream_id = %stream_id,
instance_id = format!("{:x}", instance.instance_id()),
"Emitting initial Added event"
);
if event_tx
.send(Ok(DiscoveryEvent::Added(instance.clone())))
.is_err()
{
tracing::debug!(
stream_id = %stream_id,
"Watch receiver dropped during initial sync"
);
return;
}
}
// Track known instances by their unique ID
let mut known: HashSet<DiscoveryInstanceId> = initial.into_keys().collect();
loop {
tracing::trace!(
stream_id = %stream_id,
known_count = known_instances.len(),
known_count = known.len(),
"Watch loop waiting for changes"
);
......@@ -331,44 +328,35 @@ impl Discovery for KubeDiscoveryClient {
// Get latest snapshot
let snapshot = watch_rx.borrow_and_update().clone();
// Build current map: DiscoveryInstanceId -> DiscoveryInstance
let current: std::collections::HashMap<
DiscoveryInstanceId,
DiscoveryInstance,
> = snapshot
.instances
.values()
.flat_map(|metadata| metadata.filter(&query))
.map(|instance| (instance.id(), instance))
.collect();
tracing::debug!(
stream_id = %stream_id,
seq = snapshot.sequence,
snapshot_instances = snapshot.instances.len(),
known_instances = known_instances.len(),
current_count = current.len(),
known_count = known.len(),
"Watch received snapshot update"
);
// Filter snapshot by query
let current_instances: HashSet<u64> = snapshot
.instances
.iter()
.filter_map(|(&instance_id, metadata)| {
let filtered = metadata.filter(&query);
if !filtered.is_empty() {
Some(instance_id)
} else {
None
}
})
.collect();
// Compute diff using keys
let current_keys: HashSet<&DiscoveryInstanceId> = current.keys().collect();
let known_keys: HashSet<&DiscoveryInstanceId> = known.iter().collect();
tracing::trace!(
stream_id = %stream_id,
current_ids = ?current_instances.iter().map(|id| format!("{:x}", id)).collect::<Vec<_>>(),
known_ids = ?known_instances.iter().map(|id| format!("{:x}", id)).collect::<Vec<_>>(),
"Comparing instance sets"
);
// Compute diff
let added: Vec<u64> = current_instances
.difference(&known_instances)
.copied()
.collect();
let added: Vec<&DiscoveryInstanceId> =
current_keys.difference(&known_keys).copied().collect();
let removed: Vec<u64> = known_instances
.difference(&current_instances)
.copied()
let removed: Vec<DiscoveryInstanceId> = known_keys
.difference(&current_keys)
.map(|&id| id.clone())
.collect();
// Log diff results (even if empty, for debugging)
......@@ -376,8 +364,6 @@ impl Discovery for KubeDiscoveryClient {
tracing::debug!(
stream_id = %stream_id,
seq = snapshot.sequence,
current_count = current_instances.len(),
known_count = known_instances.len(),
"Watch snapshot received but no diff detected"
);
} else {
......@@ -386,50 +372,47 @@ impl Discovery for KubeDiscoveryClient {
seq = snapshot.sequence,
added = added.len(),
removed = removed.len(),
total = current_instances.len(),
total = current.len(),
"Watch detected changes"
);
}
// Emit Added events
for instance_id in added {
if let Some(metadata) = snapshot.instances.get(&instance_id) {
let instances = metadata.filter(&query);
for instance in instances {
tracing::info!(
for id in added {
if let Some(instance) = current.get(id) {
tracing::info!(
stream_id = %stream_id,
instance_id = format!("{:x}", instance.instance_id()),
"Emitting Added event"
);
if event_tx
.send(Ok(DiscoveryEvent::Added(instance.clone())))
.is_err()
{
tracing::debug!(
stream_id = %stream_id,
instance_id = format!("{:x}", instance.instance_id()),
"Emitting Added event"
"Watch receiver dropped"
);
if event_tx.send(Ok(DiscoveryEvent::Added(instance))).is_err() {
tracing::debug!(
stream_id = %stream_id,
"Watch receiver dropped"
);
return;
}
return;
}
}
}
// Emit Removed events
for instance_id in removed {
for id in removed {
tracing::info!(
stream_id = %stream_id,
instance_id = format!("{:x}", instance_id),
id = ?id,
"Emitting Removed event"
);
if event_tx
.send(Ok(DiscoveryEvent::Removed(instance_id)))
.is_err()
{
if event_tx.send(Ok(DiscoveryEvent::Removed(id))).is_err() {
tracing::debug!(stream_id = %stream_id, "Watch receiver dropped");
return;
}
}
// Update known set
known_instances = current_instances;
known = current.into_keys().collect();
}
Err(_) => {
tracing::info!(
......
......@@ -10,7 +10,8 @@ use futures::{Stream, StreamExt};
use tokio_util::sync::CancellationToken;
use super::{
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream,
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
DiscoverySpec, DiscoveryStream, EndpointInstanceId, ModelCardInstanceId,
};
use crate::storage::kv;
......@@ -394,53 +395,71 @@ impl Discovery for KVStoreDiscovery {
continue;
}
// Extract instance_id from the key path, not the value
// Delete events have empty values in etcd, so we parse the instance_id from the key
// Extract DiscoveryInstanceId from the key path
// Delete events have empty values in etcd, so we reconstruct the ID from the key
//
// Key format (relative to bucket, after stripping bucket prefix):
// - Instances: "namespace/component/endpoint/{instance_id:x}"
// - Endpoints: "namespace/component/endpoint/{instance_id:x}"
// - Models: "namespace/component/endpoint/{instance_id:x}"
// - LoRA models: "namespace/component/endpoint/{instance_id:x}/{lora_slug}"
//
// The instance_id is always at index 3 in the RELATIVE key (after bucket prefix).
// Use strip_bucket_prefix for consistency with matches_prefix().
let relative_key = Self::strip_bucket_prefix(key_str, bucket_name);
let key_parts: Vec<&str> = relative_key.split('/').collect();
// In relative key: namespace/component/endpoint/{instance_id}[/{lora_slug}]
// instance_id is at index 3
let instance_id_index = 3;
match key_parts.get(instance_id_index) {
Some(instance_id_hex) => {
match u64::from_str_radix(instance_id_hex, 16) {
Ok(instance_id) => {
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Emitting Removed event for instance_id={:x}, key={}",
instance_id,
key_str
);
Some(DiscoveryEvent::Removed(instance_id))
}
Err(e) => {
tracing::warn!(
key = %key_str,
relative_key = %relative_key,
error = %e,
instance_id_hex = %instance_id_hex,
"Failed to parse instance_id hex from deleted key"
);
None
}
}
// We need at least 4 parts: namespace, component, endpoint, instance_id
if key_parts.len() < 4 {
tracing::warn!(
key = %key_str,
relative_key = %relative_key,
actual_parts = key_parts.len(),
"Delete event key doesn't have enough parts"
);
continue;
}
let namespace = key_parts[0].to_string();
let component = key_parts[1].to_string();
let endpoint = key_parts[2].to_string();
let instance_id_hex = key_parts[3];
match u64::from_str_radix(instance_id_hex, 16) {
Ok(instance_id) => {
// Construct the appropriate DiscoveryInstanceId based on bucket type
let id = if bucket_name == INSTANCES_BUCKET {
DiscoveryInstanceId::Endpoint(EndpointInstanceId {
namespace,
component,
endpoint,
instance_id,
})
} else {
// Model - check for LoRA suffix (5th part if present)
let model_suffix = key_parts.get(4).map(|s| s.to_string());
DiscoveryInstanceId::Model(ModelCardInstanceId {
namespace,
component,
endpoint,
instance_id,
model_suffix,
})
};
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Emitting Removed event for {:?}, key={}",
id,
key_str
);
Some(DiscoveryEvent::Removed(id))
}
None => {
Err(e) => {
tracing::warn!(
key = %key_str,
relative_key = %relative_key,
expected_index = instance_id_index,
actual_parts = key_parts.len(),
"Delete event key doesn't have instance_id at expected position"
error = %e,
instance_id_hex = %instance_id_hex,
"Failed to parse instance_id hex from deleted key"
);
None
}
......
......@@ -5,21 +5,15 @@ use anyhow::Result;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use super::{DiscoveryInstance, DiscoveryQuery};
/// Key for organizing metadata internally
/// Format: "namespace/component/endpoint"
fn make_endpoint_key(namespace: &str, component: &str, endpoint: &str) -> String {
format!("{namespace}/{component}/{endpoint}")
}
use super::{DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery};
/// Metadata stored on each pod and exposed via HTTP endpoint
/// This struct holds all discovery registrations for this pod instance
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct DiscoveryMetadata {
/// Registered endpoint instances (key: "namespace/component/endpoint")
/// Registered endpoint instances (key: path string from EndpointInstanceId::to_path())
endpoints: HashMap<String, DiscoveryInstance>,
/// Registered model card instances (key: "namespace/component/endpoint")
/// Registered model card instances (key: path string from ModelCardInstanceId::to_path())
model_cards: HashMap<String, DiscoveryInstance>,
}
......@@ -34,57 +28,53 @@ impl DiscoveryMetadata {
/// Register an endpoint instance
pub fn register_endpoint(&mut self, instance: DiscoveryInstance) -> Result<()> {
if let DiscoveryInstance::Endpoint(ref inst) = instance {
let key = make_endpoint_key(&inst.namespace, &inst.component, &inst.endpoint);
self.endpoints.insert(key, instance);
Ok(())
} else {
anyhow::bail!("Cannot register non-endpoint instance as endpoint")
match instance.id() {
DiscoveryInstanceId::Endpoint(key) => {
self.endpoints.insert(key.to_path(), instance);
Ok(())
}
DiscoveryInstanceId::Model(_) => {
anyhow::bail!("Cannot register non-endpoint instance as endpoint")
}
}
}
/// Register a model card instance
pub fn register_model_card(&mut self, instance: DiscoveryInstance) -> Result<()> {
if let DiscoveryInstance::Model {
ref namespace,
ref component,
ref endpoint,
..
} = instance
{
let key = make_endpoint_key(namespace, component, endpoint);
self.model_cards.insert(key, instance);
Ok(())
} else {
anyhow::bail!("Cannot register non-model-card instance as model card")
match instance.id() {
DiscoveryInstanceId::Model(key) => {
self.model_cards.insert(key.to_path(), instance);
Ok(())
}
DiscoveryInstanceId::Endpoint(_) => {
anyhow::bail!("Cannot register non-model-card instance as model card")
}
}
}
/// Unregister an endpoint instance
pub fn unregister_endpoint(&mut self, instance: &DiscoveryInstance) -> Result<()> {
if let DiscoveryInstance::Endpoint(inst) = instance {
let key = make_endpoint_key(&inst.namespace, &inst.component, &inst.endpoint);
self.endpoints.remove(&key);
Ok(())
} else {
anyhow::bail!("Cannot unregister non-endpoint instance as endpoint")
match instance.id() {
DiscoveryInstanceId::Endpoint(key) => {
self.endpoints.remove(&key.to_path());
Ok(())
}
DiscoveryInstanceId::Model(_) => {
anyhow::bail!("Cannot unregister non-endpoint instance as endpoint")
}
}
}
/// Unregister a model card instance
pub fn unregister_model_card(&mut self, instance: &DiscoveryInstance) -> Result<()> {
if let DiscoveryInstance::Model {
namespace,
component,
endpoint,
..
} = instance
{
let key = make_endpoint_key(namespace, component, endpoint);
self.model_cards.remove(&key);
Ok(())
} else {
anyhow::bail!("Cannot unregister non-model-card instance as model card")
match instance.id() {
DiscoveryInstanceId::Model(key) => {
self.model_cards.remove(&key.to_path());
Ok(())
}
DiscoveryInstanceId::Endpoint(_) => {
anyhow::bail!("Cannot unregister non-model-card instance as model card")
}
}
}
......
......@@ -2,7 +2,8 @@
// SPDX-License-Identifier: Apache-2.0
use super::{
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream,
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
DiscoverySpec, DiscoveryStream,
};
use anyhow::Result;
use async_trait::async_trait;
......@@ -171,7 +172,7 @@ impl Discovery for MockDiscovery {
let registry = self.registry.clone();
let stream = async_stream::stream! {
let mut known_instances = HashSet::new();
let mut known_instances: HashSet<DiscoveryInstanceId> = HashSet::new();
loop {
let current: Vec<_> = {
......@@ -183,19 +184,11 @@ impl Discovery for MockDiscovery {
.collect()
};
let current_ids: HashSet<_> = current.iter().map(|i| {
match i {
DiscoveryInstance::Endpoint(inst) => inst.instance_id,
DiscoveryInstance::Model { instance_id, .. } => *instance_id,
}
}).collect();
let current_ids: HashSet<DiscoveryInstanceId> = current.iter().map(|i| i.id()).collect();
// Emit Added events for new instances
for instance in current {
let id = match &instance {
DiscoveryInstance::Endpoint(inst) => inst.instance_id,
DiscoveryInstance::Model { instance_id, .. } => *instance_id,
};
let id = instance.id();
if known_instances.insert(id) {
yield Ok(DiscoveryEvent::Added(instance));
}
......@@ -203,8 +196,8 @@ impl Discovery for MockDiscovery {
// Emit Removed events for instances that are gone
for id in known_instances.difference(&current_ids).cloned().collect::<Vec<_>>() {
yield Ok(DiscoveryEvent::Removed(id));
known_instances.remove(&id);
yield Ok(DiscoveryEvent::Removed(id));
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
......@@ -272,8 +265,9 @@ mod tests {
let event = stream.next().await.unwrap().unwrap();
match event {
DiscoveryEvent::Removed(instance_id) => {
assert_eq!(instance_id, 1);
DiscoveryEvent::Removed(id) => {
let endpoint_id = id.extract_endpoint_id().expect("Expected endpoint removal");
assert_eq!(endpoint_id.instance_id, 1);
}
_ => panic!("Expected Removed event for instance-1"),
}
......
......@@ -198,6 +198,150 @@ impl DiscoveryInstance {
}
}
}
/// Extracts the unique identifier for this discovery instance
/// Used for tracking, diffing, and removal events
pub fn id(&self) -> DiscoveryInstanceId {
match self {
Self::Endpoint(inst) => DiscoveryInstanceId::Endpoint(EndpointInstanceId {
namespace: inst.namespace.clone(),
component: inst.component.clone(),
endpoint: inst.endpoint.clone(),
instance_id: inst.instance_id,
}),
Self::Model {
namespace,
component,
endpoint,
instance_id,
model_suffix,
..
} => DiscoveryInstanceId::Model(ModelCardInstanceId {
namespace: namespace.clone(),
component: component.clone(),
endpoint: endpoint.clone(),
instance_id: *instance_id,
model_suffix: model_suffix.clone(),
}),
}
}
}
/// Unique identifier for an endpoint instance
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct EndpointInstanceId {
pub namespace: String,
pub component: String,
pub endpoint: String,
pub instance_id: u64,
}
impl EndpointInstanceId {
/// Converts to a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}`
pub fn to_path(&self) -> String {
format!(
"{}/{}/{}/{:x}",
self.namespace, self.component, self.endpoint, self.instance_id
)
}
/// Parses from a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}`
pub fn from_path(path: &str) -> Result<Self> {
let parts: Vec<&str> = path.split('/').collect();
if parts.len() != 4 {
anyhow::bail!(
"Invalid EndpointInstanceId path: expected 4 parts, got {}",
parts.len()
);
}
Ok(Self {
namespace: parts[0].to_string(),
component: parts[1].to_string(),
endpoint: parts[2].to_string(),
instance_id: u64::from_str_radix(parts[3], 16)
.map_err(|e| anyhow::anyhow!("Invalid instance_id hex: {}", e))?,
})
}
}
/// Unique identifier for a model card instance
/// The combination of (namespace, component, endpoint, instance_id, model_suffix) uniquely identifies a model card
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelCardInstanceId {
pub namespace: String,
pub component: String,
pub endpoint: String,
pub instance_id: u64,
/// None for base models, Some(slug) for LoRA adapters
pub model_suffix: Option<String>,
}
impl ModelCardInstanceId {
/// Converts to a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}[/{model_suffix}]`
pub fn to_path(&self) -> String {
match &self.model_suffix {
Some(suffix) => format!(
"{}/{}/{}/{:x}/{}",
self.namespace, self.component, self.endpoint, self.instance_id, suffix
),
None => format!(
"{}/{}/{}/{:x}",
self.namespace, self.component, self.endpoint, self.instance_id
),
}
}
/// Parses from a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}[/{model_suffix}]`
pub fn from_path(path: &str) -> Result<Self> {
let parts: Vec<&str> = path.split('/').collect();
if parts.len() < 4 || parts.len() > 5 {
anyhow::bail!(
"Invalid ModelCardInstanceId path: expected 4 or 5 parts, got {}",
parts.len()
);
}
Ok(Self {
namespace: parts[0].to_string(),
component: parts[1].to_string(),
endpoint: parts[2].to_string(),
instance_id: u64::from_str_radix(parts[3], 16)
.map_err(|e| anyhow::anyhow!("Invalid instance_id hex: {}", e))?,
model_suffix: parts.get(4).map(|s| s.to_string()),
})
}
}
/// Union of instance identifiers for different discovery object types
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DiscoveryInstanceId {
Endpoint(EndpointInstanceId),
Model(ModelCardInstanceId),
}
impl DiscoveryInstanceId {
/// Returns the raw instance_id regardless of variant type
pub fn instance_id(&self) -> u64 {
match self {
Self::Endpoint(eid) => eid.instance_id,
Self::Model(mid) => mid.instance_id,
}
}
/// Extracts the EndpointInstanceId, returning an error if this is a Model variant
pub fn extract_endpoint_id(&self) -> Result<&EndpointInstanceId> {
match self {
Self::Endpoint(eid) => Ok(eid),
Self::Model(_) => anyhow::bail!("Expected Endpoint variant, got Model"),
}
}
/// Extracts the ModelCardInstanceId, returning an error if this is an Endpoint variant
pub fn extract_model_id(&self) -> Result<&ModelCardInstanceId> {
match self {
Self::Model(mid) => Ok(mid),
Self::Endpoint(_) => anyhow::bail!("Expected Model variant, got Endpoint"),
}
}
}
/// Events emitted by the discovery watch stream
......@@ -205,8 +349,8 @@ impl DiscoveryInstance {
pub enum DiscoveryEvent {
/// A new instance was added
Added(DiscoveryInstance),
/// An instance was removed (identified by instance_id)
Removed(u64),
/// An instance was removed (identified by its unique ID)
Removed(DiscoveryInstanceId),
}
/// Stream type for discovery events
......
......@@ -84,9 +84,9 @@ where
break;
}
}
Ok(DiscoveryEvent::Removed(instance_id)) => {
Ok(DiscoveryEvent::Removed(id)) => {
// Remove from state and send update
state.remove(&instance_id);
state.remove(&id.instance_id());
if tx.send(state.clone()).is_err() {
tracing::debug!("watch_and_extract_field receiver dropped, stopping");
break;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment