Unverified Commit 2f9812aa authored by mohammedabdulwahhab's avatar mohammedabdulwahhab Committed by GitHub
Browse files

fix: fix bug in diffing logic in list_and_watch (#5318)


Signed-off-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
parent 91eb0ed8
...@@ -1306,12 +1306,12 @@ fn spawn_prefill_watcher( ...@@ -1306,12 +1306,12 @@ fn spawn_prefill_watcher(
} }
} }
} }
DiscoveryEvent::Removed(instance_id) => { DiscoveryEvent::Removed(id) => {
// Log removal for observability // Log removal for observability
// Note: The PrefillRouter remains active - worker availability // Note: The PrefillRouter remains active - worker availability
// is handled dynamically by the underlying Client's instance tracking // is handled dynamically by the underlying Client's instance tracking
tracing::debug!( tracing::debug!(
instance_id = instance_id, instance_id = id.instance_id(),
"Prefill worker instance removed from discovery" "Prefill worker instance removed from discovery"
); );
} }
......
...@@ -320,14 +320,14 @@ impl ModelManager { ...@@ -320,14 +320,14 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string())) .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
} }
/// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is /// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is
/// deleted. /// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> { pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
self.cards.lock().insert(key.to_string(), card); self.cards.lock().insert(key.to_string(), card);
Ok(()) Ok(())
} }
/// Remove and return model card for this instance's etcd key. We do this when the instance stops. /// Remove and return model card for this instance's key. We do this when the instance stops.
pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> { pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
self.cards.lock().remove(key) self.cards.lock().remove(key)
} }
......
...@@ -10,7 +10,10 @@ use futures::StreamExt; ...@@ -10,7 +10,10 @@ use futures::StreamExt;
use dynamo_runtime::{ use dynamo_runtime::{
DistributedRuntime, DistributedRuntime,
discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoveryStream}, discovery::{
DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery, DiscoveryStream,
ModelCardInstanceId,
},
pipeline::{ pipeline::{
ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source, ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
network::egress::push_router::PushRouter, network::egress::push_router::PushRouter,
...@@ -119,23 +122,26 @@ impl ModelWatcher { ...@@ -119,23 +122,26 @@ impl ModelWatcher {
match event { match event {
DiscoveryEvent::Added(instance) => { DiscoveryEvent::Added(instance) => {
// Extract EndpointId, instance_id, and card from the discovery instance // Extract ModelCardInstanceId and card from the discovery instance
let (endpoint_id, instance_id, mut card) = match &instance { let (mcid, mut card) = match &instance {
DiscoveryInstance::Model { DiscoveryInstance::Model {
namespace, namespace,
component, component,
endpoint, endpoint,
instance_id, instance_id,
model_suffix,
.. ..
} => { } => {
let eid = EndpointId { let mcid = ModelCardInstanceId {
namespace: namespace.clone(), namespace: namespace.clone(),
component: component.clone(), component: component.clone(),
name: endpoint.clone(), endpoint: endpoint.clone(),
instance_id: *instance_id,
model_suffix: model_suffix.clone(),
}; };
match instance.deserialize_model::<ModelDeploymentCard>() { match instance.deserialize_model::<ModelDeploymentCard>() {
Ok(card) => (eid, *instance_id, card), Ok(card) => (mcid, card),
Err(err) => { Err(err) => {
tracing::error!(%err, instance_id, "Failed to deserialize model card"); tracing::error!(%err, instance_id, "Failed to deserialize model card");
continue; continue;
...@@ -153,10 +159,10 @@ impl ModelWatcher { ...@@ -153,10 +159,10 @@ impl ModelWatcher {
// Filter by namespace if target_namespace is specified // Filter by namespace if target_namespace is specified
if !global_namespace if !global_namespace
&& let Some(target_ns) = target_namespace && let Some(target_ns) = target_namespace
&& endpoint_id.namespace != target_ns && mcid.namespace != target_ns
{ {
tracing::debug!( tracing::debug!(
model_namespace = endpoint_id.namespace, model_namespace = mcid.namespace,
target_namespace = target_ns, target_namespace = target_ns,
"Skipping model from different namespace" "Skipping model from different namespace"
); );
...@@ -185,14 +191,11 @@ impl ModelWatcher { ...@@ -185,14 +191,11 @@ impl ModelWatcher {
continue; continue;
} }
// Use instance_id as the HashMap key (simpler and sufficient since keys are opaque) match self.handle_put(&mcid, &mut card).await {
let key = format!("{:x}", instance_id);
match self.handle_put(&key, &endpoint_id, &mut card).await {
Ok(()) => { Ok(()) => {
tracing::info!( tracing::info!(
model_name = card.name(), model_name = card.name(),
namespace = endpoint_id.namespace, namespace = mcid.namespace,
"added model" "added model"
); );
self.notify_on_model.notify_waiters(); self.notify_on_model.notify_waiters();
...@@ -200,19 +203,27 @@ impl ModelWatcher { ...@@ -200,19 +203,27 @@ impl ModelWatcher {
Err(err) => { Err(err) => {
tracing::error!( tracing::error!(
model_name = card.name(), model_name = card.name(),
namespace = endpoint_id.namespace, namespace = mcid.namespace,
error = format!("{err:#}"), error = format!("{err:#}"),
"Error adding model from discovery", "Error adding model from discovery",
); );
} }
} }
} }
DiscoveryEvent::Removed(instance_id) => { DiscoveryEvent::Removed(id) => {
// Use instance_id hex as the HashMap key (matches what we saved with) // Extract ModelCardInstanceId from the removal event
let key = format!("{:x}", instance_id); let model_card_instance_id = match &id {
DiscoveryInstanceId::Model(mcid) => mcid,
DiscoveryInstanceId::Endpoint(_) => {
tracing::error!(
"Unexpected discovery instance type in removal (expected Model)"
);
continue;
}
};
match self match self
.handle_delete(&key, target_namespace, global_namespace) .handle_delete(model_card_instance_id, target_namespace, global_namespace)
.await .await
{ {
Ok(Some(model_name)) => { Ok(Some(model_name)) => {
...@@ -234,14 +245,15 @@ impl ModelWatcher { ...@@ -234,14 +245,15 @@ impl ModelWatcher {
/// Returns the name of the model we just deleted, if any. /// Returns the name of the model we just deleted, if any.
async fn handle_delete( async fn handle_delete(
&self, &self,
key: &str, mcid: &ModelCardInstanceId,
target_namespace: Option<&str>, target_namespace: Option<&str>,
is_global_namespace: bool, is_global_namespace: bool,
) -> anyhow::Result<Option<String>> { ) -> anyhow::Result<Option<String>> {
let card = match self.manager.remove_model_card(key) { let key = mcid.to_path();
let card = match self.manager.remove_model_card(&key) {
Some(card) => card, Some(card) => card,
None => { None => {
anyhow::bail!("Missing ModelDeploymentCard for {key}"); anyhow::bail!("Missing ModelDeploymentCard for {}", key);
} }
}; };
let model_name = card.name().to_string(); let model_name = card.name().to_string();
...@@ -325,20 +337,20 @@ impl ModelWatcher { ...@@ -325,20 +337,20 @@ impl ModelWatcher {
// models. // models.
async fn handle_put( async fn handle_put(
&self, &self,
key: &str, mcid: &ModelCardInstanceId,
endpoint_id: &EndpointId,
card: &mut ModelDeploymentCard, card: &mut ModelDeploymentCard,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
card.download_config().await?; card.download_config().await?;
let component = self let component = self
.drt .drt
.namespace(&endpoint_id.namespace)? .namespace(&mcid.namespace)?
.component(&endpoint_id.component)?; .component(&mcid.component)?;
let endpoint = component.endpoint(&endpoint_id.name); let endpoint = component.endpoint(&mcid.endpoint);
let client = endpoint.client().await?; let client = endpoint.client().await?;
tracing::debug!(model_name = card.name(), "adding model"); tracing::debug!(model_name = card.name(), "adding model");
self.manager.save_model_card(key, card.clone())?; self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
// Skip duplicate registrations based on model type. // Skip duplicate registrations based on model type.
// Prefill and decode models are tracked separately, so registering one // Prefill and decode models are tracked separately, so registering one
...@@ -352,7 +364,7 @@ impl ModelWatcher { ...@@ -352,7 +364,7 @@ impl ModelWatcher {
if already_registered { if already_registered {
tracing::debug!( tracing::debug!(
model_name = card.name(), model_name = card.name(),
namespace = endpoint_id.namespace, namespace = mcid.namespace,
model_type = %card.model_type, model_type = %card.model_type,
"Model already registered, skipping" "Model already registered, skipping"
); );
...@@ -372,7 +384,7 @@ impl ModelWatcher { ...@@ -372,7 +384,7 @@ impl ModelWatcher {
// A model that expects pre-processed requests meaning it's up to us whether we // A model that expects pre-processed requests meaning it's up to us whether we
// handle Chat or Completions requests, so handle whatever the model supports. // handle Chat or Completions requests, so handle whatever the model supports.
let endpoint = component.endpoint(&endpoint_id.name); let endpoint = component.endpoint(&mcid.endpoint);
let kv_chooser = if self.router_config.router_mode == RouterMode::KV { let kv_chooser = if self.router_config.router_mode == RouterMode::KV {
Some( Some(
self.manager self.manager
......
...@@ -564,10 +564,12 @@ pub async fn start_kv_router_background( ...@@ -564,10 +564,12 @@ pub async fn start_kv_router_background(
continue; continue;
}; };
let DiscoveryEvent::Removed(worker_id) = discovery_event else { let DiscoveryEvent::Removed(id) = discovery_event else {
continue; continue;
}; };
let worker_id = id.instance_id();
tracing::warn!( tracing::warn!(
"DISCOVERY: Generate endpoint instance removed, removing worker {worker_id}" "DISCOVERY: Generate endpoint instance removed, removing worker {worker_id}"
); );
...@@ -642,11 +644,13 @@ pub async fn start_kv_router_background( ...@@ -642,11 +644,13 @@ pub async fn start_kv_router_background(
continue; continue;
}; };
let DiscoveryEvent::Removed(router_instance_id) = router_event else { let DiscoveryEvent::Removed(id) = router_event else {
// We only care about removals for cleaning up consumers // We only care about removals for cleaning up consumers
continue; continue;
}; };
let router_instance_id = id.instance_id();
// The consumer UUID is the instance_id in hex format // The consumer UUID is the instance_id in hex format
let consumer_to_delete = router_instance_id.to_string(); let consumer_to_delete = router_instance_id.to_string();
...@@ -708,7 +712,8 @@ async fn handle_worker_discovery( ...@@ -708,7 +712,8 @@ async fn handle_worker_discovery(
} }
} }
} }
DiscoveryEvent::Removed(worker_id) => { DiscoveryEvent::Removed(id) => {
let worker_id = id.instance_id();
tracing::warn!("DISCOVERY: Worker {worker_id} removed, removing from router indexer"); tracing::warn!("DISCOVERY: Worker {worker_id} removed, removing from router indexer");
if let Err(e) = remove_worker_tx.send(worker_id).await { if let Err(e) = remove_worker_tx.send(worker_id).await {
......
...@@ -9,7 +9,7 @@ use arc_swap::ArcSwap; ...@@ -9,7 +9,7 @@ use arc_swap::ArcSwap;
use futures::StreamExt; use futures::StreamExt;
use tokio::net::unix::pipe::Receiver; use tokio::net::unix::pipe::Receiver;
use crate::discovery::{DiscoveryEvent, DiscoveryInstance}; use crate::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId};
use crate::{ use crate::{
component::{Endpoint, Instance}, component::{Endpoint, Instance},
pipeline::async_trait, pipeline::async_trait,
...@@ -255,8 +255,8 @@ impl Client { ...@@ -255,8 +255,8 @@ impl Client {
map.insert(instance.instance_id, instance); map.insert(instance.instance_id, instance);
} }
} }
DiscoveryEvent::Removed(instance_id) => { DiscoveryEvent::Removed(id) => {
map.remove(&instance_id); map.remove(&id.instance_id());
} }
} }
......
...@@ -14,8 +14,8 @@ use utils::PodInfo; ...@@ -14,8 +14,8 @@ use utils::PodInfo;
use crate::CancellationToken; use crate::CancellationToken;
use crate::discovery::{ use crate::discovery::{
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryMetadata, DiscoveryQuery, DiscoverySpec, Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryMetadata,
DiscoveryStream, MetadataSnapshot, DiscoveryQuery, DiscoverySpec, DiscoveryStream, MetadataSnapshot,
}; };
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
...@@ -257,56 +257,53 @@ impl Discovery for KubeDiscoveryClient { ...@@ -257,56 +257,53 @@ impl Discovery for KubeDiscoveryClient {
// Spawn task to process snapshots // Spawn task to process snapshots
tokio::spawn(async move { tokio::spawn(async move {
// Initialize known_instances from current snapshot state // Initialize from current snapshot state
// This is critical: watch_rx.changed() only fires on FUTURE changes, // This is critical: watch_rx.changed() only fires on FUTURE changes,
// so we must capture the current state first to detect removals correctly // so we must capture the current state first to detect removals correctly
let initial_snapshot = watch_rx.borrow_and_update().clone(); let initial_snapshot = watch_rx.borrow_and_update().clone();
let mut known_instances: HashSet<u64> = initial_snapshot // Build initial map: DiscoveryInstanceId -> DiscoveryInstance
.instances let initial: std::collections::HashMap<DiscoveryInstanceId, DiscoveryInstance> =
.iter() initial_snapshot
.filter_map(|(&instance_id, metadata)| { .instances
let filtered = metadata.filter(&query); .values()
if !filtered.is_empty() { .flat_map(|metadata| metadata.filter(&query))
Some(instance_id) .map(|instance| (instance.id(), instance))
} else { .collect();
None
}
})
.collect();
tracing::debug!( tracing::debug!(
stream_id = %stream_id, stream_id = %stream_id,
initial_instances = known_instances.len(), initial_count = initial.len(),
"Watch started for query={:?}", "Watch started for query={:?}",
query query
); );
// Emit initial Added events for all existing instances (the "list" part of list_and_watch) // Emit initial Added events (the "list" part of list_and_watch)
for &instance_id in &known_instances { for instance in initial.values() {
if let Some(metadata) = initial_snapshot.instances.get(&instance_id) { tracing::info!(
let instances = metadata.filter(&query); stream_id = %stream_id,
for instance in instances { instance_id = format!("{:x}", instance.instance_id()),
tracing::info!( "Emitting initial Added event"
stream_id = %stream_id, );
instance_id = format!("{:x}", instance.instance_id()), if event_tx
"Emitting initial Added event" .send(Ok(DiscoveryEvent::Added(instance.clone())))
); .is_err()
if event_tx.send(Ok(DiscoveryEvent::Added(instance))).is_err() { {
tracing::debug!( tracing::debug!(
stream_id = %stream_id, stream_id = %stream_id,
"Watch receiver dropped during initial sync" "Watch receiver dropped during initial sync"
); );
return; return;
}
}
} }
} }
// Track known instances by their unique ID
let mut known: HashSet<DiscoveryInstanceId> = initial.into_keys().collect();
loop { loop {
tracing::trace!( tracing::trace!(
stream_id = %stream_id, stream_id = %stream_id,
known_count = known_instances.len(), known_count = known.len(),
"Watch loop waiting for changes" "Watch loop waiting for changes"
); );
...@@ -331,44 +328,35 @@ impl Discovery for KubeDiscoveryClient { ...@@ -331,44 +328,35 @@ impl Discovery for KubeDiscoveryClient {
// Get latest snapshot // Get latest snapshot
let snapshot = watch_rx.borrow_and_update().clone(); let snapshot = watch_rx.borrow_and_update().clone();
// Build current map: DiscoveryInstanceId -> DiscoveryInstance
let current: std::collections::HashMap<
DiscoveryInstanceId,
DiscoveryInstance,
> = snapshot
.instances
.values()
.flat_map(|metadata| metadata.filter(&query))
.map(|instance| (instance.id(), instance))
.collect();
tracing::debug!( tracing::debug!(
stream_id = %stream_id, stream_id = %stream_id,
seq = snapshot.sequence, seq = snapshot.sequence,
snapshot_instances = snapshot.instances.len(), current_count = current.len(),
known_instances = known_instances.len(), known_count = known.len(),
"Watch received snapshot update" "Watch received snapshot update"
); );
// Filter snapshot by query // Compute diff using keys
let current_instances: HashSet<u64> = snapshot let current_keys: HashSet<&DiscoveryInstanceId> = current.keys().collect();
.instances let known_keys: HashSet<&DiscoveryInstanceId> = known.iter().collect();
.iter()
.filter_map(|(&instance_id, metadata)| {
let filtered = metadata.filter(&query);
if !filtered.is_empty() {
Some(instance_id)
} else {
None
}
})
.collect();
tracing::trace!( let added: Vec<&DiscoveryInstanceId> =
stream_id = %stream_id, current_keys.difference(&known_keys).copied().collect();
current_ids = ?current_instances.iter().map(|id| format!("{:x}", id)).collect::<Vec<_>>(),
known_ids = ?known_instances.iter().map(|id| format!("{:x}", id)).collect::<Vec<_>>(),
"Comparing instance sets"
);
// Compute diff
let added: Vec<u64> = current_instances
.difference(&known_instances)
.copied()
.collect();
let removed: Vec<u64> = known_instances let removed: Vec<DiscoveryInstanceId> = known_keys
.difference(&current_instances) .difference(&current_keys)
.copied() .map(|&id| id.clone())
.collect(); .collect();
// Log diff results (even if empty, for debugging) // Log diff results (even if empty, for debugging)
...@@ -376,8 +364,6 @@ impl Discovery for KubeDiscoveryClient { ...@@ -376,8 +364,6 @@ impl Discovery for KubeDiscoveryClient {
tracing::debug!( tracing::debug!(
stream_id = %stream_id, stream_id = %stream_id,
seq = snapshot.sequence, seq = snapshot.sequence,
current_count = current_instances.len(),
known_count = known_instances.len(),
"Watch snapshot received but no diff detected" "Watch snapshot received but no diff detected"
); );
} else { } else {
...@@ -386,50 +372,47 @@ impl Discovery for KubeDiscoveryClient { ...@@ -386,50 +372,47 @@ impl Discovery for KubeDiscoveryClient {
seq = snapshot.sequence, seq = snapshot.sequence,
added = added.len(), added = added.len(),
removed = removed.len(), removed = removed.len(),
total = current_instances.len(), total = current.len(),
"Watch detected changes" "Watch detected changes"
); );
} }
// Emit Added events // Emit Added events
for instance_id in added { for id in added {
if let Some(metadata) = snapshot.instances.get(&instance_id) { if let Some(instance) = current.get(id) {
let instances = metadata.filter(&query); tracing::info!(
for instance in instances { stream_id = %stream_id,
tracing::info!( instance_id = format!("{:x}", instance.instance_id()),
"Emitting Added event"
);
if event_tx
.send(Ok(DiscoveryEvent::Added(instance.clone())))
.is_err()
{
tracing::debug!(
stream_id = %stream_id, stream_id = %stream_id,
instance_id = format!("{:x}", instance.instance_id()), "Watch receiver dropped"
"Emitting Added event"
); );
if event_tx.send(Ok(DiscoveryEvent::Added(instance))).is_err() { return;
tracing::debug!(
stream_id = %stream_id,
"Watch receiver dropped"
);
return;
}
} }
} }
} }
// Emit Removed events // Emit Removed events
for instance_id in removed { for id in removed {
tracing::info!( tracing::info!(
stream_id = %stream_id, stream_id = %stream_id,
instance_id = format!("{:x}", instance_id), id = ?id,
"Emitting Removed event" "Emitting Removed event"
); );
if event_tx if event_tx.send(Ok(DiscoveryEvent::Removed(id))).is_err() {
.send(Ok(DiscoveryEvent::Removed(instance_id)))
.is_err()
{
tracing::debug!(stream_id = %stream_id, "Watch receiver dropped"); tracing::debug!(stream_id = %stream_id, "Watch receiver dropped");
return; return;
} }
} }
// Update known set // Update known set
known_instances = current_instances; known = current.into_keys().collect();
} }
Err(_) => { Err(_) => {
tracing::info!( tracing::info!(
......
...@@ -10,7 +10,8 @@ use futures::{Stream, StreamExt}; ...@@ -10,7 +10,8 @@ use futures::{Stream, StreamExt};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use super::{ use super::{
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream, Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
DiscoverySpec, DiscoveryStream, EndpointInstanceId, ModelCardInstanceId,
}; };
use crate::storage::kv; use crate::storage::kv;
...@@ -394,53 +395,71 @@ impl Discovery for KVStoreDiscovery { ...@@ -394,53 +395,71 @@ impl Discovery for KVStoreDiscovery {
continue; continue;
} }
// Extract instance_id from the key path, not the value // Extract DiscoveryInstanceId from the key path
// Delete events have empty values in etcd, so we parse the instance_id from the key // Delete events have empty values in etcd, so we reconstruct the ID from the key
// //
// Key format (relative to bucket, after stripping bucket prefix): // Key format (relative to bucket, after stripping bucket prefix):
// - Instances: "namespace/component/endpoint/{instance_id:x}" // - Endpoints: "namespace/component/endpoint/{instance_id:x}"
// - Models: "namespace/component/endpoint/{instance_id:x}" // - Models: "namespace/component/endpoint/{instance_id:x}"
// - LoRA models: "namespace/component/endpoint/{instance_id:x}/{lora_slug}" // - LoRA models: "namespace/component/endpoint/{instance_id:x}/{lora_slug}"
// //
// The instance_id is always at index 3 in the RELATIVE key (after bucket prefix).
// Use strip_bucket_prefix for consistency with matches_prefix(). // Use strip_bucket_prefix for consistency with matches_prefix().
let relative_key = Self::strip_bucket_prefix(key_str, bucket_name); let relative_key = Self::strip_bucket_prefix(key_str, bucket_name);
let key_parts: Vec<&str> = relative_key.split('/').collect(); let key_parts: Vec<&str> = relative_key.split('/').collect();
// In relative key: namespace/component/endpoint/{instance_id}[/{lora_slug}] // In relative key: namespace/component/endpoint/{instance_id}[/{lora_slug}]
// instance_id is at index 3 // We need at least 4 parts: namespace, component, endpoint, instance_id
let instance_id_index = 3; if key_parts.len() < 4 {
tracing::warn!(
match key_parts.get(instance_id_index) { key = %key_str,
Some(instance_id_hex) => { relative_key = %relative_key,
match u64::from_str_radix(instance_id_hex, 16) { actual_parts = key_parts.len(),
Ok(instance_id) => { "Delete event key doesn't have enough parts"
tracing::debug!( );
"KVStoreDiscovery::list_and_watch: Emitting Removed event for instance_id={:x}, key={}", continue;
instance_id, }
key_str
); let namespace = key_parts[0].to_string();
Some(DiscoveryEvent::Removed(instance_id)) let component = key_parts[1].to_string();
} let endpoint = key_parts[2].to_string();
Err(e) => { let instance_id_hex = key_parts[3];
tracing::warn!(
key = %key_str, match u64::from_str_radix(instance_id_hex, 16) {
relative_key = %relative_key, Ok(instance_id) => {
error = %e, // Construct the appropriate DiscoveryInstanceId based on bucket type
instance_id_hex = %instance_id_hex, let id = if bucket_name == INSTANCES_BUCKET {
"Failed to parse instance_id hex from deleted key" DiscoveryInstanceId::Endpoint(EndpointInstanceId {
); namespace,
None component,
} endpoint,
} instance_id,
})
} else {
// Model - check for LoRA suffix (5th part if present)
let model_suffix = key_parts.get(4).map(|s| s.to_string());
DiscoveryInstanceId::Model(ModelCardInstanceId {
namespace,
component,
endpoint,
instance_id,
model_suffix,
})
};
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Emitting Removed event for {:?}, key={}",
id,
key_str
);
Some(DiscoveryEvent::Removed(id))
} }
None => { Err(e) => {
tracing::warn!( tracing::warn!(
key = %key_str, key = %key_str,
relative_key = %relative_key, relative_key = %relative_key,
expected_index = instance_id_index, error = %e,
actual_parts = key_parts.len(), instance_id_hex = %instance_id_hex,
"Delete event key doesn't have instance_id at expected position" "Failed to parse instance_id hex from deleted key"
); );
None None
} }
......
...@@ -5,21 +5,15 @@ use anyhow::Result; ...@@ -5,21 +5,15 @@ use anyhow::Result;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use super::{DiscoveryInstance, DiscoveryQuery}; use super::{DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery};
/// Key for organizing metadata internally
/// Format: "namespace/component/endpoint"
fn make_endpoint_key(namespace: &str, component: &str, endpoint: &str) -> String {
format!("{namespace}/{component}/{endpoint}")
}
/// Metadata stored on each pod and exposed via HTTP endpoint /// Metadata stored on each pod and exposed via HTTP endpoint
/// This struct holds all discovery registrations for this pod instance /// This struct holds all discovery registrations for this pod instance
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct DiscoveryMetadata { pub struct DiscoveryMetadata {
/// Registered endpoint instances (key: "namespace/component/endpoint") /// Registered endpoint instances (key: path string from EndpointInstanceId::to_path())
endpoints: HashMap<String, DiscoveryInstance>, endpoints: HashMap<String, DiscoveryInstance>,
/// Registered model card instances (key: "namespace/component/endpoint") /// Registered model card instances (key: path string from ModelCardInstanceId::to_path())
model_cards: HashMap<String, DiscoveryInstance>, model_cards: HashMap<String, DiscoveryInstance>,
} }
...@@ -34,57 +28,53 @@ impl DiscoveryMetadata { ...@@ -34,57 +28,53 @@ impl DiscoveryMetadata {
/// Register an endpoint instance /// Register an endpoint instance
pub fn register_endpoint(&mut self, instance: DiscoveryInstance) -> Result<()> { pub fn register_endpoint(&mut self, instance: DiscoveryInstance) -> Result<()> {
if let DiscoveryInstance::Endpoint(ref inst) = instance { match instance.id() {
let key = make_endpoint_key(&inst.namespace, &inst.component, &inst.endpoint); DiscoveryInstanceId::Endpoint(key) => {
self.endpoints.insert(key, instance); self.endpoints.insert(key.to_path(), instance);
Ok(()) Ok(())
} else { }
anyhow::bail!("Cannot register non-endpoint instance as endpoint") DiscoveryInstanceId::Model(_) => {
anyhow::bail!("Cannot register non-endpoint instance as endpoint")
}
} }
} }
/// Register a model card instance /// Register a model card instance
pub fn register_model_card(&mut self, instance: DiscoveryInstance) -> Result<()> { pub fn register_model_card(&mut self, instance: DiscoveryInstance) -> Result<()> {
if let DiscoveryInstance::Model { match instance.id() {
ref namespace, DiscoveryInstanceId::Model(key) => {
ref component, self.model_cards.insert(key.to_path(), instance);
ref endpoint, Ok(())
.. }
} = instance DiscoveryInstanceId::Endpoint(_) => {
{ anyhow::bail!("Cannot register non-model-card instance as model card")
let key = make_endpoint_key(namespace, component, endpoint); }
self.model_cards.insert(key, instance);
Ok(())
} else {
anyhow::bail!("Cannot register non-model-card instance as model card")
} }
} }
/// Unregister an endpoint instance /// Unregister an endpoint instance
pub fn unregister_endpoint(&mut self, instance: &DiscoveryInstance) -> Result<()> { pub fn unregister_endpoint(&mut self, instance: &DiscoveryInstance) -> Result<()> {
if let DiscoveryInstance::Endpoint(inst) = instance { match instance.id() {
let key = make_endpoint_key(&inst.namespace, &inst.component, &inst.endpoint); DiscoveryInstanceId::Endpoint(key) => {
self.endpoints.remove(&key); self.endpoints.remove(&key.to_path());
Ok(()) Ok(())
} else { }
anyhow::bail!("Cannot unregister non-endpoint instance as endpoint") DiscoveryInstanceId::Model(_) => {
anyhow::bail!("Cannot unregister non-endpoint instance as endpoint")
}
} }
} }
/// Unregister a model card instance /// Unregister a model card instance
pub fn unregister_model_card(&mut self, instance: &DiscoveryInstance) -> Result<()> { pub fn unregister_model_card(&mut self, instance: &DiscoveryInstance) -> Result<()> {
if let DiscoveryInstance::Model { match instance.id() {
namespace, DiscoveryInstanceId::Model(key) => {
component, self.model_cards.remove(&key.to_path());
endpoint, Ok(())
.. }
} = instance DiscoveryInstanceId::Endpoint(_) => {
{ anyhow::bail!("Cannot unregister non-model-card instance as model card")
let key = make_endpoint_key(namespace, component, endpoint); }
self.model_cards.remove(&key);
Ok(())
} else {
anyhow::bail!("Cannot unregister non-model-card instance as model card")
} }
} }
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::{ use super::{
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream, Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
DiscoverySpec, DiscoveryStream,
}; };
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
...@@ -171,7 +172,7 @@ impl Discovery for MockDiscovery { ...@@ -171,7 +172,7 @@ impl Discovery for MockDiscovery {
let registry = self.registry.clone(); let registry = self.registry.clone();
let stream = async_stream::stream! { let stream = async_stream::stream! {
let mut known_instances = HashSet::new(); let mut known_instances: HashSet<DiscoveryInstanceId> = HashSet::new();
loop { loop {
let current: Vec<_> = { let current: Vec<_> = {
...@@ -183,19 +184,11 @@ impl Discovery for MockDiscovery { ...@@ -183,19 +184,11 @@ impl Discovery for MockDiscovery {
.collect() .collect()
}; };
let current_ids: HashSet<_> = current.iter().map(|i| { let current_ids: HashSet<DiscoveryInstanceId> = current.iter().map(|i| i.id()).collect();
match i {
DiscoveryInstance::Endpoint(inst) => inst.instance_id,
DiscoveryInstance::Model { instance_id, .. } => *instance_id,
}
}).collect();
// Emit Added events for new instances // Emit Added events for new instances
for instance in current { for instance in current {
let id = match &instance { let id = instance.id();
DiscoveryInstance::Endpoint(inst) => inst.instance_id,
DiscoveryInstance::Model { instance_id, .. } => *instance_id,
};
if known_instances.insert(id) { if known_instances.insert(id) {
yield Ok(DiscoveryEvent::Added(instance)); yield Ok(DiscoveryEvent::Added(instance));
} }
...@@ -203,8 +196,8 @@ impl Discovery for MockDiscovery { ...@@ -203,8 +196,8 @@ impl Discovery for MockDiscovery {
// Emit Removed events for instances that are gone // Emit Removed events for instances that are gone
for id in known_instances.difference(&current_ids).cloned().collect::<Vec<_>>() { for id in known_instances.difference(&current_ids).cloned().collect::<Vec<_>>() {
yield Ok(DiscoveryEvent::Removed(id));
known_instances.remove(&id); known_instances.remove(&id);
yield Ok(DiscoveryEvent::Removed(id));
} }
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
...@@ -272,8 +265,9 @@ mod tests { ...@@ -272,8 +265,9 @@ mod tests {
let event = stream.next().await.unwrap().unwrap(); let event = stream.next().await.unwrap().unwrap();
match event { match event {
DiscoveryEvent::Removed(instance_id) => { DiscoveryEvent::Removed(id) => {
assert_eq!(instance_id, 1); let endpoint_id = id.extract_endpoint_id().expect("Expected endpoint removal");
assert_eq!(endpoint_id.instance_id, 1);
} }
_ => panic!("Expected Removed event for instance-1"), _ => panic!("Expected Removed event for instance-1"),
} }
......
...@@ -198,6 +198,150 @@ impl DiscoveryInstance { ...@@ -198,6 +198,150 @@ impl DiscoveryInstance {
} }
} }
} }
/// Extracts the unique identifier for this discovery instance
/// Used for tracking, diffing, and removal events
pub fn id(&self) -> DiscoveryInstanceId {
match self {
Self::Endpoint(inst) => DiscoveryInstanceId::Endpoint(EndpointInstanceId {
namespace: inst.namespace.clone(),
component: inst.component.clone(),
endpoint: inst.endpoint.clone(),
instance_id: inst.instance_id,
}),
Self::Model {
namespace,
component,
endpoint,
instance_id,
model_suffix,
..
} => DiscoveryInstanceId::Model(ModelCardInstanceId {
namespace: namespace.clone(),
component: component.clone(),
endpoint: endpoint.clone(),
instance_id: *instance_id,
model_suffix: model_suffix.clone(),
}),
}
}
}
/// Unique identifier for an endpoint instance
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct EndpointInstanceId {
pub namespace: String,
pub component: String,
pub endpoint: String,
pub instance_id: u64,
}
impl EndpointInstanceId {
/// Converts to a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}`
pub fn to_path(&self) -> String {
format!(
"{}/{}/{}/{:x}",
self.namespace, self.component, self.endpoint, self.instance_id
)
}
/// Parses from a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}`
pub fn from_path(path: &str) -> Result<Self> {
let parts: Vec<&str> = path.split('/').collect();
if parts.len() != 4 {
anyhow::bail!(
"Invalid EndpointInstanceId path: expected 4 parts, got {}",
parts.len()
);
}
Ok(Self {
namespace: parts[0].to_string(),
component: parts[1].to_string(),
endpoint: parts[2].to_string(),
instance_id: u64::from_str_radix(parts[3], 16)
.map_err(|e| anyhow::anyhow!("Invalid instance_id hex: {}", e))?,
})
}
}
/// Unique identifier for a model card instance
/// The combination of (namespace, component, endpoint, instance_id, model_suffix) uniquely identifies a model card
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelCardInstanceId {
pub namespace: String,
pub component: String,
pub endpoint: String,
pub instance_id: u64,
/// None for base models, Some(slug) for LoRA adapters
pub model_suffix: Option<String>,
}
impl ModelCardInstanceId {
/// Converts to a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}[/{model_suffix}]`
pub fn to_path(&self) -> String {
match &self.model_suffix {
Some(suffix) => format!(
"{}/{}/{}/{:x}/{}",
self.namespace, self.component, self.endpoint, self.instance_id, suffix
),
None => format!(
"{}/{}/{}/{:x}",
self.namespace, self.component, self.endpoint, self.instance_id
),
}
}
/// Parses from a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}[/{model_suffix}]`
pub fn from_path(path: &str) -> Result<Self> {
let parts: Vec<&str> = path.split('/').collect();
if parts.len() < 4 || parts.len() > 5 {
anyhow::bail!(
"Invalid ModelCardInstanceId path: expected 4 or 5 parts, got {}",
parts.len()
);
}
Ok(Self {
namespace: parts[0].to_string(),
component: parts[1].to_string(),
endpoint: parts[2].to_string(),
instance_id: u64::from_str_radix(parts[3], 16)
.map_err(|e| anyhow::anyhow!("Invalid instance_id hex: {}", e))?,
model_suffix: parts.get(4).map(|s| s.to_string()),
})
}
}
/// Union of instance identifiers for different discovery object types
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DiscoveryInstanceId {
Endpoint(EndpointInstanceId),
Model(ModelCardInstanceId),
}
impl DiscoveryInstanceId {
/// Returns the raw instance_id regardless of variant type
pub fn instance_id(&self) -> u64 {
match self {
Self::Endpoint(eid) => eid.instance_id,
Self::Model(mid) => mid.instance_id,
}
}
/// Extracts the EndpointInstanceId, returning an error if this is a Model variant
pub fn extract_endpoint_id(&self) -> Result<&EndpointInstanceId> {
match self {
Self::Endpoint(eid) => Ok(eid),
Self::Model(_) => anyhow::bail!("Expected Endpoint variant, got Model"),
}
}
/// Extracts the ModelCardInstanceId, returning an error if this is an Endpoint variant
pub fn extract_model_id(&self) -> Result<&ModelCardInstanceId> {
match self {
Self::Model(mid) => Ok(mid),
Self::Endpoint(_) => anyhow::bail!("Expected Model variant, got Endpoint"),
}
}
} }
/// Events emitted by the discovery watch stream /// Events emitted by the discovery watch stream
...@@ -205,8 +349,8 @@ impl DiscoveryInstance { ...@@ -205,8 +349,8 @@ impl DiscoveryInstance {
pub enum DiscoveryEvent { pub enum DiscoveryEvent {
/// A new instance was added /// A new instance was added
Added(DiscoveryInstance), Added(DiscoveryInstance),
/// An instance was removed (identified by instance_id) /// An instance was removed (identified by its unique ID)
Removed(u64), Removed(DiscoveryInstanceId),
} }
/// Stream type for discovery events /// Stream type for discovery events
......
...@@ -84,9 +84,9 @@ where ...@@ -84,9 +84,9 @@ where
break; break;
} }
} }
Ok(DiscoveryEvent::Removed(instance_id)) => { Ok(DiscoveryEvent::Removed(id)) => {
// Remove from state and send update // Remove from state and send update
state.remove(&instance_id); state.remove(&id.instance_id());
if tx.send(state.clone()).is_err() { if tx.send(state.clone()).is_err() {
tracing::debug!("watch_and_extract_field receiver dropped, stopping"); tracing::debug!("watch_and_extract_field receiver dropped, stopping");
break; break;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment