"components/vscode:/vscode.git/clone" did not exist on "07a64744b933c38ed43a48af9f864dfc98bfc9f4"
Unverified Commit 09b26bf6 authored by mohammedabdulwahhab's avatar mohammedabdulwahhab Committed by GitHub
Browse files

fix: refactor to use service discovery (#4092)


Signed-off-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
parent 04f7579b
...@@ -72,17 +72,14 @@ async def test_hello_world(example_dir, server_process): ...@@ -72,17 +72,14 @@ async def test_hello_world(example_dir, server_process):
# Run the client for 5 seconds # Run the client for 5 seconds
client_output = await run_client(example_dir) client_output = await run_client(example_dir)
# Split output into lines and filter out empty lines # Split output into lines and strip whitespace, filter out empty lines
lines = [line.strip() for line in client_output.split("\n") if line.strip()] lines = [line.strip() for line in client_output.split("\n") if line.strip()]
# Each client iteration produces 4 lines in about 4 seconds # Each client iteration produces 4 lines in about 4 seconds
# The client ran for 5 seconds so the first iteration is expected to be completed # The client ran for 5 seconds so the first iteration is expected to be completed
# Assert the first 4 lines are the expected sequence # Check that all 4 expected lines appear in the output
assert (
len(lines) >= 4
), f"Expected at least 4 lines, got {len(lines)}. Output: {lines}"
expected_lines = ["Hello world!", "Hello sun!", "Hello moon!", "Hello star!"] expected_lines = ["Hello world!", "Hello sun!", "Hello moon!", "Hello star!"]
for i, expected_line in enumerate(expected_lines): for expected_line in expected_lines:
assert ( assert expected_line in lines, (
lines[i] == expected_line f"Expected line '{expected_line}' not found in output.\n" f"Lines: {lines}"
), f"Line {i+1}: expected '{expected_line}', got '{lines[i]}'. Full output: {lines}" )
...@@ -2,26 +2,27 @@ ...@@ -2,26 +2,27 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Notify;
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use anyhow::Context as _; use anyhow::Context as _;
use tokio::sync::{Notify, mpsc::Receiver}; use futures::StreamExt;
use dynamo_runtime::{ use dynamo_runtime::{
DistributedRuntime, DistributedRuntime,
discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoveryStream},
pipeline::{ pipeline::{
ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source, ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
network::egress::push_router::PushRouter, network::egress::push_router::PushRouter,
}, },
protocols::{EndpointId, annotated::Annotated}, protocols::{EndpointId, annotated::Annotated},
storage::key_value_store::WatchEvent,
}; };
use crate::{ use crate::{
backend::Backend, backend::Backend,
entrypoint, entrypoint,
kv_router::{KvRouterConfig, PrefillRouter}, kv_router::{KvRouterConfig, PrefillRouter},
model_card::{self, ModelDeploymentCard}, model_card::ModelDeploymentCard,
model_type::{ModelInput, ModelType}, model_type::{ModelInput, ModelType},
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter}, preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
protocols::{ protocols::{
...@@ -99,17 +100,51 @@ impl ModelWatcher { ...@@ -99,17 +100,51 @@ impl ModelWatcher {
} }
/// Common watch logic with optional namespace filtering /// Common watch logic with optional namespace filtering
pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>, target_namespace: Option<&str>) { pub async fn watch(
&self,
mut discovery_stream: DiscoveryStream,
target_namespace: Option<&str>,
) {
let global_namespace = target_namespace.is_none_or(is_global_namespace); let global_namespace = target_namespace.is_none_or(is_global_namespace);
while let Some(event) = events_rx.recv().await { while let Some(result) = discovery_stream.next().await {
let event = match result {
Ok(event) => event,
Err(err) => {
tracing::error!(%err, "Error in discovery stream");
continue;
}
};
match event { match event {
WatchEvent::Put(kv) => { DiscoveryEvent::Added(instance) => {
let key = kv.key_str(); // Extract EndpointId, instance_id, and card from the discovery instance
let endpoint_id = match key_extract(key) { let (endpoint_id, instance_id, mut card) = match &instance {
Ok((eid, _)) => eid, DiscoveryInstance::Model {
namespace,
component,
endpoint,
instance_id,
..
} => {
let eid = EndpointId {
namespace: namespace.clone(),
component: component.clone(),
name: endpoint.clone(),
};
match instance.deserialize_model::<ModelDeploymentCard>() {
Ok(card) => (eid, *instance_id, card),
Err(err) => { Err(err) => {
tracing::error!(%key, %err, "Failed extracting EndpointId from key. Ignoring instance."); tracing::error!(%err, instance_id, "Failed to deserialize model card");
continue;
}
}
}
_ => {
tracing::error!(
"Unexpected discovery instance type (expected ModelCard)"
);
continue; continue;
} }
}; };
...@@ -127,21 +162,6 @@ impl ModelWatcher { ...@@ -127,21 +162,6 @@ impl ModelWatcher {
continue; continue;
} }
let mut card = match serde_json::from_slice::<ModelDeploymentCard>(kv.value()) {
Ok(card) => card,
Err(err) => {
match kv.value_str() {
Ok(value) => {
tracing::error!(%err, value, "Invalid JSON in model card")
}
Err(value_str_err) => {
tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON")
}
}
continue;
}
};
// If we already have a worker for this model, and the ModelDeploymentCard // If we already have a worker for this model, and the ModelDeploymentCard
// cards don't match, alert, and don't add the new instance // cards don't match, alert, and don't add the new instance
let can_add = let can_add =
...@@ -164,7 +184,10 @@ impl ModelWatcher { ...@@ -164,7 +184,10 @@ impl ModelWatcher {
continue; continue;
} }
match self.handle_put(key, &endpoint_id, &mut card).await { // Use instance_id as the HashMap key (simpler and sufficient since keys are opaque)
let key = format!("{:x}", instance_id);
match self.handle_put(&key, &endpoint_id, &mut card).await {
Ok(()) => { Ok(()) => {
tracing::info!( tracing::info!(
model_name = card.name(), model_name = card.name(),
...@@ -183,10 +206,12 @@ impl ModelWatcher { ...@@ -183,10 +206,12 @@ impl ModelWatcher {
} }
} }
} }
WatchEvent::Delete(key) => { DiscoveryEvent::Removed(instance_id) => {
let deleted_key = key.as_ref(); // Use instance_id hex as the HashMap key (matches what we saved with)
let key = format!("{:x}", instance_id);
match self match self
.handle_delete(deleted_key, target_namespace, global_namespace) .handle_delete(&key, target_namespace, global_namespace)
.await .await
{ {
Ok(Some(model_name)) => { Ok(Some(model_name)) => {
...@@ -559,35 +584,39 @@ impl ModelWatcher { ...@@ -559,35 +584,39 @@ impl ModelWatcher {
/// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance /// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance
async fn all_cards(&self) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> { async fn all_cards(&self) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
let store = self.drt.store(); let discovery = self.drt.discovery();
let Some(card_bucket) = store.get_bucket(model_card::ROOT_PATH).await? else { let instances = discovery.list(DiscoveryQuery::AllModels).await?;
// no cards
return Ok(vec![]);
};
let entries = card_bucket.entries().await?;
let mut results = Vec::with_capacity(entries.len()); let mut results = Vec::with_capacity(instances.len());
for (key, card_bytes) in entries { for instance in instances {
let r = match serde_json::from_slice::<ModelDeploymentCard>(&card_bytes) { match instance.deserialize_model::<ModelDeploymentCard>() {
Ok(card) => { Ok(card) => {
let maybe_endpoint_id = // Extract EndpointId from the instance
key_extract(&key).map(|(endpoint_id, _instance_id)| endpoint_id); let endpoint_id = match &instance {
let endpoint_id = match maybe_endpoint_id { dynamo_runtime::discovery::DiscoveryInstance::Model {
Ok(eid) => eid, namespace,
Err(err) => { component,
tracing::error!(%err, "Skipping invalid key, not string or not EndpointId"); endpoint,
..
} => EndpointId {
namespace: namespace.clone(),
component: component.clone(),
name: endpoint.clone(),
},
_ => {
tracing::error!(
"Unexpected discovery instance type (expected ModelCard)"
);
continue; continue;
} }
}; };
(endpoint_id, card) results.push((endpoint_id, card));
} }
Err(err) => { Err(err) => {
let value = String::from_utf8_lossy(&card_bytes); tracing::error!(%err, "Failed to deserialize model card");
tracing::error!(%err, %value, "Invalid JSON in model card");
continue; continue;
} }
}; }
results.push(r);
} }
Ok(results) Ok(results)
} }
...@@ -611,41 +640,3 @@ impl ModelWatcher { ...@@ -611,41 +640,3 @@ impl ModelWatcher {
Ok(all.into_iter().map(|(_eid, card)| card).collect()) Ok(all.into_iter().map(|(_eid, card)| card).collect())
} }
} }
/// The ModelDeploymentCard is published in store with a key like "v1/mdc/dynamo/backend/generate/694d9981145a61ad".
/// Extract the EndpointId and instance_id from that.
fn key_extract(s: &str) -> anyhow::Result<(EndpointId, String)> {
if !s.starts_with(model_card::ROOT_PATH) {
anyhow::bail!("Invalid format: expected model card ROOT_PATH segment in {s}");
}
let parts: Vec<&str> = s.split('/').collect();
// Need at least prefix model_card::ROOT_PATH (2 parts) + namespace, component, name (3 parts)
if parts.len() <= 5 {
anyhow::bail!("Invalid format: not enough path segments in {s}");
}
let endpoint_id = EndpointId {
namespace: parts[2].to_string(),
component: parts[3].to_string(),
name: parts[4].to_string(),
};
Ok((endpoint_id, parts[parts.len() - 1].to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_extract() {
let input = format!(
"{}/dynamo/backend/generate/694d9981145a61ad",
model_card::ROOT_PATH
);
let (endpoint_id, _) = key_extract(&input).unwrap();
assert_eq!(endpoint_id.namespace, "dynamo");
assert_eq!(endpoint_id.component, "backend");
assert_eq!(endpoint_id.name, "generate");
}
}
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
use crate::kv_router::KV_METRICS_SUBJECT; use crate::kv_router::KV_METRICS_SUBJECT;
use crate::kv_router::scoring::LoadEvent; use crate::kv_router::scoring::LoadEvent;
use crate::model_card::{self, ModelDeploymentCard}; use crate::model_card::ModelDeploymentCard;
use dynamo_runtime::component::Client; use dynamo_runtime::component::Client;
use dynamo_runtime::discovery::{DiscoveryQuery, watch_and_extract_field};
use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait}; use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventSubscriber; use dynamo_runtime::traits::events::EventSubscriber;
use dynamo_runtime::utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
...@@ -79,28 +79,23 @@ impl WorkerLoadMonitor for KvWorkerMonitor { ...@@ -79,28 +79,23 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
let endpoint = &self.client.endpoint; let endpoint = &self.client.endpoint;
let component = endpoint.component(); let component = endpoint.component();
let Some(etcd_client) = component.drt().etcd_client() else { let cancellation_token = component.drt().child_token();
// Static mode, no monitoring needed
return Ok(());
};
// Watch for runtime config updates from model deployment cards // Watch for runtime config updates from model deployment cards via discovery interface
let runtime_configs_watcher = watch_prefix_with_extraction( let discovery = component.drt().discovery();
etcd_client, let discovery_stream = discovery
model_card::ROOT_PATH, .list_and_watch(DiscoveryQuery::AllModels, Some(cancellation_token.clone()))
key_extractors::lease_id,
|card: ModelDeploymentCard| Some(card.runtime_config),
component.drt().child_token(),
)
.await?; .await?;
let mut config_events_rx = runtime_configs_watcher.receiver(); let mut config_events_rx =
watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
card.runtime_config
});
// Subscribe to KV metrics events // Subscribe to KV metrics events
let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?; let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?;
let worker_load_states = self.worker_load_states.clone(); let worker_load_states = self.worker_load_states.clone();
let client = self.client.clone(); let client = self.client.clone();
let cancellation_token = component.drt().child_token();
let busy_threshold = self.busy_threshold; let busy_threshold = self.busy_threshold;
// Spawn background monitoring task // Spawn background monitoring task
......
...@@ -10,7 +10,7 @@ use crate::{ ...@@ -10,7 +10,7 @@ use crate::{
entrypoint::{self, EngineConfig}, entrypoint::{self, EngineConfig},
kv_router::{KvPushRouter, KvRouter, PrefillRouter}, kv_router::{KvPushRouter, KvRouter, PrefillRouter},
migration::Migration, migration::Migration,
model_card::{self, ModelDeploymentCard}, model_card::ModelDeploymentCard,
preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter}, preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest}, protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
request_template::RequestTemplate, request_template::RequestTemplate,
...@@ -59,7 +59,6 @@ pub async fn prepare_engine( ...@@ -59,7 +59,6 @@ pub async fn prepare_engine(
) -> anyhow::Result<PreparedEngine> { ) -> anyhow::Result<PreparedEngine> {
match engine_config { match engine_config {
EngineConfig::Dynamic(local_model) => { EngineConfig::Dynamic(local_model) => {
let store = Arc::new(distributed_runtime.store().clone());
let model_manager = Arc::new(ModelManager::new()); let model_manager = Arc::new(ModelManager::new());
let watch_obj = Arc::new(ModelWatcher::new( let watch_obj = Arc::new(ModelWatcher::new(
distributed_runtime.clone(), distributed_runtime.clone(),
...@@ -68,14 +67,16 @@ pub async fn prepare_engine( ...@@ -68,14 +67,16 @@ pub async fn prepare_engine(
None, None,
None, None,
)); ));
let (_, receiver) = store.watch( let discovery = distributed_runtime.discovery();
model_card::ROOT_PATH, let discovery_stream = discovery
None, .list_and_watch(
distributed_runtime.primary_token(), dynamo_runtime::discovery::DiscoveryQuery::AllModels,
); Some(distributed_runtime.primary_token().clone()),
)
.await?;
let inner_watch_obj = watch_obj.clone(); let inner_watch_obj = watch_obj.clone();
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
inner_watch_obj.watch(receiver, None).await; inner_watch_obj.watch(discovery_stream, None).await;
}); });
tracing::info!("Waiting for remote model.."); tracing::info!("Waiting for remote model..");
......
...@@ -9,15 +9,14 @@ use crate::{ ...@@ -9,15 +9,14 @@ use crate::{
entrypoint::{self, EngineConfig, input::common}, entrypoint::{self, EngineConfig, input::common},
grpc::service::kserve, grpc::service::kserve,
kv_router::KvRouterConfig, kv_router::KvRouterConfig,
model_card,
namespace::is_global_namespace, namespace::is_global_namespace,
types::openai::{ types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
}, },
}; };
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::RouterMode; use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::{DistributedRuntime, storage::key_value_store::KeyValueStoreManager};
/// Build and run an KServe gRPC service /// Build and run an KServe gRPC service
pub async fn run( pub async fn run(
...@@ -30,7 +29,6 @@ pub async fn run( ...@@ -30,7 +29,6 @@ pub async fn run(
let grpc_service = match engine_config { let grpc_service = match engine_config {
EngineConfig::Dynamic(_) => { EngineConfig::Dynamic(_) => {
let store = Arc::new(distributed_runtime.store().clone());
let grpc_service = grpc_service_builder.build()?; let grpc_service = grpc_service_builder.build()?;
let router_config = engine_config.local_model().router_config(); let router_config = engine_config.local_model().router_config();
// Listen for models registering themselves, add them to gRPC service // Listen for models registering themselves, add them to gRPC service
...@@ -43,7 +41,6 @@ pub async fn run( ...@@ -43,7 +41,6 @@ pub async fn run(
run_watcher( run_watcher(
distributed_runtime.clone(), distributed_runtime.clone(),
grpc_service.state().manager_clone(), grpc_service.state().manager_clone(),
store,
router_config.router_mode, router_config.router_mode,
Some(router_config.kv_router_config), Some(router_config.kv_router_config),
router_config.busy_threshold, router_config.busy_threshold,
...@@ -166,34 +163,39 @@ pub async fn run( ...@@ -166,34 +163,39 @@ pub async fn run(
/// Spawns a task that watches for new models in store, /// Spawns a task that watches for new models in store,
/// and registers them with the ModelManager so that the HTTP service can use them. /// and registers them with the ModelManager so that the HTTP service can use them.
#[allow(clippy::too_many_arguments)]
async fn run_watcher( async fn run_watcher(
runtime: DistributedRuntime, runtime: DistributedRuntime,
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
store: Arc<KeyValueStoreManager>,
router_mode: RouterMode, router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>, busy_threshold: Option<f64>,
target_namespace: Option<String>, target_namespace: Option<String>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancellation_token = runtime.primary_token();
let watch_obj = ModelWatcher::new( let watch_obj = ModelWatcher::new(
runtime, runtime.clone(),
model_manager, model_manager,
router_mode, router_mode,
kv_router_config, kv_router_config,
busy_threshold, busy_threshold,
); );
tracing::debug!("Waiting for remote model"); tracing::debug!("Waiting for remote model");
let (_, receiver) = store.watch(model_card::ROOT_PATH, None, cancellation_token); let discovery = runtime.discovery();
let discovery_stream = discovery
.list_and_watch(
dynamo_runtime::discovery::DiscoveryQuery::AllModels,
Some(runtime.primary_token()),
)
.await?;
// [gluo NOTE] This is different from http::run_watcher where it alters the HTTP service // [gluo NOTE] This is different from http::run_watcher where it alters the HTTP service
// endpoint being exposed, gRPC doesn't have the same concept as the KServe service // endpoint being exposed, gRPC doesn't have the same concept as the KServe service
// only has one kind of inference endpoint. // only has one kind of inference endpoint.
// Pass the sender to the watcher // Pass the discovery stream to the watcher
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
watch_obj.watch(receiver, target_namespace.as_deref()).await; watch_obj
.watch(discovery_stream, target_namespace.as_deref())
.await;
}); });
Ok(()) Ok(())
......
...@@ -10,7 +10,6 @@ use crate::{ ...@@ -10,7 +10,6 @@ use crate::{
entrypoint::{self, EngineConfig, input::common}, entrypoint::{self, EngineConfig, input::common},
http::service::service_v2::{self, HttpService}, http::service::service_v2::{self, HttpService},
kv_router::KvRouterConfig, kv_router::KvRouterConfig,
model_card,
namespace::is_global_namespace, namespace::is_global_namespace,
types::openai::{ types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
...@@ -19,7 +18,6 @@ use crate::{ ...@@ -19,7 +18,6 @@ use crate::{
}; };
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::RouterMode; use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::storage::key_value_store::KeyValueStoreManager;
/// Build and run an HTTP service /// Build and run an HTTP service
pub async fn run( pub async fn run(
...@@ -69,7 +67,6 @@ pub async fn run( ...@@ -69,7 +67,6 @@ pub async fn run(
// This allows the /health endpoint to query store for active instances // This allows the /health endpoint to query store for active instances
http_service_builder = http_service_builder.store(distributed_runtime.store().clone()); http_service_builder = http_service_builder.store(distributed_runtime.store().clone());
let http_service = http_service_builder.build()?; let http_service = http_service_builder.build()?;
let store = Arc::new(distributed_runtime.store().clone());
let router_config = engine_config.local_model().router_config(); let router_config = engine_config.local_model().router_config();
// Listen for models registering themselves, add them to HTTP service // Listen for models registering themselves, add them to HTTP service
...@@ -84,7 +81,6 @@ pub async fn run( ...@@ -84,7 +81,6 @@ pub async fn run(
run_watcher( run_watcher(
distributed_runtime.clone(), distributed_runtime.clone(),
http_service.state().manager_clone(), http_service.state().manager_clone(),
store,
router_config.router_mode, router_config.router_mode,
Some(router_config.kv_router_config), Some(router_config.kv_router_config),
router_config.busy_threshold, router_config.busy_threshold,
...@@ -271,7 +267,6 @@ pub async fn run( ...@@ -271,7 +267,6 @@ pub async fn run(
async fn run_watcher( async fn run_watcher(
runtime: DistributedRuntime, runtime: DistributedRuntime,
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
store: Arc<KeyValueStoreManager>,
router_mode: RouterMode, router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>, busy_threshold: Option<f64>,
...@@ -279,16 +274,21 @@ async fn run_watcher( ...@@ -279,16 +274,21 @@ async fn run_watcher(
http_service: Arc<HttpService>, http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>, metrics: Arc<crate::http::service::metrics::Metrics>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancellation_token = runtime.primary_token();
let mut watch_obj = ModelWatcher::new( let mut watch_obj = ModelWatcher::new(
runtime, runtime.clone(),
model_manager, model_manager,
router_mode, router_mode,
kv_router_config, kv_router_config,
busy_threshold, busy_threshold,
); );
tracing::debug!("Waiting for remote model"); tracing::debug!("Waiting for remote model");
let (_, receiver) = store.watch(model_card::ROOT_PATH, None, cancellation_token); let discovery = runtime.discovery();
let discovery_stream = discovery
.list_and_watch(
dynamo_runtime::discovery::DiscoveryQuery::AllModels,
Some(runtime.primary_token()),
)
.await?;
// Create a channel to receive model type updates // Create a channel to receive model type updates
let (tx, mut rx) = tokio::sync::mpsc::channel(32); let (tx, mut rx) = tokio::sync::mpsc::channel(32);
...@@ -302,9 +302,11 @@ async fn run_watcher( ...@@ -302,9 +302,11 @@ async fn run_watcher(
} }
}); });
// Pass the sender to the watcher // Pass the discovery stream to the watcher
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
watch_obj.watch(receiver, target_namespace.as_deref()).await; watch_obj
.watch(discovery_stream, target_namespace.as_deref())
.await;
}); });
Ok(()) Ok(())
......
...@@ -6,7 +6,7 @@ use axum::{http::Method, response::IntoResponse, routing::post, Json, Router}; ...@@ -6,7 +6,7 @@ use axum::{http::Method, response::IntoResponse, routing::post, Json, Router};
use serde_json::json; use serde_json::json;
use std::sync::Arc; use std::sync::Arc;
use dynamo_runtime::{pipeline::PushRouter, stream::StreamExt}; use dynamo_runtime::{discovery::DiscoveryQuery, pipeline::PushRouter, stream::StreamExt};
pub const CLEAR_KV_ENDPOINT: &str = "clear_kv_blocks"; pub const CLEAR_KV_ENDPOINT: &str = "clear_kv_blocks";
...@@ -150,7 +150,14 @@ async fn clear_kv_blocks_handler( ...@@ -150,7 +150,14 @@ async fn clear_kv_blocks_handler(
} }
}; };
let instances = match component_obj.list_instances().await { let discovery_client = distributed.discovery();
let discovery_key = DiscoveryQuery::Endpoint {
namespace: namespace.clone(),
component: component.clone(),
endpoint: CLEAR_KV_ENDPOINT.to_string(),
};
let discovery_instances = match discovery_client.list(discovery_key).await {
Ok(instances) => instances, Ok(instances) => instances,
Err(e) => { Err(e) => {
add_worker_result( add_worker_result(
...@@ -165,11 +172,11 @@ async fn clear_kv_blocks_handler( ...@@ -165,11 +172,11 @@ async fn clear_kv_blocks_handler(
} }
}; };
if instances.is_empty() { if discovery_instances.is_empty() {
add_worker_result( add_worker_result(
false, false,
entry_name, entry_name,
"No instances found for worker group", "No instances found for clear_kv_blocks endpoint",
namespace, namespace,
component, component,
None, None,
...@@ -177,30 +184,13 @@ async fn clear_kv_blocks_handler( ...@@ -177,30 +184,13 @@ async fn clear_kv_blocks_handler(
continue; continue;
} }
let instances_filtered = instances let instances_filtered: Vec<dynamo_runtime::component::Instance> = discovery_instances
.clone()
.into_iter() .into_iter()
.filter(|instance| instance.endpoint == CLEAR_KV_ENDPOINT) .filter_map(|di| match di {
.collect::<Vec<_>>(); dynamo_runtime::discovery::DiscoveryInstance::Endpoint(instance) => Some(instance),
_ => None,
if instances_filtered.is_empty() { })
let found_endpoints: Vec<String> = instances
.iter()
.map(|instance| instance.endpoint.clone())
.collect(); .collect();
add_worker_result(
false,
entry_name,
&format!(
"Worker group doesn't support clear_kv_blocks. Supported endpoints: {}",
found_endpoints.join(", ")
),
namespace,
component,
None,
);
continue;
}
for instance in &instances_filtered { for instance in &instances_filtered {
let instance_name = format!("{}-instance-{}", entry.name, instance.id()); let instance_name = format!("{}-instance-{}", entry.name, instance.id());
......
...@@ -52,14 +52,13 @@ async fn live_handler( ...@@ -52,14 +52,13 @@ async fn live_handler(
async fn health_handler( async fn health_handler(
axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>, axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
) -> impl IntoResponse { ) -> impl IntoResponse {
let instances = match list_all_instances(state.store()).await { let instances = match list_all_instances(state.discovery()).await {
Ok(instances) => instances, Ok(instances) => instances,
Err(err) => { Err(err) => {
tracing::warn!(%err, "Failed to fetch instances from store"); tracing::warn!(%err, "Failed to fetch instances from discovery");
vec![] vec![]
} }
}; };
let mut endpoints: Vec<String> = instances let mut endpoints: Vec<String> = instances
.iter() .iter()
.map(|instance| instance.endpoint_id().as_url()) .map(|instance| instance.endpoint_id().as_url())
......
...@@ -18,6 +18,7 @@ use crate::request_template::RequestTemplate; ...@@ -18,6 +18,7 @@ use crate::request_template::RequestTemplate;
use anyhow::Result; use anyhow::Result;
use axum_server::tls_rustls::RustlsConfig; use axum_server::tls_rustls::RustlsConfig;
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::discovery::{Discovery, KVStoreDiscovery};
use dynamo_runtime::logging::make_request_span; use dynamo_runtime::logging::make_request_span;
use dynamo_runtime::metrics::prometheus_names::name_prefix; use dynamo_runtime::metrics::prometheus_names::name_prefix;
use dynamo_runtime::storage::key_value_store::KeyValueStoreManager; use dynamo_runtime::storage::key_value_store::KeyValueStoreManager;
...@@ -31,6 +32,7 @@ pub struct State { ...@@ -31,6 +32,7 @@ pub struct State {
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
manager: Arc<ModelManager>, manager: Arc<ModelManager>,
store: KeyValueStoreManager, store: KeyValueStoreManager,
discovery_client: Arc<dyn Discovery>,
flags: StateFlags, flags: StateFlags,
} }
...@@ -72,10 +74,18 @@ impl StateFlags { ...@@ -72,10 +74,18 @@ impl StateFlags {
impl State { impl State {
pub fn new(manager: Arc<ModelManager>, store: KeyValueStoreManager) -> Self { pub fn new(manager: Arc<ModelManager>, store: KeyValueStoreManager) -> Self {
// Initialize discovery backed by KV store
// Create a cancellation token for the discovery's watch streams
let discovery_client = {
let cancel_token = CancellationToken::new();
Arc::new(KVStoreDiscovery::new(store.clone(), cancel_token)) as Arc<dyn Discovery>
};
Self { Self {
manager, manager,
metrics: Arc::new(Metrics::default()), metrics: Arc::new(Metrics::default()),
store, store,
discovery_client,
flags: StateFlags { flags: StateFlags {
chat_endpoints_enabled: AtomicBool::new(false), chat_endpoints_enabled: AtomicBool::new(false),
cmpl_endpoints_enabled: AtomicBool::new(false), cmpl_endpoints_enabled: AtomicBool::new(false),
...@@ -102,6 +112,10 @@ impl State { ...@@ -102,6 +112,10 @@ impl State {
&self.store &self.store
} }
pub fn discovery(&self) -> Arc<dyn Discovery> {
self.discovery_client.clone()
}
// TODO // TODO
pub fn sse_keep_alive(&self) -> Option<Duration> { pub fn sse_keep_alive(&self) -> Option<Duration> {
None None
......
...@@ -9,13 +9,13 @@ use anyhow::Result; ...@@ -9,13 +9,13 @@ use anyhow::Result;
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::{ use dynamo_runtime::{
component::{Component, InstanceSource}, component::{Component, InstanceSource},
discovery::{DiscoveryQuery, watch_and_extract_field},
pipeline::{ pipeline::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
SingleIn, async_trait, SingleIn, async_trait,
}, },
prelude::*,
protocols::annotated::Annotated, protocols::annotated::Annotated,
utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction}, traits::DistributedRuntimeProvider,
}; };
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -47,7 +47,7 @@ use crate::{ ...@@ -47,7 +47,7 @@ use crate::{
subscriber::start_kv_router_background, subscriber::start_kv_router_background,
}, },
local_model::runtime_config::ModelRuntimeConfig, local_model::runtime_config::ModelRuntimeConfig,
model_card::{self, ModelDeploymentCard}, model_card::ModelDeploymentCard,
preprocessor::PreprocessedRequest, preprocessor::PreprocessedRequest,
protocols::common::llm_backend::LLMEngineOutput, protocols::common::llm_backend::LLMEngineOutput,
}; };
...@@ -233,22 +233,20 @@ impl KvRouter { ...@@ -233,22 +233,20 @@ impl KvRouter {
} }
}; };
// Create runtime config watcher using the generic etcd watcher // Watch for runtime config updates via discovery interface
// TODO: Migrate to discovery_client() once it exposes kv_get_and_watch_prefix functionality let discovery = component.drt().discovery();
let etcd_client = component let discovery_key = DiscoveryQuery::EndpointModels {
.drt() namespace: component.namespace().name().to_string(),
.etcd_client() component: component.name().to_string(),
.expect("Cannot KV route without etcd client"); endpoint: "generate".to_string(),
};
let runtime_configs_watcher = watch_prefix_with_extraction( let discovery_stream = discovery
etcd_client, .list_and_watch(discovery_key, Some(cancellation_token.clone()))
&format!("{}/{}", model_card::ROOT_PATH, component.path()),
key_extractors::lease_id,
|card: ModelDeploymentCard| Some(card.runtime_config),
cancellation_token.clone(),
)
.await?; .await?;
let runtime_configs_rx = runtime_configs_watcher.receiver(); let runtime_configs_rx =
watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
card.runtime_config
});
let indexer = if kv_router_config.overlap_score_weight == 0.0 { let indexer = if kv_router_config.overlap_score_weight == 0.0 {
// When overlap_score_weight is zero, we don't need to track prefixes // When overlap_score_weight is zero, we don't need to track prefixes
......
...@@ -8,6 +8,7 @@ use std::{collections::HashSet, time::Duration}; ...@@ -8,6 +8,7 @@ use std::{collections::HashSet, time::Duration};
use anyhow::Result; use anyhow::Result;
use dynamo_runtime::{ use dynamo_runtime::{
component::Component, component::Component,
discovery::DiscoveryQuery,
prelude::*, prelude::*,
traits::events::EventPublisher, traits::events::EventPublisher,
transports::{ transports::{
...@@ -15,6 +16,7 @@ use dynamo_runtime::{ ...@@ -15,6 +16,7 @@ use dynamo_runtime::{
nats::{NatsQueue, Slug}, nats::{NatsQueue, Slug},
}, },
}; };
use futures::StreamExt;
use rand::Rng; use rand::Rng;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -281,10 +283,15 @@ pub async fn start_kv_router_background( ...@@ -281,10 +283,15 @@ pub async fn start_kv_router_background(
// Get the generate endpoint and watch for instance deletions // Get the generate endpoint and watch for instance deletions
let generate_endpoint = component.endpoint("generate"); let generate_endpoint = component.endpoint("generate");
let (_instance_prefix, mut instance_event_rx) = etcd_client let discovery_client = component.drt().discovery();
.kv_get_and_watch_prefix(generate_endpoint.etcd_root()) let discovery_key = DiscoveryQuery::Endpoint {
.await? namespace: component.namespace().name().to_string(),
.dissolve(); component: component.name().to_string(),
endpoint: "generate".to_string(),
};
let mut instance_event_stream = discovery_client
.list_and_watch(discovery_key, Some(cancellation_token.clone()))
.await?;
// Get instances_rx for tracking current workers // Get instances_rx for tracking current workers
let client = generate_endpoint.client().await?; let client = generate_endpoint.client().await?;
...@@ -337,25 +344,20 @@ pub async fn start_kv_router_background( ...@@ -337,25 +344,20 @@ pub async fn start_kv_router_background(
} }
// Handle generate endpoint instance deletion events // Handle generate endpoint instance deletion events
Some(event) = instance_event_rx.recv() => { Some(discovery_event_result) = instance_event_stream.next() => {
let WatchEvent::Delete(kv) = event else { let Ok(discovery_event) = discovery_event_result else {
continue; continue;
}; };
let key = String::from_utf8_lossy(kv.key()); let dynamo_runtime::discovery::DiscoveryEvent::Removed(worker_id) = discovery_event else {
let Some(worker_id_str) = key.split(&['/', ':'][..]).next_back() else {
tracing::warn!("Could not extract worker ID from instance key: {key}");
continue; continue;
}; };
// Parse as hexadecimal (base 16) tracing::warn!(
let Ok(worker_id) = u64::from_str_radix(worker_id_str, 16) else { worker_id = worker_id,
tracing::warn!("Could not parse worker ID from instance key: {key}"); "DISCOVERY: Generate endpoint instance removed, removing worker"
continue; );
};
tracing::info!("Generate endpoint instance deleted, removing worker {worker_id}");
if let Err(e) = remove_worker_tx.send(worker_id).await { if let Err(e) = remove_worker_tx.send(worker_id).await {
tracing::warn!("Failed to send worker removal for worker {worker_id}: {e}"); tracing::warn!("Failed to send worker removal for worker {worker_id}: {e}");
} }
......
...@@ -5,14 +5,14 @@ use std::fs; ...@@ -5,14 +5,14 @@ use std::fs;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use dynamo_runtime::component::Endpoint; use dynamo_runtime::component::Endpoint;
use dynamo_runtime::discovery::DiscoverySpec;
use dynamo_runtime::protocols::EndpointId; use dynamo_runtime::protocols::EndpointId;
use dynamo_runtime::slug::Slug; use dynamo_runtime::slug::Slug;
use dynamo_runtime::storage::key_value_store::Key;
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use crate::entrypoint::RouterConfig; use crate::entrypoint::RouterConfig;
use crate::mocker::protocols::MockEngineArgs; use crate::mocker::protocols::MockEngineArgs;
use crate::model_card::{self, ModelDeploymentCard}; use crate::model_card::ModelDeploymentCard;
use crate::model_type::{ModelInput, ModelType}; use crate::model_type::{ModelInput, ModelType};
use crate::preprocessor::media::{MediaDecoder, MediaFetcher}; use crate::preprocessor::media::{MediaDecoder, MediaFetcher};
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
...@@ -434,13 +434,16 @@ impl LocalModel { ...@@ -434,13 +434,16 @@ impl LocalModel {
self.card.model_type = model_type; self.card.model_type = model_type;
self.card.model_input = model_input; self.card.model_input = model_input;
// Publish the Model Deployment Card to KV store // Register the Model Deployment Card via discovery interface
let card_store = endpoint.drt().store(); let discovery = endpoint.drt().discovery();
let key = Key::from_raw(endpoint.unique_path(card_store.connection_id())); let spec = DiscoverySpec::from_model(
endpoint.component().namespace().name().to_string(),
endpoint.component().name().to_string(),
endpoint.name().to_string(),
&self.card,
)?;
let _instance = discovery.register(spec).await?;
let _outcome = card_store
.publish(model_card::ROOT_PATH, None, &key, &mut self.card)
.await?;
Ok(()) Ok(())
} }
} }
......
...@@ -295,10 +295,12 @@ mod integration_tests { ...@@ -295,10 +295,12 @@ mod integration_tests {
use super::*; use super::*;
use dynamo_llm::{ use dynamo_llm::{
discovery::ModelWatcher, engines::make_echo_engine, entrypoint::EngineConfig, discovery::ModelWatcher, engines::make_echo_engine, entrypoint::EngineConfig,
local_model::LocalModelBuilder, model_card, local_model::LocalModelBuilder,
}; };
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::discovery::DiscoveryQuery;
use dynamo_runtime::pipeline::RouterMode; use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use std::sync::Arc; use std::sync::Arc;
#[tokio::test] #[tokio::test]
...@@ -333,7 +335,7 @@ mod integration_tests { ...@@ -333,7 +335,7 @@ mod integration_tests {
.build() .build()
.unwrap(); .unwrap();
// Set up model watcher to discover models from etcd (like production) // Set up model watcher to discover models via discovery interface (like production)
// This is crucial for the polling task to find model entries // This is crucial for the polling task to find model entries
let model_watcher = ModelWatcher::new( let model_watcher = ModelWatcher::new(
...@@ -343,17 +345,19 @@ mod integration_tests { ...@@ -343,17 +345,19 @@ mod integration_tests {
None, None,
None, None,
); );
// Start watching etcd for model registrations // Start watching for model registrations via discovery interface
let store = Arc::new(distributed_runtime.store().clone()); let discovery = distributed_runtime.discovery();
let (_, receiver) = store.watch( let discovery_stream = discovery
model_card::ROOT_PATH, .list_and_watch(
None, DiscoveryQuery::AllModels,
distributed_runtime.primary_token(), Some(distributed_runtime.primary_token()),
); )
.await
.unwrap();
// Spawn watcher task to discover models from etcd // Spawn watcher task to discover models
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
model_watcher.watch(receiver, None).await; model_watcher.watch(discovery_stream, None).await;
}); });
// Set up the engine following the StaticFull pattern from http.rs // Set up the engine following the StaticFull pattern from http.rs
......
...@@ -75,7 +75,7 @@ pub use client::{Client, InstanceSource}; ...@@ -75,7 +75,7 @@ pub use client::{Client, InstanceSource};
/// An instance is namespace+component+endpoint+lease_id and must be unique. /// An instance is namespace+component+endpoint+lease_id and must be unique.
pub const INSTANCE_ROOT_PATH: &str = "v1/instances"; pub const INSTANCE_ROOT_PATH: &str = "v1/instances";
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum TransportType { pub enum TransportType {
NatsTcp(String), NatsTcp(String),
...@@ -278,21 +278,24 @@ impl Component { ...@@ -278,21 +278,24 @@ impl Component {
} }
pub async fn list_instances(&self) -> anyhow::Result<Vec<Instance>> { pub async fn list_instances(&self) -> anyhow::Result<Vec<Instance>> {
let client = self.drt.store(); let discovery = self.drt.discovery();
let Some(bucket) = client.get_bucket(&self.instance_root()).await? else {
return Ok(vec![]); let discovery_query = crate::discovery::DiscoveryQuery::ComponentEndpoints {
}; namespace: self.namespace.name(),
let entries = bucket.entries().await?; component: self.name.clone(),
let mut instances = Vec::with_capacity(entries.len());
for (name, bytes) in entries.into_iter() {
let val = match serde_json::from_slice::<Instance>(&bytes) {
Ok(val) => val,
Err(err) => {
anyhow::bail!("Error converting storage response to Instance: {err}. {name}",);
}
}; };
instances.push(val);
} let discovery_instances = discovery.list(discovery_query).await?;
// Extract Instance from DiscoveryInstance::Endpoint wrapper
let mut instances: Vec<Instance> = discovery_instances
.into_iter()
.filter_map(|di| match di {
crate::discovery::DiscoveryInstance::Endpoint(instance) => Some(instance),
_ => None, // Ignore all other variants (ModelCard, etc.)
})
.collect();
instances.sort(); instances.sort();
Ok(instances) Ok(instances)
} }
......
...@@ -9,6 +9,7 @@ use crate::{ ...@@ -9,6 +9,7 @@ use crate::{
storage::key_value_store::{KeyValueStoreManager, WatchEvent}, storage::key_value_store::{KeyValueStoreManager, WatchEvent},
}; };
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
use futures::StreamExt;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::net::unix::pipe::Receiver; use tokio::net::unix::pipe::Receiver;
...@@ -67,18 +68,33 @@ impl Client { ...@@ -67,18 +68,33 @@ impl Client {
// Client with auto-discover instances using etcd // Client with auto-discover instances using etcd
pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> { pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> {
tracing::debug!(
"Client::new_dynamic: Creating dynamic client for endpoint: {}",
endpoint.path()
);
const INSTANCE_REFRESH_PERIOD: Duration = Duration::from_secs(1); const INSTANCE_REFRESH_PERIOD: Duration = Duration::from_secs(1);
// create live endpoint watcher
let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?; let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?;
tracing::debug!(
"Client::new_dynamic: Got instance source for endpoint: {}",
endpoint.path()
);
let client = Client { let client = Client {
endpoint, endpoint: endpoint.clone(),
instance_source: instance_source.clone(), instance_source: instance_source.clone(),
instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))), instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))), instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
}; };
tracing::debug!(
"Client::new_dynamic: Starting instance source monitor for endpoint: {}",
endpoint.path()
);
client.monitor_instance_source(); client.monitor_instance_source();
tracing::debug!(
"Client::new_dynamic: Successfully created dynamic client for endpoint: {}",
endpoint.path()
);
Ok(client) Ok(client)
} }
...@@ -113,17 +129,47 @@ impl Client { ...@@ -113,17 +129,47 @@ impl Client {
/// Wait for at least one Instance to be available for this Endpoint /// Wait for at least one Instance to be available for this Endpoint
pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> { pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
tracing::debug!(
"wait_for_instances: Starting wait for endpoint: {}",
self.endpoint.path()
);
let mut instances: Vec<Instance> = vec![]; let mut instances: Vec<Instance> = vec![];
if let InstanceSource::Dynamic(mut rx) = self.instance_source.as_ref().clone() { if let InstanceSource::Dynamic(mut rx) = self.instance_source.as_ref().clone() {
// wait for there to be 1 or more endpoints // wait for there to be 1 or more endpoints
let mut iteration = 0;
loop { loop {
instances = rx.borrow_and_update().to_vec(); instances = rx.borrow_and_update().to_vec();
tracing::debug!(
"wait_for_instances: iteration={}, current_instance_count={}, endpoint={}",
iteration,
instances.len(),
self.endpoint.path()
);
if instances.is_empty() { if instances.is_empty() {
tracing::debug!(
"wait_for_instances: No instances yet, waiting for change notification for endpoint: {}",
self.endpoint.path()
);
rx.changed().await?; rx.changed().await?;
tracing::debug!(
"wait_for_instances: Change notification received for endpoint: {}",
self.endpoint.path()
);
} else { } else {
tracing::info!(
"wait_for_instances: Found {} instance(s) for endpoint: {}",
instances.len(),
self.endpoint.path()
);
break; break;
} }
iteration += 1;
} }
} else {
tracing::debug!(
"wait_for_instances: Static instance source, no dynamic discovery for endpoint: {}",
self.endpoint.path()
);
} }
Ok(instances) Ok(instances)
} }
...@@ -159,14 +205,22 @@ impl Client { ...@@ -159,14 +205,22 @@ impl Client {
fn monitor_instance_source(&self) { fn monitor_instance_source(&self) {
let cancel_token = self.endpoint.drt().primary_token(); let cancel_token = self.endpoint.drt().primary_token();
let client = self.clone(); let client = self.clone();
let endpoint_path = self.endpoint.path();
tracing::debug!(
"monitor_instance_source: Starting monitor for endpoint: {}",
endpoint_path
);
tokio::task::spawn(async move { tokio::task::spawn(async move {
let mut rx = match client.instance_source.as_ref() { let mut rx = match client.instance_source.as_ref() {
InstanceSource::Static => { InstanceSource::Static => {
tracing::error!("Static instance source is not watchable"); tracing::error!(
"monitor_instance_source: Static instance source is not watchable"
);
return; return;
} }
InstanceSource::Dynamic(rx) => rx.clone(), InstanceSource::Dynamic(rx) => rx.clone(),
}; };
let mut iteration = 0;
while !cancel_token.is_cancelled() { while !cancel_token.is_cancelled() {
let instance_ids: Vec<u64> = rx let instance_ids: Vec<u64> = rx
.borrow_and_update() .borrow_and_update()
...@@ -174,17 +228,37 @@ impl Client { ...@@ -174,17 +228,37 @@ impl Client {
.map(|instance| instance.id()) .map(|instance| instance.id())
.collect(); .collect();
tracing::debug!(
"monitor_instance_source: iteration={}, instance_count={}, instance_ids={:?}, endpoint={}",
iteration,
instance_ids.len(),
instance_ids,
endpoint_path
);
// TODO: this resets both tracked available and free instances // TODO: this resets both tracked available and free instances
client.instance_avail.store(Arc::new(instance_ids.clone())); client.instance_avail.store(Arc::new(instance_ids.clone()));
client.instance_free.store(Arc::new(instance_ids)); client.instance_free.store(Arc::new(instance_ids.clone()));
tracing::debug!("instance source updated"); tracing::debug!(
"monitor_instance_source: instance source updated, endpoint={}",
endpoint_path
);
if let Err(err) = rx.changed().await { if let Err(err) = rx.changed().await {
tracing::error!("The Sender is dropped: {}", err); tracing::error!(
"monitor_instance_source: The Sender is dropped: {}, endpoint={}",
err,
endpoint_path
);
cancel_token.cancel(); cancel_token.cancel();
} }
iteration += 1;
} }
tracing::debug!(
"monitor_instance_source: Monitor loop exiting for endpoint: {}",
endpoint_path
);
}); });
} }
...@@ -195,100 +269,141 @@ impl Client { ...@@ -195,100 +269,141 @@ impl Client {
let instance_sources = drt.instance_sources(); let instance_sources = drt.instance_sources();
let mut instance_sources = instance_sources.lock().await; let mut instance_sources = instance_sources.lock().await;
tracing::debug!(
"get_or_create_dynamic_instance_source: Checking cache for endpoint: {}",
endpoint.path()
);
if let Some(instance_source) = instance_sources.get(endpoint) { if let Some(instance_source) = instance_sources.get(endpoint) {
if let Some(instance_source) = instance_source.upgrade() { if let Some(instance_source) = instance_source.upgrade() {
tracing::debug!(
"get_or_create_dynamic_instance_source: Found cached instance source for endpoint: {}",
endpoint.path()
);
return Ok(instance_source); return Ok(instance_source);
} else { } else {
tracing::debug!(
"get_or_create_dynamic_instance_source: Cached instance source was dropped, removing for endpoint: {}",
endpoint.path()
);
instance_sources.remove(endpoint); instance_sources.remove(endpoint);
} }
} }
let prefix = endpoint.etcd_root(); tracing::debug!(
let store = Arc::new(drt.store().clone()); "get_or_create_dynamic_instance_source: Creating new instance source for endpoint: {}",
let (_, mut kv_event_rx) = endpoint.path()
store.watch(super::INSTANCE_ROOT_PATH, None, drt.primary_token()); );
let discovery = drt.discovery();
let discovery_query = crate::discovery::DiscoveryQuery::Endpoint {
namespace: endpoint.component.namespace.name.clone(),
component: endpoint.component.name.clone(),
endpoint: endpoint.name.clone(),
};
tracing::debug!(
"get_or_create_dynamic_instance_source: Calling discovery.list_and_watch for query: {:?}",
discovery_query
);
let mut discovery_stream = discovery
.list_and_watch(discovery_query.clone(), None)
.await?;
tracing::debug!(
"get_or_create_dynamic_instance_source: Got discovery stream for query: {:?}",
discovery_query
);
let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]); let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
let secondary = endpoint.component.drt.runtime.secondary().clone(); let secondary = endpoint.component.drt.runtime.secondary().clone();
// this task should be included in the registry
// currently this is created once per client, but this object/task should only be instantiated
// once per worker/instance
secondary.spawn(async move { secondary.spawn(async move {
tracing::debug!("Starting endpoint watcher for prefix: {prefix}"); tracing::debug!("endpoint_watcher: Starting for discovery query: {:?}", discovery_query);
let mut map = HashMap::new(); let mut map: HashMap<u64, Instance> = HashMap::new();
let mut event_count = 0;
loop { loop {
let kv_event = tokio::select! { let discovery_event = tokio::select! {
_ = watch_tx.closed() => { _ = watch_tx.closed() => {
tracing::debug!("all watchers have closed; shutting down endpoint watcher for prefix: {prefix}"); tracing::debug!("endpoint_watcher: all watchers have closed; shutting down for discovery query: {:?}", discovery_query);
break;
}
discovery_event = discovery_stream.next() => {
tracing::debug!("endpoint_watcher: Received stream event for discovery query: {:?}", discovery_query);
match discovery_event {
Some(Ok(event)) => {
tracing::debug!("endpoint_watcher: Got Ok event: {:?}", event);
event
},
Some(Err(e)) => {
tracing::error!("endpoint_watcher: discovery stream error: {}; shutting down for discovery query: {:?}", e, discovery_query);
break; break;
} }
kv_event = kv_event_rx.recv() => {
match kv_event {
Some(kv_event) => kv_event,
None => { None => {
tracing::debug!("watch stream has closed; shutting down endpoint watcher for prefix: {prefix}"); tracing::debug!("endpoint_watcher: watch stream has closed; shutting down for discovery query: {:?}", discovery_query);
break; break;
} }
} }
} }
}; };
match kv_event { event_count += 1;
WatchEvent::Put(kv) => { tracing::debug!("endpoint_watcher: Processing event #{} for discovery query: {:?}", event_count, discovery_query);
let key = kv.key_str();
if !key.starts_with(&prefix) {
continue;
}
let Some(mut key) = key.strip_prefix(super::INSTANCE_ROOT_PATH) else {
tracing::error!("WatchEvent::Put Key not in INSTANCE_ROOT_PATH. Should be impossible.");
continue;
};
if key.starts_with("/") {
key = &key[1..];
}
match serde_json::from_slice::<Instance>(kv.value()) { match discovery_event {
Ok(val) => map.insert(key.to_string(), val), crate::discovery::DiscoveryEvent::Added(discovery_instance) => {
Err(err) => { match discovery_instance {
tracing::error!(error = %err, prefix, crate::discovery::DiscoveryInstance::Endpoint(instance) => {
"Unable to parse put endpoint event; shutting down endpoint watcher"); tracing::debug!(
break; "endpoint_watcher: Added endpoint instance_id={}, namespace={}, component={}, endpoint={}",
instance.instance_id,
instance.namespace,
instance.component,
instance.endpoint
);
map.insert(instance.instance_id, instance);
} }
}; _ => {
tracing::debug!("endpoint_watcher: Ignoring non-endpoint instance (Model, etc.) for discovery query: {:?}", discovery_query);
} }
WatchEvent::Delete(key) => {
let key = key.as_ref();
if !key.starts_with(&prefix) {
continue;
} }
let Some(mut key) = key.strip_prefix(super::INSTANCE_ROOT_PATH) else {
tracing::error!("WatchEvent::Delete Key not in INSTANCE_ROOT_PATH. Should be impossible.");
continue;
};
if key.starts_with("/") {
key = &key[1..];
} }
map.remove(key); crate::discovery::DiscoveryEvent::Removed(instance_id) => {
tracing::debug!(
"endpoint_watcher: Removed instance_id={} for discovery query: {:?}",
instance_id,
discovery_query
);
map.remove(&instance_id);
} }
} }
let instances: Vec<Instance> = map.values().cloned().collect(); let instances: Vec<Instance> = map.values().cloned().collect();
tracing::debug!(
"endpoint_watcher: Current map size={}, sending update for discovery query: {:?}",
instances.len(),
discovery_query
);
if watch_tx.send(instances).is_err() { if watch_tx.send(instances).is_err() {
tracing::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix); tracing::debug!("endpoint_watcher: Unable to send watch updates; shutting down for discovery query: {:?}", discovery_query);
break; break;
} }
} }
tracing::debug!("Completed endpoint watcher for prefix: {prefix}"); tracing::debug!("endpoint_watcher: Completed for discovery query: {:?}, total events processed: {}", discovery_query, event_count);
let _ = watch_tx.send(vec![]); let _ = watch_tx.send(vec![]);
}); });
let instance_source = Arc::new(InstanceSource::Dynamic(watch_rx)); let instance_source = Arc::new(InstanceSource::Dynamic(watch_rx));
instance_sources.insert(endpoint.clone(), Arc::downgrade(&instance_source)); instance_sources.insert(endpoint.clone(), Arc::downgrade(&instance_source));
tracing::debug!(
"get_or_create_dynamic_instance_source: Successfully created and cached instance source for endpoint: {}",
endpoint.path()
);
Ok(instance_source) Ok(instance_source)
} }
} }
...@@ -193,26 +193,23 @@ impl EndpointConfigBuilder { ...@@ -193,26 +193,23 @@ impl EndpointConfigBuilder {
result result
}); });
let info = Instance { // Register this endpoint instance in the discovery plane
// The discovery interface abstracts storage backend (etcd, k8s, etc) and provides
// consistent registration/discovery across the system.
let discovery = endpoint.drt().discovery();
let discovery_spec = crate::discovery::DiscoverySpec::Endpoint {
namespace: namespace_name.clone(),
component: component_name.clone(), component: component_name.clone(),
endpoint: endpoint_name.clone(), endpoint: endpoint_name.clone(),
namespace: namespace_name.clone(), transport: TransportType::NatsTcp(subject.clone()),
instance_id: connection_id,
transport: TransportType::NatsTcp(subject),
}; };
let info = serde_json::to_vec_pretty(&info)?; if let Err(e) = discovery.register(discovery_spec).await {
let store = endpoint.drt().store();
let instances_bucket = store
.get_or_create_bucket(super::INSTANCE_ROOT_PATH, None)
.await?;
let key = key_value_store::Key::from_raw(endpoint.unique_path(connection_id));
if let Err(err) = instances_bucket.insert(&key, info.into(), 0).await {
tracing::error!( tracing::error!(
component_name, component_name,
endpoint_name, endpoint_name,
error = %err, error = %e,
"Unable to register service for discovery" "Unable to register service for discovery"
); );
endpoint_shutdown_token.cancel(); endpoint_shutdown_token.cancel();
...@@ -220,6 +217,7 @@ impl EndpointConfigBuilder { ...@@ -220,6 +217,7 @@ impl EndpointConfigBuilder {
"Unable to register service for discovery. Check discovery service status" "Unable to register service for discovery. Check discovery service status"
)); ));
} }
task.await??; task.await??;
Ok(()) Ok(())
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::storage::key_value_store::{KeyValueStoreManager, WatchEvent};
use crate::{CancellationToken, Result};
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use std::pin::Pin;
use std::sync::Arc;
use super::{
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream,
};
const INSTANCES_BUCKET: &str = "v1/instances";
const MODELS_BUCKET: &str = "v1/mdc";
/// Discovery implementation backed by a KeyValueStore
pub struct KVStoreDiscovery {
store: Arc<KeyValueStoreManager>,
cancel_token: CancellationToken,
}
impl KVStoreDiscovery {
pub fn new(store: KeyValueStoreManager, cancel_token: CancellationToken) -> Self {
Self {
store: Arc::new(store),
cancel_token,
}
}
/// Build the key path for an endpoint (relative to bucket, not absolute)
fn endpoint_key(namespace: &str, component: &str, endpoint: &str, instance_id: u64) -> String {
format!("{}/{}/{}/{:x}", namespace, component, endpoint, instance_id)
}
/// Build the key path for a model (relative to bucket, not absolute)
fn model_key(namespace: &str, component: &str, endpoint: &str, instance_id: u64) -> String {
format!("{}/{}/{}/{:x}", namespace, component, endpoint, instance_id)
}
/// Extract prefix for querying based on discovery query
fn query_prefix(query: &DiscoveryQuery) -> String {
match query {
DiscoveryQuery::AllEndpoints => INSTANCES_BUCKET.to_string(),
DiscoveryQuery::NamespacedEndpoints { namespace } => {
format!("{}/{}", INSTANCES_BUCKET, namespace)
}
DiscoveryQuery::ComponentEndpoints {
namespace,
component,
} => {
format!("{}/{}/{}", INSTANCES_BUCKET, namespace, component)
}
DiscoveryQuery::Endpoint {
namespace,
component,
endpoint,
} => {
format!(
"{}/{}/{}/{}",
INSTANCES_BUCKET, namespace, component, endpoint
)
}
DiscoveryQuery::AllModels => MODELS_BUCKET.to_string(),
DiscoveryQuery::NamespacedModels { namespace } => {
format!("{}/{}", MODELS_BUCKET, namespace)
}
DiscoveryQuery::ComponentModels {
namespace,
component,
} => {
format!("{}/{}/{}", MODELS_BUCKET, namespace, component)
}
DiscoveryQuery::EndpointModels {
namespace,
component,
endpoint,
} => {
format!("{}/{}/{}/{}", MODELS_BUCKET, namespace, component, endpoint)
}
}
}
/// Strip bucket prefix from a key if present, returning the relative path within the bucket
/// For example: "v1/instances/ns/comp/ep" -> "ns/comp/ep"
/// Or if already relative: "ns/comp/ep" -> "ns/comp/ep"
fn strip_bucket_prefix<'a>(key: &'a str, bucket_name: &str) -> &'a str {
// Try to strip "bucket_name/" from the beginning
if let Some(stripped) = key.strip_prefix(bucket_name) {
// Strip the leading slash if present
stripped.strip_prefix('/').unwrap_or(stripped)
} else {
// Key is already relative to bucket
key
}
}
/// Check if a key matches the given prefix, handling both absolute and relative key formats
/// This works regardless of whether keys include the bucket prefix (etcd) or not (memory)
fn matches_prefix(key_str: &str, prefix: &str, bucket_name: &str) -> bool {
// Normalize both the key and prefix to relative paths (without bucket prefix)
let relative_key = Self::strip_bucket_prefix(key_str, bucket_name);
let relative_prefix = Self::strip_bucket_prefix(prefix, bucket_name);
// Empty prefix matches everything in the bucket
if relative_prefix.is_empty() {
return true;
}
// Check if the relative key starts with the relative prefix
relative_key.starts_with(relative_prefix)
}
/// Parse and deserialize a discovery instance from KV store entry
fn parse_instance(value: &[u8]) -> Result<DiscoveryInstance> {
let instance: DiscoveryInstance = serde_json::from_slice(value)?;
Ok(instance)
}
}
#[async_trait]
impl Discovery for KVStoreDiscovery {
fn instance_id(&self) -> u64 {
self.store.connection_id()
}
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
let instance_id = self.instance_id();
let instance = spec.with_instance_id(instance_id);
let (bucket_name, key_path) = match &instance {
DiscoveryInstance::Endpoint(inst) => {
let key = Self::endpoint_key(
&inst.namespace,
&inst.component,
&inst.endpoint,
inst.instance_id,
);
tracing::debug!(
"KVStoreDiscovery::register: Registering endpoint instance_id={}, namespace={}, component={}, endpoint={}, key={}",
inst.instance_id,
inst.namespace,
inst.component,
inst.endpoint,
key
);
(INSTANCES_BUCKET, key)
}
DiscoveryInstance::Model {
namespace,
component,
endpoint,
instance_id,
..
} => {
let key = Self::model_key(namespace, component, endpoint, *instance_id);
tracing::debug!(
"KVStoreDiscovery::register: Registering model instance_id={}, namespace={}, component={}, endpoint={}, key={}",
instance_id,
namespace,
component,
endpoint,
key
);
(MODELS_BUCKET, key)
}
};
// Serialize the instance
let instance_json = serde_json::to_vec(&instance)?;
tracing::debug!(
"KVStoreDiscovery::register: Serialized instance to {} bytes for key={}",
instance_json.len(),
key_path
);
// Store in the KV store with no TTL (instances persist until explicitly removed)
tracing::debug!(
"KVStoreDiscovery::register: Getting/creating bucket={} for key={}",
bucket_name,
key_path
);
let bucket = self.store.get_or_create_bucket(bucket_name, None).await?;
let key = crate::storage::key_value_store::Key::from_raw(key_path.clone());
tracing::debug!(
"KVStoreDiscovery::register: Inserting into bucket={}, key={}",
bucket_name,
key_path
);
// Use revision 0 for initial registration
let outcome = bucket.insert(&key, instance_json.into(), 0).await?;
tracing::debug!(
"KVStoreDiscovery::register: Successfully registered instance_id={}, key={}, outcome={:?}",
instance_id,
key_path,
outcome
);
Ok(instance)
}
async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
let prefix = Self::query_prefix(&query);
let bucket_name = if prefix.starts_with(INSTANCES_BUCKET) {
INSTANCES_BUCKET
} else {
MODELS_BUCKET
};
// Get bucket - if it doesn't exist, return empty list
let Some(bucket) = self.store.get_bucket(bucket_name).await? else {
return Ok(Vec::new());
};
// Get all entries from the bucket
let entries = bucket.entries().await?;
// Filter by prefix and deserialize
let mut instances = Vec::new();
for (key_str, value) in entries {
if Self::matches_prefix(&key_str, &prefix, bucket_name) {
match Self::parse_instance(&value) {
Ok(instance) => instances.push(instance),
Err(e) => {
tracing::warn!(key = %key_str, error = %e, "Failed to parse discovery instance");
}
}
}
}
Ok(instances)
}
async fn list_and_watch(
&self,
query: DiscoveryQuery,
cancel_token: Option<CancellationToken>,
) -> Result<DiscoveryStream> {
let prefix = Self::query_prefix(&query);
let bucket_name = if prefix.starts_with(INSTANCES_BUCKET) {
INSTANCES_BUCKET
} else {
MODELS_BUCKET
};
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Starting watch for query={:?}, prefix={}, bucket={}",
query,
prefix,
bucket_name
);
// Use the provided cancellation token, or fall back to the default token
let cancel_token = cancel_token.unwrap_or_else(|| self.cancel_token.clone());
// Use the KeyValueStoreManager's watch mechanism
let (_, mut rx) = self.store.clone().watch(
bucket_name,
None, // No TTL
cancel_token,
);
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Got watch receiver for bucket={}",
bucket_name
);
// Create a stream that filters and transforms WatchEvents to DiscoveryEvents
let stream = async_stream::stream! {
let mut event_count = 0;
tracing::debug!("KVStoreDiscovery::list_and_watch: Stream started, waiting for events on prefix={}", prefix);
while let Some(event) = rx.recv().await {
event_count += 1;
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Received event #{} for prefix={}",
event_count,
prefix
);
let discovery_event = match event {
WatchEvent::Put(kv) => {
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Put event, key={}, prefix={}, matches={}",
kv.key_str(),
prefix,
Self::matches_prefix(kv.key_str(), &prefix, bucket_name)
);
// Check if this key matches our prefix
if !Self::matches_prefix(kv.key_str(), &prefix, bucket_name) {
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Skipping key {} (doesn't match prefix {})",
kv.key_str(),
prefix
);
continue;
}
match Self::parse_instance(kv.value()) {
Ok(instance) => {
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Emitting Added event for instance_id={}, key={}",
instance.instance_id(),
kv.key_str()
);
Some(DiscoveryEvent::Added(instance))
},
Err(e) => {
tracing::warn!(
key = %kv.key_str(),
error = %e,
"Failed to parse discovery instance from watch event"
);
None
}
}
}
WatchEvent::Delete(kv) => {
let key_str = kv.as_ref();
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Delete event, key={}, prefix={}",
key_str,
prefix
);
// Check if this key matches our prefix
if !Self::matches_prefix(key_str, &prefix, bucket_name) {
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Skipping deleted key {} (doesn't match prefix {})",
key_str,
prefix
);
continue;
}
// Extract instance_id from the key path, not the value
// Delete events have empty values in etcd, so we parse the instance_id from the key
// Key format: "v1/instances/namespace/component/endpoint/{instance_id:x}"
let key_parts: Vec<&str> = key_str.split('/').collect();
match key_parts.last() {
Some(instance_id_hex) => {
match u64::from_str_radix(instance_id_hex, 16) {
Ok(instance_id) => {
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Emitting Removed event for instance_id={}, key={}",
instance_id,
key_str
);
Some(DiscoveryEvent::Removed(instance_id))
}
Err(e) => {
tracing::warn!(
key = %key_str,
error = %e,
"Failed to parse instance_id hex from deleted key"
);
None
}
}
}
None => {
tracing::warn!(
key = %key_str,
"Delete event key has no path components"
);
None
}
}
}
};
if let Some(event) = discovery_event {
tracing::debug!("KVStoreDiscovery::list_and_watch: Yielding event: {:?}", event);
yield Ok(event);
} else {
tracing::debug!("KVStoreDiscovery::list_and_watch: Event was filtered out (None)");
}
}
tracing::debug!("KVStoreDiscovery::list_and_watch: Stream ended after {} events for prefix={}", event_count, prefix);
};
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Returning stream for query={:?}",
query
);
Ok(Box::pin(stream))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::component::TransportType;
#[tokio::test]
async fn test_kv_store_discovery_register_endpoint() {
let store = KeyValueStoreManager::memory();
let cancel_token = CancellationToken::new();
let client = KVStoreDiscovery::new(store, cancel_token);
let spec = DiscoverySpec::Endpoint {
namespace: "test".to_string(),
component: "comp1".to_string(),
endpoint: "ep1".to_string(),
transport: TransportType::NatsTcp("nats://localhost:4222".to_string()),
};
let instance = client.register(spec).await.unwrap();
match instance {
DiscoveryInstance::Endpoint(inst) => {
assert_eq!(inst.namespace, "test");
assert_eq!(inst.component, "comp1");
assert_eq!(inst.endpoint, "ep1");
}
_ => panic!("Expected Endpoint instance"),
}
}
#[tokio::test]
async fn test_kv_store_discovery_list() {
let store = KeyValueStoreManager::memory();
let cancel_token = CancellationToken::new();
let client = KVStoreDiscovery::new(store, cancel_token);
// Register multiple endpoints
let spec1 = DiscoverySpec::Endpoint {
namespace: "ns1".to_string(),
component: "comp1".to_string(),
endpoint: "ep1".to_string(),
transport: TransportType::NatsTcp("nats://localhost:4222".to_string()),
};
client.register(spec1).await.unwrap();
let spec2 = DiscoverySpec::Endpoint {
namespace: "ns1".to_string(),
component: "comp1".to_string(),
endpoint: "ep2".to_string(),
transport: TransportType::NatsTcp("nats://localhost:4222".to_string()),
};
client.register(spec2).await.unwrap();
let spec3 = DiscoverySpec::Endpoint {
namespace: "ns2".to_string(),
component: "comp2".to_string(),
endpoint: "ep1".to_string(),
transport: TransportType::NatsTcp("nats://localhost:4222".to_string()),
};
client.register(spec3).await.unwrap();
// List all endpoints
let all = client.list(DiscoveryQuery::AllEndpoints).await.unwrap();
assert_eq!(all.len(), 3);
// List namespaced endpoints
let ns1 = client
.list(DiscoveryQuery::NamespacedEndpoints {
namespace: "ns1".to_string(),
})
.await
.unwrap();
assert_eq!(ns1.len(), 2);
// List component endpoints
let comp1 = client
.list(DiscoveryQuery::ComponentEndpoints {
namespace: "ns1".to_string(),
component: "comp1".to_string(),
})
.await
.unwrap();
assert_eq!(comp1.len(), 2);
}
#[tokio::test]
async fn test_kv_store_discovery_watch() {
let store = KeyValueStoreManager::memory();
let cancel_token = CancellationToken::new();
let client = Arc::new(KVStoreDiscovery::new(store, cancel_token.clone()));
// Start watching before registering
let mut stream = client
.list_and_watch(DiscoveryQuery::AllEndpoints, None)
.await
.unwrap();
let client_clone = client.clone();
let register_task = tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let spec = DiscoverySpec::Endpoint {
namespace: "test".to_string(),
component: "comp1".to_string(),
endpoint: "ep1".to_string(),
transport: TransportType::NatsTcp("nats://localhost:4222".to_string()),
};
client_clone.register(spec).await.unwrap();
});
// Wait for the added event
let event = stream.next().await.unwrap().unwrap();
match event {
DiscoveryEvent::Added(instance) => match instance {
DiscoveryInstance::Endpoint(inst) => {
assert_eq!(inst.namespace, "test");
assert_eq!(inst.component, "comp1");
assert_eq!(inst.endpoint, "ep1");
}
_ => panic!("Expected Endpoint instance"),
},
_ => panic!("Expected Added event"),
}
register_task.await.unwrap();
cancel_token.cancel();
}
}
...@@ -2,10 +2,9 @@ ...@@ -2,10 +2,9 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::{ use super::{
DiscoveryClient, DiscoveryEvent, DiscoveryInstance, DiscoveryKey, DiscoverySpec, Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream,
DiscoveryStream,
}; };
use crate::Result; use crate::{CancellationToken, Result};
use async_trait::async_trait; use async_trait::async_trait;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
...@@ -21,14 +20,14 @@ impl SharedMockRegistry { ...@@ -21,14 +20,14 @@ impl SharedMockRegistry {
} }
} }
/// Mock implementation of DiscoveryClient for testing /// Mock implementation of Discovery for testing
/// We can potentially remove this once we have KeyValueDiscoveryClient implemented /// We can potentially remove this once we have KVStoreDiscovery fully tested
pub struct MockDiscoveryClient { pub struct MockDiscovery {
instance_id: u64, instance_id: u64,
registry: SharedMockRegistry, registry: SharedMockRegistry,
} }
impl MockDiscoveryClient { impl MockDiscovery {
pub fn new(instance_id: Option<u64>, registry: SharedMockRegistry) -> Self { pub fn new(instance_id: Option<u64>, registry: SharedMockRegistry) -> Self {
let instance_id = instance_id.unwrap_or_else(|| { let instance_id = instance_id.unwrap_or_else(|| {
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
...@@ -43,24 +42,24 @@ impl MockDiscoveryClient { ...@@ -43,24 +42,24 @@ impl MockDiscoveryClient {
} }
} }
/// Helper function to check if an instance matches a discovery key query /// Helper function to check if an instance matches a discovery query
fn matches_key(instance: &DiscoveryInstance, key: &DiscoveryKey) -> bool { fn matches_query(instance: &DiscoveryInstance, query: &DiscoveryQuery) -> bool {
match (instance, key) { match (instance, query) {
// Endpoint matching // Endpoint matching
(DiscoveryInstance::Endpoint(_), DiscoveryKey::AllEndpoints) => true, (DiscoveryInstance::Endpoint(_), DiscoveryQuery::AllEndpoints) => true,
(DiscoveryInstance::Endpoint(inst), DiscoveryKey::NamespacedEndpoints { namespace }) => { (DiscoveryInstance::Endpoint(inst), DiscoveryQuery::NamespacedEndpoints { namespace }) => {
&inst.namespace == namespace &inst.namespace == namespace
} }
( (
DiscoveryInstance::Endpoint(inst), DiscoveryInstance::Endpoint(inst),
DiscoveryKey::ComponentEndpoints { DiscoveryQuery::ComponentEndpoints {
namespace, namespace,
component, component,
}, },
) => &inst.namespace == namespace && &inst.component == component, ) => &inst.namespace == namespace && &inst.component == component,
( (
DiscoveryInstance::Endpoint(inst), DiscoveryInstance::Endpoint(inst),
DiscoveryKey::Endpoint { DiscoveryQuery::Endpoint {
namespace, namespace,
component, component,
endpoint, endpoint,
...@@ -71,33 +70,33 @@ fn matches_key(instance: &DiscoveryInstance, key: &DiscoveryKey) -> bool { ...@@ -71,33 +70,33 @@ fn matches_key(instance: &DiscoveryInstance, key: &DiscoveryKey) -> bool {
&& &inst.endpoint == endpoint && &inst.endpoint == endpoint
} }
// ModelCard matching // Model matching
(DiscoveryInstance::ModelCard { .. }, DiscoveryKey::AllModelCards) => true, (DiscoveryInstance::Model { .. }, DiscoveryQuery::AllModels) => true,
( (
DiscoveryInstance::ModelCard { DiscoveryInstance::Model {
namespace: inst_ns, .. namespace: inst_ns, ..
}, },
DiscoveryKey::NamespacedModelCards { namespace }, DiscoveryQuery::NamespacedModels { namespace },
) => inst_ns == namespace, ) => inst_ns == namespace,
( (
DiscoveryInstance::ModelCard { DiscoveryInstance::Model {
namespace: inst_ns, namespace: inst_ns,
component: inst_comp, component: inst_comp,
.. ..
}, },
DiscoveryKey::ComponentModelCards { DiscoveryQuery::ComponentModels {
namespace, namespace,
component, component,
}, },
) => inst_ns == namespace && inst_comp == component, ) => inst_ns == namespace && inst_comp == component,
( (
DiscoveryInstance::ModelCard { DiscoveryInstance::Model {
namespace: inst_ns, namespace: inst_ns,
component: inst_comp, component: inst_comp,
endpoint: inst_ep, endpoint: inst_ep,
.. ..
}, },
DiscoveryKey::EndpointModelCards { DiscoveryQuery::EndpointModels {
namespace, namespace,
component, component,
endpoint, endpoint,
...@@ -107,23 +106,23 @@ fn matches_key(instance: &DiscoveryInstance, key: &DiscoveryKey) -> bool { ...@@ -107,23 +106,23 @@ fn matches_key(instance: &DiscoveryInstance, key: &DiscoveryKey) -> bool {
// Cross-type matches return false // Cross-type matches return false
( (
DiscoveryInstance::Endpoint(_), DiscoveryInstance::Endpoint(_),
DiscoveryKey::AllModelCards DiscoveryQuery::AllModels
| DiscoveryKey::NamespacedModelCards { .. } | DiscoveryQuery::NamespacedModels { .. }
| DiscoveryKey::ComponentModelCards { .. } | DiscoveryQuery::ComponentModels { .. }
| DiscoveryKey::EndpointModelCards { .. }, | DiscoveryQuery::EndpointModels { .. },
) => false, ) => false,
( (
DiscoveryInstance::ModelCard { .. }, DiscoveryInstance::Model { .. },
DiscoveryKey::AllEndpoints DiscoveryQuery::AllEndpoints
| DiscoveryKey::NamespacedEndpoints { .. } | DiscoveryQuery::NamespacedEndpoints { .. }
| DiscoveryKey::ComponentEndpoints { .. } | DiscoveryQuery::ComponentEndpoints { .. }
| DiscoveryKey::Endpoint { .. }, | DiscoveryQuery::Endpoint { .. },
) => false, ) => false,
} }
} }
#[async_trait] #[async_trait]
impl DiscoveryClient for MockDiscoveryClient { impl Discovery for MockDiscovery {
fn instance_id(&self) -> u64 { fn instance_id(&self) -> u64 {
self.instance_id self.instance_id
} }
...@@ -140,16 +139,20 @@ impl DiscoveryClient for MockDiscoveryClient { ...@@ -140,16 +139,20 @@ impl DiscoveryClient for MockDiscoveryClient {
Ok(instance) Ok(instance)
} }
async fn list(&self, key: DiscoveryKey) -> Result<Vec<DiscoveryInstance>> { async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
let instances = self.registry.instances.lock().unwrap(); let instances = self.registry.instances.lock().unwrap();
Ok(instances Ok(instances
.iter() .iter()
.filter(|instance| matches_key(instance, &key)) .filter(|instance| matches_query(instance, &query))
.cloned() .cloned()
.collect()) .collect())
} }
async fn list_and_watch(&self, key: DiscoveryKey) -> Result<DiscoveryStream> { async fn list_and_watch(
&self,
query: DiscoveryQuery,
_cancel_token: Option<CancellationToken>,
) -> Result<DiscoveryStream> {
use std::collections::HashSet; use std::collections::HashSet;
let registry = self.registry.clone(); let registry = self.registry.clone();
...@@ -162,7 +165,7 @@ impl DiscoveryClient for MockDiscoveryClient { ...@@ -162,7 +165,7 @@ impl DiscoveryClient for MockDiscoveryClient {
let instances = registry.instances.lock().unwrap(); let instances = registry.instances.lock().unwrap();
instances instances
.iter() .iter()
.filter(|instance| matches_key(instance, &key)) .filter(|instance| matches_query(instance, &query))
.cloned() .cloned()
.collect() .collect()
}; };
...@@ -170,7 +173,7 @@ impl DiscoveryClient for MockDiscoveryClient { ...@@ -170,7 +173,7 @@ impl DiscoveryClient for MockDiscoveryClient {
let current_ids: HashSet<_> = current.iter().map(|i| { let current_ids: HashSet<_> = current.iter().map(|i| {
match i { match i {
DiscoveryInstance::Endpoint(inst) => inst.instance_id, DiscoveryInstance::Endpoint(inst) => inst.instance_id,
DiscoveryInstance::ModelCard { instance_id, .. } => *instance_id, DiscoveryInstance::Model { instance_id, .. } => *instance_id,
} }
}).collect(); }).collect();
...@@ -178,7 +181,7 @@ impl DiscoveryClient for MockDiscoveryClient { ...@@ -178,7 +181,7 @@ impl DiscoveryClient for MockDiscoveryClient {
for instance in current { for instance in current {
let id = match &instance { let id = match &instance {
DiscoveryInstance::Endpoint(inst) => inst.instance_id, DiscoveryInstance::Endpoint(inst) => inst.instance_id,
DiscoveryInstance::ModelCard { instance_id, .. } => *instance_id, DiscoveryInstance::Model { instance_id, .. } => *instance_id,
}; };
if known_instances.insert(id) { if known_instances.insert(id) {
yield Ok(DiscoveryEvent::Added(instance)); yield Ok(DiscoveryEvent::Added(instance));
...@@ -207,8 +210,8 @@ mod tests { ...@@ -207,8 +210,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_mock_discovery_add_and_remove() { async fn test_mock_discovery_add_and_remove() {
let registry = SharedMockRegistry::new(); let registry = SharedMockRegistry::new();
let client1 = MockDiscoveryClient::new(Some(1), registry.clone()); let client1 = MockDiscovery::new(Some(1), registry.clone());
let client2 = MockDiscoveryClient::new(Some(2), registry.clone()); let client2 = MockDiscovery::new(Some(2), registry.clone());
let spec = DiscoverySpec::Endpoint { let spec = DiscoverySpec::Endpoint {
namespace: "test-ns".to_string(), namespace: "test-ns".to_string(),
...@@ -217,14 +220,14 @@ mod tests { ...@@ -217,14 +220,14 @@ mod tests {
transport: crate::component::TransportType::NatsTcp("test-subject".to_string()), transport: crate::component::TransportType::NatsTcp("test-subject".to_string()),
}; };
let key = DiscoveryKey::Endpoint { let query = DiscoveryQuery::Endpoint {
namespace: "test-ns".to_string(), namespace: "test-ns".to_string(),
component: "test-comp".to_string(), component: "test-comp".to_string(),
endpoint: "test-ep".to_string(), endpoint: "test-ep".to_string(),
}; };
// Start watching // Start watching
let mut stream = client1.list_and_watch(key.clone()).await.unwrap(); let mut stream = client1.list_and_watch(query.clone(), None).await.unwrap();
// Add first instance // Add first instance
client1.register(spec.clone()).await.unwrap(); client1.register(spec.clone()).await.unwrap();
...@@ -251,7 +254,7 @@ mod tests { ...@@ -251,7 +254,7 @@ mod tests {
// Remove first instance // Remove first instance
registry.instances.lock().unwrap().retain(|i| match i { registry.instances.lock().unwrap().retain(|i| match i {
DiscoveryInstance::Endpoint(inst) => inst.instance_id != 1, DiscoveryInstance::Endpoint(inst) => inst.instance_id != 1,
DiscoveryInstance::ModelCard { instance_id, .. } => *instance_id != 1, DiscoveryInstance::Model { instance_id, .. } => *instance_id != 1,
}); });
let event = stream.next().await.unwrap().unwrap(); let event = stream.next().await.unwrap().unwrap();
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::CancellationToken;
use crate::Result; use crate::Result;
use crate::component::TransportType; use crate::component::TransportType;
use async_trait::async_trait; use async_trait::async_trait;
...@@ -9,7 +10,10 @@ use serde::{Deserialize, Serialize}; ...@@ -9,7 +10,10 @@ use serde::{Deserialize, Serialize};
use std::pin::Pin; use std::pin::Pin;
mod mock; mod mock;
pub use mock::{MockDiscoveryClient, SharedMockRegistry}; pub use mock::{MockDiscovery, SharedMockRegistry};
mod kv_store;
pub use kv_store::KVStoreDiscovery;
pub mod utils; pub mod utils;
pub use utils::watch_and_extract_field; pub use utils::watch_and_extract_field;
...@@ -17,7 +21,7 @@ pub use utils::watch_and_extract_field; ...@@ -17,7 +21,7 @@ pub use utils::watch_and_extract_field;
/// Query key for prefix-based discovery queries /// Query key for prefix-based discovery queries
/// Supports hierarchical queries from all endpoints down to specific endpoints /// Supports hierarchical queries from all endpoints down to specific endpoints
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DiscoveryKey { pub enum DiscoveryQuery {
/// Query all endpoints in the system /// Query all endpoints in the system
AllEndpoints, AllEndpoints,
/// Query all endpoints in a specific namespace /// Query all endpoints in a specific namespace
...@@ -35,15 +39,15 @@ pub enum DiscoveryKey { ...@@ -35,15 +39,15 @@ pub enum DiscoveryKey {
component: String, component: String,
endpoint: String, endpoint: String,
}, },
AllModelCards, AllModels,
NamespacedModelCards { NamespacedModels {
namespace: String, namespace: String,
}, },
ComponentModelCards { ComponentModels {
namespace: String, namespace: String,
component: String, component: String,
}, },
EndpointModelCards { EndpointModels {
namespace: String, namespace: String,
component: String, component: String,
endpoint: String, endpoint: String,
...@@ -62,21 +66,21 @@ pub enum DiscoverySpec { ...@@ -62,21 +66,21 @@ pub enum DiscoverySpec {
/// Transport type and routing information /// Transport type and routing information
transport: TransportType, transport: TransportType,
}, },
ModelCard { Model {
namespace: String, namespace: String,
component: String, component: String,
endpoint: String, endpoint: String,
/// ModelDeploymentCard serialized as JSON /// ModelDeploymentCard serialized as JSON
/// This allows lib/runtime to remain independent of lib/llm types /// This allows lib/runtime to remain independent of lib/llm types
/// DiscoverySpec.from_model_card() and DiscoveryInstance.deserialize_model_card() are ergonomic helpers to create and deserialize the model card. /// DiscoverySpec.from_model() and DiscoveryInstance.deserialize_model() are ergonomic helpers to create and deserialize the model card.
card_json: serde_json::Value, card_json: serde_json::Value,
}, },
} }
impl DiscoverySpec { impl DiscoverySpec {
/// Creates a ModelCard discovery spec from a serializable type /// Creates a Model discovery spec from a serializable type
/// The card will be serialized to JSON to avoid cross-crate dependencies /// The card will be serialized to JSON to avoid cross-crate dependencies
pub fn from_model_card<T>( pub fn from_model<T>(
namespace: String, namespace: String,
component: String, component: String,
endpoint: String, endpoint: String,
...@@ -86,7 +90,7 @@ impl DiscoverySpec { ...@@ -86,7 +90,7 @@ impl DiscoverySpec {
T: Serialize, T: Serialize,
{ {
let card_json = serde_json::to_value(card)?; let card_json = serde_json::to_value(card)?;
Ok(Self::ModelCard { Ok(Self::Model {
namespace, namespace,
component, component,
endpoint, endpoint,
...@@ -109,12 +113,12 @@ impl DiscoverySpec { ...@@ -109,12 +113,12 @@ impl DiscoverySpec {
instance_id, instance_id,
transport, transport,
}), }),
Self::ModelCard { Self::Model {
namespace, namespace,
component, component,
endpoint, endpoint,
card_json, card_json,
} => DiscoveryInstance::ModelCard { } => DiscoveryInstance::Model {
namespace, namespace,
component, component,
endpoint, endpoint,
...@@ -132,7 +136,7 @@ impl DiscoverySpec { ...@@ -132,7 +136,7 @@ impl DiscoverySpec {
pub enum DiscoveryInstance { pub enum DiscoveryInstance {
/// Registered endpoint instance - wraps the component::Instance directly /// Registered endpoint instance - wraps the component::Instance directly
Endpoint(crate::component::Instance), Endpoint(crate::component::Instance),
ModelCard { Model {
namespace: String, namespace: String,
component: String, component: String,
endpoint: String, endpoint: String,
...@@ -148,26 +152,26 @@ impl DiscoveryInstance { ...@@ -148,26 +152,26 @@ impl DiscoveryInstance {
pub fn instance_id(&self) -> u64 { pub fn instance_id(&self) -> u64 {
match self { match self {
Self::Endpoint(inst) => inst.instance_id, Self::Endpoint(inst) => inst.instance_id,
Self::ModelCard { instance_id, .. } => *instance_id, Self::Model { instance_id, .. } => *instance_id,
} }
} }
/// Deserializes the model card JSON into the specified type T /// Deserializes the model JSON into the specified type T
/// Returns an error if this is not a ModelCard instance or if deserialization fails /// Returns an error if this is not a Model instance or if deserialization fails
pub fn deserialize_model_card<T>(&self) -> crate::Result<T> pub fn deserialize_model<T>(&self) -> crate::Result<T>
where where
T: for<'de> Deserialize<'de>, T: for<'de> Deserialize<'de>,
{ {
match self { match self {
Self::ModelCard { card_json, .. } => Ok(serde_json::from_value(card_json.clone())?), Self::Model { card_json, .. } => Ok(serde_json::from_value(card_json.clone())?),
Self::Endpoint(_) => { Self::Endpoint(_) => {
crate::raise!("Cannot deserialize model card from Endpoint instance") crate::raise!("Cannot deserialize model from Endpoint instance")
} }
} }
} }
} }
/// Events emitted by the discovery client watch stream /// Events emitted by the discovery watch stream
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum DiscoveryEvent { pub enum DiscoveryEvent {
/// A new instance was added /// A new instance was added
...@@ -179,9 +183,9 @@ pub enum DiscoveryEvent { ...@@ -179,9 +183,9 @@ pub enum DiscoveryEvent {
/// Stream type for discovery events /// Stream type for discovery events
pub type DiscoveryStream = Pin<Box<dyn Stream<Item = Result<DiscoveryEvent>> + Send>>; pub type DiscoveryStream = Pin<Box<dyn Stream<Item = Result<DiscoveryEvent>> + Send>>;
/// Discovery client trait for service discovery across different backends /// Discovery trait for service discovery across different backends
#[async_trait] #[async_trait]
pub trait DiscoveryClient: Send + Sync { pub trait Discovery: Send + Sync {
/// Returns a unique identifier for this worker (e.g lease id if using etcd or generated id for memory store) /// Returns a unique identifier for this worker (e.g lease id if using etcd or generated id for memory store)
/// Discovery objects created by this worker will be associated with this id. /// Discovery objects created by this worker will be associated with this id.
fn instance_id(&self) -> u64; fn instance_id(&self) -> u64;
...@@ -189,10 +193,15 @@ pub trait DiscoveryClient: Send + Sync { ...@@ -189,10 +193,15 @@ pub trait DiscoveryClient: Send + Sync {
/// Registers an object in the discovery plane with the instance id /// Registers an object in the discovery plane with the instance id
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance>; async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance>;
/// Returns a list of currently registered instances for the given discovery key /// Returns a list of currently registered instances for the given discovery query
/// This is a one-time snapshot without watching for changes /// This is a one-time snapshot without watching for changes
async fn list(&self, key: DiscoveryKey) -> Result<Vec<DiscoveryInstance>>; async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>>;
/// Returns a stream of discovery events (Added/Removed) for the given discovery key /// Returns a stream of discovery events (Added/Removed) for the given discovery query
async fn list_and_watch(&self, key: DiscoveryKey) -> Result<DiscoveryStream>; /// The optional cancellation token can be used to stop the watch stream
async fn list_and_watch(
&self,
query: DiscoveryQuery,
cancel_token: Option<CancellationToken>,
) -> Result<DiscoveryStream>;
} }
...@@ -26,7 +26,7 @@ use super::{DiscoveryEvent, DiscoveryInstance, DiscoveryStream}; ...@@ -26,7 +26,7 @@ use super::{DiscoveryEvent, DiscoveryInstance, DiscoveryStream};
/// ///
/// # Example /// # Example
/// ```ignore /// ```ignore
/// let stream = discovery.list_and_watch(DiscoveryKey::ComponentModelCards { ... }).await?; /// let stream = discovery.list_and_watch(DiscoveryQuery::ComponentModels { ... }, None).await?;
/// let runtime_configs_rx = watch_and_extract_field( /// let runtime_configs_rx = watch_and_extract_field(
/// stream, /// stream,
/// |card: ModelDeploymentCard| card.runtime_config, /// |card: ModelDeploymentCard| card.runtime_config,
...@@ -62,7 +62,7 @@ where ...@@ -62,7 +62,7 @@ where
let instance_id = instance.instance_id(); let instance_id = instance.instance_id();
// Deserialize the full instance into type T // Deserialize the full instance into type T
let deserialized: T = match instance.deserialize_model_card() { let deserialized: T = match instance.deserialize_model() {
Ok(d) => d, Ok(d) => d,
Err(e) => { Err(e) => {
tracing::warn!( tracing::warn!(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment