"examples/vscode:/vscode.git/clone" did not exist on "b204456630e8c83ed9d981191cc583e49923f38d"
Unverified Commit 09b26bf6 authored by mohammedabdulwahhab's avatar mohammedabdulwahhab Committed by GitHub
Browse files

fix: refactor to use service discovery (#4092)


Signed-off-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
parent 04f7579b
......@@ -72,17 +72,14 @@ async def test_hello_world(example_dir, server_process):
# Run the client for 5 seconds
client_output = await run_client(example_dir)
# Split output into lines and filter out empty lines
# Split output into lines and strip whitespace, filter out empty lines
lines = [line.strip() for line in client_output.split("\n") if line.strip()]
# Each client iteration produces 4 lines in about 4 seconds
# The client ran for 5 seconds so the first iteration is expected to be completed
# Assert the first 4 lines are the expected sequence
assert (
len(lines) >= 4
), f"Expected at least 4 lines, got {len(lines)}. Output: {lines}"
# Check that all 4 expected lines appear in the output
expected_lines = ["Hello world!", "Hello sun!", "Hello moon!", "Hello star!"]
for i, expected_line in enumerate(expected_lines):
assert (
lines[i] == expected_line
), f"Line {i+1}: expected '{expected_line}', got '{lines[i]}'. Full output: {lines}"
for expected_line in expected_lines:
assert expected_line in lines, (
f"Expected line '{expected_line}' not found in output.\n" f"Lines: {lines}"
)
......@@ -2,26 +2,27 @@
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use tokio::sync::Notify;
use tokio::sync::mpsc::Sender;
use anyhow::Context as _;
use tokio::sync::{Notify, mpsc::Receiver};
use futures::StreamExt;
use dynamo_runtime::{
DistributedRuntime,
discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoveryStream},
pipeline::{
ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
network::egress::push_router::PushRouter,
},
protocols::{EndpointId, annotated::Annotated},
storage::key_value_store::WatchEvent,
};
use crate::{
backend::Backend,
entrypoint,
kv_router::{KvRouterConfig, PrefillRouter},
model_card::{self, ModelDeploymentCard},
model_card::ModelDeploymentCard,
model_type::{ModelInput, ModelType},
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
protocols::{
......@@ -99,17 +100,51 @@ impl ModelWatcher {
}
/// Common watch logic with optional namespace filtering
pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>, target_namespace: Option<&str>) {
pub async fn watch(
&self,
mut discovery_stream: DiscoveryStream,
target_namespace: Option<&str>,
) {
let global_namespace = target_namespace.is_none_or(is_global_namespace);
while let Some(event) = events_rx.recv().await {
while let Some(result) = discovery_stream.next().await {
let event = match result {
Ok(event) => event,
Err(err) => {
tracing::error!(%err, "Error in discovery stream");
continue;
}
};
match event {
WatchEvent::Put(kv) => {
let key = kv.key_str();
let endpoint_id = match key_extract(key) {
Ok((eid, _)) => eid,
DiscoveryEvent::Added(instance) => {
// Extract EndpointId, instance_id, and card from the discovery instance
let (endpoint_id, instance_id, mut card) = match &instance {
DiscoveryInstance::Model {
namespace,
component,
endpoint,
instance_id,
..
} => {
let eid = EndpointId {
namespace: namespace.clone(),
component: component.clone(),
name: endpoint.clone(),
};
match instance.deserialize_model::<ModelDeploymentCard>() {
Ok(card) => (eid, *instance_id, card),
Err(err) => {
tracing::error!(%key, %err, "Failed extracting EndpointId from key. Ignoring instance.");
tracing::error!(%err, instance_id, "Failed to deserialize model card");
continue;
}
}
}
_ => {
tracing::error!(
"Unexpected discovery instance type (expected ModelCard)"
);
continue;
}
};
......@@ -127,21 +162,6 @@ impl ModelWatcher {
continue;
}
let mut card = match serde_json::from_slice::<ModelDeploymentCard>(kv.value()) {
Ok(card) => card,
Err(err) => {
match kv.value_str() {
Ok(value) => {
tracing::error!(%err, value, "Invalid JSON in model card")
}
Err(value_str_err) => {
tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON")
}
}
continue;
}
};
// If we already have a worker for this model, and the ModelDeploymentCard
// cards don't match, alert, and don't add the new instance
let can_add =
......@@ -164,7 +184,10 @@ impl ModelWatcher {
continue;
}
match self.handle_put(key, &endpoint_id, &mut card).await {
// Use instance_id as the HashMap key (simpler and sufficient since keys are opaque)
let key = format!("{:x}", instance_id);
match self.handle_put(&key, &endpoint_id, &mut card).await {
Ok(()) => {
tracing::info!(
model_name = card.name(),
......@@ -183,10 +206,12 @@ impl ModelWatcher {
}
}
}
WatchEvent::Delete(key) => {
let deleted_key = key.as_ref();
DiscoveryEvent::Removed(instance_id) => {
// Use instance_id hex as the HashMap key (matches what we saved with)
let key = format!("{:x}", instance_id);
match self
.handle_delete(deleted_key, target_namespace, global_namespace)
.handle_delete(&key, target_namespace, global_namespace)
.await
{
Ok(Some(model_name)) => {
......@@ -559,35 +584,39 @@ impl ModelWatcher {
/// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance
async fn all_cards(&self) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
let store = self.drt.store();
let Some(card_bucket) = store.get_bucket(model_card::ROOT_PATH).await? else {
// no cards
return Ok(vec![]);
};
let entries = card_bucket.entries().await?;
let discovery = self.drt.discovery();
let instances = discovery.list(DiscoveryQuery::AllModels).await?;
let mut results = Vec::with_capacity(entries.len());
for (key, card_bytes) in entries {
let r = match serde_json::from_slice::<ModelDeploymentCard>(&card_bytes) {
let mut results = Vec::with_capacity(instances.len());
for instance in instances {
match instance.deserialize_model::<ModelDeploymentCard>() {
Ok(card) => {
let maybe_endpoint_id =
key_extract(&key).map(|(endpoint_id, _instance_id)| endpoint_id);
let endpoint_id = match maybe_endpoint_id {
Ok(eid) => eid,
Err(err) => {
tracing::error!(%err, "Skipping invalid key, not string or not EndpointId");
// Extract EndpointId from the instance
let endpoint_id = match &instance {
dynamo_runtime::discovery::DiscoveryInstance::Model {
namespace,
component,
endpoint,
..
} => EndpointId {
namespace: namespace.clone(),
component: component.clone(),
name: endpoint.clone(),
},
_ => {
tracing::error!(
"Unexpected discovery instance type (expected ModelCard)"
);
continue;
}
};
(endpoint_id, card)
results.push((endpoint_id, card));
}
Err(err) => {
let value = String::from_utf8_lossy(&card_bytes);
tracing::error!(%err, %value, "Invalid JSON in model card");
tracing::error!(%err, "Failed to deserialize model card");
continue;
}
};
results.push(r);
}
}
Ok(results)
}
......@@ -611,41 +640,3 @@ impl ModelWatcher {
Ok(all.into_iter().map(|(_eid, card)| card).collect())
}
}
/// The ModelDeploymentCard is published in store with a key like "v1/mdc/dynamo/backend/generate/694d9981145a61ad".
/// Extract the EndpointId and instance_id from that.
fn key_extract(s: &str) -> anyhow::Result<(EndpointId, String)> {
if !s.starts_with(model_card::ROOT_PATH) {
anyhow::bail!("Invalid format: expected model card ROOT_PATH segment in {s}");
}
let parts: Vec<&str> = s.split('/').collect();
// Need at least prefix model_card::ROOT_PATH (2 parts) + namespace, component, name (3 parts)
if parts.len() <= 5 {
anyhow::bail!("Invalid format: not enough path segments in {s}");
}
let endpoint_id = EndpointId {
namespace: parts[2].to_string(),
component: parts[3].to_string(),
name: parts[4].to_string(),
};
Ok((endpoint_id, parts[parts.len() - 1].to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_extract() {
let input = format!(
"{}/dynamo/backend/generate/694d9981145a61ad",
model_card::ROOT_PATH
);
let (endpoint_id, _) = key_extract(&input).unwrap();
assert_eq!(endpoint_id.namespace, "dynamo");
assert_eq!(endpoint_id.component, "backend");
assert_eq!(endpoint_id.name, "generate");
}
}
......@@ -3,12 +3,12 @@
use crate::kv_router::KV_METRICS_SUBJECT;
use crate::kv_router::scoring::LoadEvent;
use crate::model_card::{self, ModelDeploymentCard};
use crate::model_card::ModelDeploymentCard;
use dynamo_runtime::component::Client;
use dynamo_runtime::discovery::{DiscoveryQuery, watch_and_extract_field};
use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventSubscriber;
use dynamo_runtime::utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio_stream::StreamExt;
......@@ -79,28 +79,23 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
let endpoint = &self.client.endpoint;
let component = endpoint.component();
let Some(etcd_client) = component.drt().etcd_client() else {
// Static mode, no monitoring needed
return Ok(());
};
let cancellation_token = component.drt().child_token();
// Watch for runtime config updates from model deployment cards
let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client,
model_card::ROOT_PATH,
key_extractors::lease_id,
|card: ModelDeploymentCard| Some(card.runtime_config),
component.drt().child_token(),
)
// Watch for runtime config updates from model deployment cards via discovery interface
let discovery = component.drt().discovery();
let discovery_stream = discovery
.list_and_watch(DiscoveryQuery::AllModels, Some(cancellation_token.clone()))
.await?;
let mut config_events_rx = runtime_configs_watcher.receiver();
let mut config_events_rx =
watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
card.runtime_config
});
// Subscribe to KV metrics events
let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?;
let worker_load_states = self.worker_load_states.clone();
let client = self.client.clone();
let cancellation_token = component.drt().child_token();
let busy_threshold = self.busy_threshold;
// Spawn background monitoring task
......
......@@ -10,7 +10,7 @@ use crate::{
entrypoint::{self, EngineConfig},
kv_router::{KvPushRouter, KvRouter, PrefillRouter},
migration::Migration,
model_card::{self, ModelDeploymentCard},
model_card::ModelDeploymentCard,
preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
request_template::RequestTemplate,
......@@ -59,7 +59,6 @@ pub async fn prepare_engine(
) -> anyhow::Result<PreparedEngine> {
match engine_config {
EngineConfig::Dynamic(local_model) => {
let store = Arc::new(distributed_runtime.store().clone());
let model_manager = Arc::new(ModelManager::new());
let watch_obj = Arc::new(ModelWatcher::new(
distributed_runtime.clone(),
......@@ -68,14 +67,16 @@ pub async fn prepare_engine(
None,
None,
));
let (_, receiver) = store.watch(
model_card::ROOT_PATH,
None,
distributed_runtime.primary_token(),
);
let discovery = distributed_runtime.discovery();
let discovery_stream = discovery
.list_and_watch(
dynamo_runtime::discovery::DiscoveryQuery::AllModels,
Some(distributed_runtime.primary_token().clone()),
)
.await?;
let inner_watch_obj = watch_obj.clone();
let _watcher_task = tokio::spawn(async move {
inner_watch_obj.watch(receiver, None).await;
inner_watch_obj.watch(discovery_stream, None).await;
});
tracing::info!("Waiting for remote model..");
......
......@@ -9,15 +9,14 @@ use crate::{
entrypoint::{self, EngineConfig, input::common},
grpc::service::kserve,
kv_router::KvRouterConfig,
model_card,
namespace::is_global_namespace,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
};
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::{DistributedRuntime, storage::key_value_store::KeyValueStoreManager};
/// Build and run an KServe gRPC service
pub async fn run(
......@@ -30,7 +29,6 @@ pub async fn run(
let grpc_service = match engine_config {
EngineConfig::Dynamic(_) => {
let store = Arc::new(distributed_runtime.store().clone());
let grpc_service = grpc_service_builder.build()?;
let router_config = engine_config.local_model().router_config();
// Listen for models registering themselves, add them to gRPC service
......@@ -43,7 +41,6 @@ pub async fn run(
run_watcher(
distributed_runtime.clone(),
grpc_service.state().manager_clone(),
store,
router_config.router_mode,
Some(router_config.kv_router_config),
router_config.busy_threshold,
......@@ -166,34 +163,39 @@ pub async fn run(
/// Spawns a task that watches for new models in store,
/// and registers them with the ModelManager so that the HTTP service can use them.
#[allow(clippy::too_many_arguments)]
async fn run_watcher(
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
store: Arc<KeyValueStoreManager>,
router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
target_namespace: Option<String>,
) -> anyhow::Result<()> {
let cancellation_token = runtime.primary_token();
let watch_obj = ModelWatcher::new(
runtime,
runtime.clone(),
model_manager,
router_mode,
kv_router_config,
busy_threshold,
);
tracing::debug!("Waiting for remote model");
let (_, receiver) = store.watch(model_card::ROOT_PATH, None, cancellation_token);
let discovery = runtime.discovery();
let discovery_stream = discovery
.list_and_watch(
dynamo_runtime::discovery::DiscoveryQuery::AllModels,
Some(runtime.primary_token()),
)
.await?;
// [gluo NOTE] This is different from http::run_watcher where it alters the HTTP service
// endpoint being exposed, gRPC doesn't have the same concept as the KServe service
// only has one kind of inference endpoint.
// Pass the sender to the watcher
// Pass the discovery stream to the watcher
let _watcher_task = tokio::spawn(async move {
watch_obj.watch(receiver, target_namespace.as_deref()).await;
watch_obj
.watch(discovery_stream, target_namespace.as_deref())
.await;
});
Ok(())
......
......@@ -10,7 +10,6 @@ use crate::{
entrypoint::{self, EngineConfig, input::common},
http::service::service_v2::{self, HttpService},
kv_router::KvRouterConfig,
model_card,
namespace::is_global_namespace,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
......@@ -19,7 +18,6 @@ use crate::{
};
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::storage::key_value_store::KeyValueStoreManager;
/// Build and run an HTTP service
pub async fn run(
......@@ -69,7 +67,6 @@ pub async fn run(
// This allows the /health endpoint to query store for active instances
http_service_builder = http_service_builder.store(distributed_runtime.store().clone());
let http_service = http_service_builder.build()?;
let store = Arc::new(distributed_runtime.store().clone());
let router_config = engine_config.local_model().router_config();
// Listen for models registering themselves, add them to HTTP service
......@@ -84,7 +81,6 @@ pub async fn run(
run_watcher(
distributed_runtime.clone(),
http_service.state().manager_clone(),
store,
router_config.router_mode,
Some(router_config.kv_router_config),
router_config.busy_threshold,
......@@ -271,7 +267,6 @@ pub async fn run(
async fn run_watcher(
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
store: Arc<KeyValueStoreManager>,
router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
......@@ -279,16 +274,21 @@ async fn run_watcher(
http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>,
) -> anyhow::Result<()> {
let cancellation_token = runtime.primary_token();
let mut watch_obj = ModelWatcher::new(
runtime,
runtime.clone(),
model_manager,
router_mode,
kv_router_config,
busy_threshold,
);
tracing::debug!("Waiting for remote model");
let (_, receiver) = store.watch(model_card::ROOT_PATH, None, cancellation_token);
let discovery = runtime.discovery();
let discovery_stream = discovery
.list_and_watch(
dynamo_runtime::discovery::DiscoveryQuery::AllModels,
Some(runtime.primary_token()),
)
.await?;
// Create a channel to receive model type updates
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
......@@ -302,9 +302,11 @@ async fn run_watcher(
}
});
// Pass the sender to the watcher
// Pass the discovery stream to the watcher
let _watcher_task = tokio::spawn(async move {
watch_obj.watch(receiver, target_namespace.as_deref()).await;
watch_obj
.watch(discovery_stream, target_namespace.as_deref())
.await;
});
Ok(())
......
......@@ -6,7 +6,7 @@ use axum::{http::Method, response::IntoResponse, routing::post, Json, Router};
use serde_json::json;
use std::sync::Arc;
use dynamo_runtime::{pipeline::PushRouter, stream::StreamExt};
use dynamo_runtime::{discovery::DiscoveryQuery, pipeline::PushRouter, stream::StreamExt};
pub const CLEAR_KV_ENDPOINT: &str = "clear_kv_blocks";
......@@ -150,7 +150,14 @@ async fn clear_kv_blocks_handler(
}
};
let instances = match component_obj.list_instances().await {
let discovery_client = distributed.discovery();
let discovery_key = DiscoveryQuery::Endpoint {
namespace: namespace.clone(),
component: component.clone(),
endpoint: CLEAR_KV_ENDPOINT.to_string(),
};
let discovery_instances = match discovery_client.list(discovery_key).await {
Ok(instances) => instances,
Err(e) => {
add_worker_result(
......@@ -165,11 +172,11 @@ async fn clear_kv_blocks_handler(
}
};
if instances.is_empty() {
if discovery_instances.is_empty() {
add_worker_result(
false,
entry_name,
"No instances found for worker group",
"No instances found for clear_kv_blocks endpoint",
namespace,
component,
None,
......@@ -177,30 +184,13 @@ async fn clear_kv_blocks_handler(
continue;
}
let instances_filtered = instances
.clone()
let instances_filtered: Vec<dynamo_runtime::component::Instance> = discovery_instances
.into_iter()
.filter(|instance| instance.endpoint == CLEAR_KV_ENDPOINT)
.collect::<Vec<_>>();
if instances_filtered.is_empty() {
let found_endpoints: Vec<String> = instances
.iter()
.map(|instance| instance.endpoint.clone())
.filter_map(|di| match di {
dynamo_runtime::discovery::DiscoveryInstance::Endpoint(instance) => Some(instance),
_ => None,
})
.collect();
add_worker_result(
false,
entry_name,
&format!(
"Worker group doesn't support clear_kv_blocks. Supported endpoints: {}",
found_endpoints.join(", ")
),
namespace,
component,
None,
);
continue;
}
for instance in &instances_filtered {
let instance_name = format!("{}-instance-{}", entry.name, instance.id());
......
......@@ -52,14 +52,13 @@ async fn live_handler(
async fn health_handler(
axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
) -> impl IntoResponse {
let instances = match list_all_instances(state.store()).await {
let instances = match list_all_instances(state.discovery()).await {
Ok(instances) => instances,
Err(err) => {
tracing::warn!(%err, "Failed to fetch instances from store");
tracing::warn!(%err, "Failed to fetch instances from discovery");
vec![]
}
};
let mut endpoints: Vec<String> = instances
.iter()
.map(|instance| instance.endpoint_id().as_url())
......
......@@ -18,6 +18,7 @@ use crate::request_template::RequestTemplate;
use anyhow::Result;
use axum_server::tls_rustls::RustlsConfig;
use derive_builder::Builder;
use dynamo_runtime::discovery::{Discovery, KVStoreDiscovery};
use dynamo_runtime::logging::make_request_span;
use dynamo_runtime::metrics::prometheus_names::name_prefix;
use dynamo_runtime::storage::key_value_store::KeyValueStoreManager;
......@@ -31,6 +32,7 @@ pub struct State {
metrics: Arc<Metrics>,
manager: Arc<ModelManager>,
store: KeyValueStoreManager,
discovery_client: Arc<dyn Discovery>,
flags: StateFlags,
}
......@@ -72,10 +74,18 @@ impl StateFlags {
impl State {
pub fn new(manager: Arc<ModelManager>, store: KeyValueStoreManager) -> Self {
// Initialize discovery backed by KV store
// Create a cancellation token for the discovery's watch streams
let discovery_client = {
let cancel_token = CancellationToken::new();
Arc::new(KVStoreDiscovery::new(store.clone(), cancel_token)) as Arc<dyn Discovery>
};
Self {
manager,
metrics: Arc::new(Metrics::default()),
store,
discovery_client,
flags: StateFlags {
chat_endpoints_enabled: AtomicBool::new(false),
cmpl_endpoints_enabled: AtomicBool::new(false),
......@@ -102,6 +112,10 @@ impl State {
&self.store
}
pub fn discovery(&self) -> Arc<dyn Discovery> {
self.discovery_client.clone()
}
// TODO
pub fn sse_keep_alive(&self) -> Option<Duration> {
None
......
......@@ -9,13 +9,13 @@ use anyhow::Result;
use derive_builder::Builder;
use dynamo_runtime::{
component::{Component, InstanceSource},
discovery::{DiscoveryQuery, watch_and_extract_field},
pipeline::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
SingleIn, async_trait,
},
prelude::*,
protocols::annotated::Annotated,
utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction},
traits::DistributedRuntimeProvider,
};
use futures::stream::{self, StreamExt};
use serde::{Deserialize, Serialize};
......@@ -47,7 +47,7 @@ use crate::{
subscriber::start_kv_router_background,
},
local_model::runtime_config::ModelRuntimeConfig,
model_card::{self, ModelDeploymentCard},
model_card::ModelDeploymentCard,
preprocessor::PreprocessedRequest,
protocols::common::llm_backend::LLMEngineOutput,
};
......@@ -233,22 +233,20 @@ impl KvRouter {
}
};
// Create runtime config watcher using the generic etcd watcher
// TODO: Migrate to discovery_client() once it exposes kv_get_and_watch_prefix functionality
let etcd_client = component
.drt()
.etcd_client()
.expect("Cannot KV route without etcd client");
let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client,
&format!("{}/{}", model_card::ROOT_PATH, component.path()),
key_extractors::lease_id,
|card: ModelDeploymentCard| Some(card.runtime_config),
cancellation_token.clone(),
)
// Watch for runtime config updates via discovery interface
let discovery = component.drt().discovery();
let discovery_key = DiscoveryQuery::EndpointModels {
namespace: component.namespace().name().to_string(),
component: component.name().to_string(),
endpoint: "generate".to_string(),
};
let discovery_stream = discovery
.list_and_watch(discovery_key, Some(cancellation_token.clone()))
.await?;
let runtime_configs_rx = runtime_configs_watcher.receiver();
let runtime_configs_rx =
watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
card.runtime_config
});
let indexer = if kv_router_config.overlap_score_weight == 0.0 {
// When overlap_score_weight is zero, we don't need to track prefixes
......
......@@ -8,6 +8,7 @@ use std::{collections::HashSet, time::Duration};
use anyhow::Result;
use dynamo_runtime::{
component::Component,
discovery::DiscoveryQuery,
prelude::*,
traits::events::EventPublisher,
transports::{
......@@ -15,6 +16,7 @@ use dynamo_runtime::{
nats::{NatsQueue, Slug},
},
};
use futures::StreamExt;
use rand::Rng;
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
......@@ -281,10 +283,15 @@ pub async fn start_kv_router_background(
// Get the generate endpoint and watch for instance deletions
let generate_endpoint = component.endpoint("generate");
let (_instance_prefix, mut instance_event_rx) = etcd_client
.kv_get_and_watch_prefix(generate_endpoint.etcd_root())
.await?
.dissolve();
let discovery_client = component.drt().discovery();
let discovery_key = DiscoveryQuery::Endpoint {
namespace: component.namespace().name().to_string(),
component: component.name().to_string(),
endpoint: "generate".to_string(),
};
let mut instance_event_stream = discovery_client
.list_and_watch(discovery_key, Some(cancellation_token.clone()))
.await?;
// Get instances_rx for tracking current workers
let client = generate_endpoint.client().await?;
......@@ -337,25 +344,20 @@ pub async fn start_kv_router_background(
}
// Handle generate endpoint instance deletion events
Some(event) = instance_event_rx.recv() => {
let WatchEvent::Delete(kv) = event else {
Some(discovery_event_result) = instance_event_stream.next() => {
let Ok(discovery_event) = discovery_event_result else {
continue;
};
let key = String::from_utf8_lossy(kv.key());
let Some(worker_id_str) = key.split(&['/', ':'][..]).next_back() else {
tracing::warn!("Could not extract worker ID from instance key: {key}");
let dynamo_runtime::discovery::DiscoveryEvent::Removed(worker_id) = discovery_event else {
continue;
};
// Parse as hexadecimal (base 16)
let Ok(worker_id) = u64::from_str_radix(worker_id_str, 16) else {
tracing::warn!("Could not parse worker ID from instance key: {key}");
continue;
};
tracing::warn!(
worker_id = worker_id,
"DISCOVERY: Generate endpoint instance removed, removing worker"
);
tracing::info!("Generate endpoint instance deleted, removing worker {worker_id}");
if let Err(e) = remove_worker_tx.send(worker_id).await {
tracing::warn!("Failed to send worker removal for worker {worker_id}: {e}");
}
......
......@@ -5,14 +5,14 @@ use std::fs;
use std::path::{Path, PathBuf};
use dynamo_runtime::component::Endpoint;
use dynamo_runtime::discovery::DiscoverySpec;
use dynamo_runtime::protocols::EndpointId;
use dynamo_runtime::slug::Slug;
use dynamo_runtime::storage::key_value_store::Key;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use crate::entrypoint::RouterConfig;
use crate::mocker::protocols::MockEngineArgs;
use crate::model_card::{self, ModelDeploymentCard};
use crate::model_card::ModelDeploymentCard;
use crate::model_type::{ModelInput, ModelType};
use crate::preprocessor::media::{MediaDecoder, MediaFetcher};
use crate::request_template::RequestTemplate;
......@@ -434,13 +434,16 @@ impl LocalModel {
self.card.model_type = model_type;
self.card.model_input = model_input;
// Publish the Model Deployment Card to KV store
let card_store = endpoint.drt().store();
let key = Key::from_raw(endpoint.unique_path(card_store.connection_id()));
// Register the Model Deployment Card via discovery interface
let discovery = endpoint.drt().discovery();
let spec = DiscoverySpec::from_model(
endpoint.component().namespace().name().to_string(),
endpoint.component().name().to_string(),
endpoint.name().to_string(),
&self.card,
)?;
let _instance = discovery.register(spec).await?;
let _outcome = card_store
.publish(model_card::ROOT_PATH, None, &key, &mut self.card)
.await?;
Ok(())
}
}
......
......@@ -295,10 +295,12 @@ mod integration_tests {
use super::*;
use dynamo_llm::{
discovery::ModelWatcher, engines::make_echo_engine, entrypoint::EngineConfig,
local_model::LocalModelBuilder, model_card,
local_model::LocalModelBuilder,
};
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::discovery::DiscoveryQuery;
use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use std::sync::Arc;
#[tokio::test]
......@@ -333,7 +335,7 @@ mod integration_tests {
.build()
.unwrap();
// Set up model watcher to discover models from etcd (like production)
// Set up model watcher to discover models via discovery interface (like production)
// This is crucial for the polling task to find model entries
let model_watcher = ModelWatcher::new(
......@@ -343,17 +345,19 @@ mod integration_tests {
None,
None,
);
// Start watching etcd for model registrations
let store = Arc::new(distributed_runtime.store().clone());
let (_, receiver) = store.watch(
model_card::ROOT_PATH,
None,
distributed_runtime.primary_token(),
);
// Start watching for model registrations via discovery interface
let discovery = distributed_runtime.discovery();
let discovery_stream = discovery
.list_and_watch(
DiscoveryQuery::AllModels,
Some(distributed_runtime.primary_token()),
)
.await
.unwrap();
// Spawn watcher task to discover models from etcd
// Spawn watcher task to discover models
let _watcher_task = tokio::spawn(async move {
model_watcher.watch(receiver, None).await;
model_watcher.watch(discovery_stream, None).await;
});
// Set up the engine following the StaticFull pattern from http.rs
......
......@@ -75,7 +75,7 @@ pub use client::{Client, InstanceSource};
/// An instance is namespace+component+endpoint+lease_id and must be unique.
pub const INSTANCE_ROOT_PATH: &str = "v1/instances";
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum TransportType {
NatsTcp(String),
......@@ -278,21 +278,24 @@ impl Component {
}
pub async fn list_instances(&self) -> anyhow::Result<Vec<Instance>> {
let client = self.drt.store();
let Some(bucket) = client.get_bucket(&self.instance_root()).await? else {
return Ok(vec![]);
};
let entries = bucket.entries().await?;
let mut instances = Vec::with_capacity(entries.len());
for (name, bytes) in entries.into_iter() {
let val = match serde_json::from_slice::<Instance>(&bytes) {
Ok(val) => val,
Err(err) => {
anyhow::bail!("Error converting storage response to Instance: {err}. {name}",);
}
let discovery = self.drt.discovery();
let discovery_query = crate::discovery::DiscoveryQuery::ComponentEndpoints {
namespace: self.namespace.name(),
component: self.name.clone(),
};
instances.push(val);
}
let discovery_instances = discovery.list(discovery_query).await?;
// Extract Instance from DiscoveryInstance::Endpoint wrapper
let mut instances: Vec<Instance> = discovery_instances
.into_iter()
.filter_map(|di| match di {
crate::discovery::DiscoveryInstance::Endpoint(instance) => Some(instance),
_ => None, // Ignore all other variants (ModelCard, etc.)
})
.collect();
instances.sort();
Ok(instances)
}
......
......@@ -9,6 +9,7 @@ use crate::{
storage::key_value_store::{KeyValueStoreManager, WatchEvent},
};
use arc_swap::ArcSwap;
use futures::StreamExt;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::unix::pipe::Receiver;
......@@ -67,18 +68,33 @@ impl Client {
// Client with auto-discover instances using etcd
pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> {
tracing::debug!(
"Client::new_dynamic: Creating dynamic client for endpoint: {}",
endpoint.path()
);
const INSTANCE_REFRESH_PERIOD: Duration = Duration::from_secs(1);
// create live endpoint watcher
let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?;
tracing::debug!(
"Client::new_dynamic: Got instance source for endpoint: {}",
endpoint.path()
);
let client = Client {
endpoint,
endpoint: endpoint.clone(),
instance_source: instance_source.clone(),
instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
};
tracing::debug!(
"Client::new_dynamic: Starting instance source monitor for endpoint: {}",
endpoint.path()
);
client.monitor_instance_source();
tracing::debug!(
"Client::new_dynamic: Successfully created dynamic client for endpoint: {}",
endpoint.path()
);
Ok(client)
}
......@@ -113,17 +129,47 @@ impl Client {
/// Wait for at least one Instance to be available for this Endpoint
pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
tracing::debug!(
"wait_for_instances: Starting wait for endpoint: {}",
self.endpoint.path()
);
let mut instances: Vec<Instance> = vec![];
if let InstanceSource::Dynamic(mut rx) = self.instance_source.as_ref().clone() {
// wait for there to be 1 or more endpoints
let mut iteration = 0;
loop {
instances = rx.borrow_and_update().to_vec();
tracing::debug!(
"wait_for_instances: iteration={}, current_instance_count={}, endpoint={}",
iteration,
instances.len(),
self.endpoint.path()
);
if instances.is_empty() {
tracing::debug!(
"wait_for_instances: No instances yet, waiting for change notification for endpoint: {}",
self.endpoint.path()
);
rx.changed().await?;
tracing::debug!(
"wait_for_instances: Change notification received for endpoint: {}",
self.endpoint.path()
);
} else {
tracing::info!(
"wait_for_instances: Found {} instance(s) for endpoint: {}",
instances.len(),
self.endpoint.path()
);
break;
}
iteration += 1;
}
} else {
tracing::debug!(
"wait_for_instances: Static instance source, no dynamic discovery for endpoint: {}",
self.endpoint.path()
);
}
Ok(instances)
}
......@@ -159,14 +205,22 @@ impl Client {
fn monitor_instance_source(&self) {
let cancel_token = self.endpoint.drt().primary_token();
let client = self.clone();
let endpoint_path = self.endpoint.path();
tracing::debug!(
"monitor_instance_source: Starting monitor for endpoint: {}",
endpoint_path
);
tokio::task::spawn(async move {
let mut rx = match client.instance_source.as_ref() {
InstanceSource::Static => {
tracing::error!("Static instance source is not watchable");
tracing::error!(
"monitor_instance_source: Static instance source is not watchable"
);
return;
}
InstanceSource::Dynamic(rx) => rx.clone(),
};
let mut iteration = 0;
while !cancel_token.is_cancelled() {
let instance_ids: Vec<u64> = rx
.borrow_and_update()
......@@ -174,17 +228,37 @@ impl Client {
.map(|instance| instance.id())
.collect();
tracing::debug!(
"monitor_instance_source: iteration={}, instance_count={}, instance_ids={:?}, endpoint={}",
iteration,
instance_ids.len(),
instance_ids,
endpoint_path
);
// TODO: this resets both tracked available and free instances
client.instance_avail.store(Arc::new(instance_ids.clone()));
client.instance_free.store(Arc::new(instance_ids));
client.instance_free.store(Arc::new(instance_ids.clone()));
tracing::debug!("instance source updated");
tracing::debug!(
"monitor_instance_source: instance source updated, endpoint={}",
endpoint_path
);
if let Err(err) = rx.changed().await {
tracing::error!("The Sender is dropped: {}", err);
tracing::error!(
"monitor_instance_source: The Sender is dropped: {}, endpoint={}",
err,
endpoint_path
);
cancel_token.cancel();
}
iteration += 1;
}
tracing::debug!(
"monitor_instance_source: Monitor loop exiting for endpoint: {}",
endpoint_path
);
});
}
......@@ -195,100 +269,141 @@ impl Client {
let instance_sources = drt.instance_sources();
let mut instance_sources = instance_sources.lock().await;
tracing::debug!(
"get_or_create_dynamic_instance_source: Checking cache for endpoint: {}",
endpoint.path()
);
if let Some(instance_source) = instance_sources.get(endpoint) {
if let Some(instance_source) = instance_source.upgrade() {
tracing::debug!(
"get_or_create_dynamic_instance_source: Found cached instance source for endpoint: {}",
endpoint.path()
);
return Ok(instance_source);
} else {
tracing::debug!(
"get_or_create_dynamic_instance_source: Cached instance source was dropped, removing for endpoint: {}",
endpoint.path()
);
instance_sources.remove(endpoint);
}
}
let prefix = endpoint.etcd_root();
let store = Arc::new(drt.store().clone());
let (_, mut kv_event_rx) =
store.watch(super::INSTANCE_ROOT_PATH, None, drt.primary_token());
tracing::debug!(
"get_or_create_dynamic_instance_source: Creating new instance source for endpoint: {}",
endpoint.path()
);
let discovery = drt.discovery();
let discovery_query = crate::discovery::DiscoveryQuery::Endpoint {
namespace: endpoint.component.namespace.name.clone(),
component: endpoint.component.name.clone(),
endpoint: endpoint.name.clone(),
};
tracing::debug!(
"get_or_create_dynamic_instance_source: Calling discovery.list_and_watch for query: {:?}",
discovery_query
);
let mut discovery_stream = discovery
.list_and_watch(discovery_query.clone(), None)
.await?;
tracing::debug!(
"get_or_create_dynamic_instance_source: Got discovery stream for query: {:?}",
discovery_query
);
let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
let secondary = endpoint.component.drt.runtime.secondary().clone();
// this task should be included in the registry
// currently this is created once per client, but this object/task should only be instantiated
// once per worker/instance
secondary.spawn(async move {
tracing::debug!("Starting endpoint watcher for prefix: {prefix}");
let mut map = HashMap::new();
tracing::debug!("endpoint_watcher: Starting for discovery query: {:?}", discovery_query);
let mut map: HashMap<u64, Instance> = HashMap::new();
let mut event_count = 0;
loop {
let kv_event = tokio::select! {
let discovery_event = tokio::select! {
_ = watch_tx.closed() => {
tracing::debug!("all watchers have closed; shutting down endpoint watcher for prefix: {prefix}");
tracing::debug!("endpoint_watcher: all watchers have closed; shutting down for discovery query: {:?}", discovery_query);
break;
}
discovery_event = discovery_stream.next() => {
tracing::debug!("endpoint_watcher: Received stream event for discovery query: {:?}", discovery_query);
match discovery_event {
Some(Ok(event)) => {
tracing::debug!("endpoint_watcher: Got Ok event: {:?}", event);
event
},
Some(Err(e)) => {
tracing::error!("endpoint_watcher: discovery stream error: {}; shutting down for discovery query: {:?}", e, discovery_query);
break;
}
kv_event = kv_event_rx.recv() => {
match kv_event {
Some(kv_event) => kv_event,
None => {
tracing::debug!("watch stream has closed; shutting down endpoint watcher for prefix: {prefix}");
tracing::debug!("endpoint_watcher: watch stream has closed; shutting down for discovery query: {:?}", discovery_query);
break;
}
}
}
};
match kv_event {
WatchEvent::Put(kv) => {
let key = kv.key_str();
if !key.starts_with(&prefix) {
continue;
}
let Some(mut key) = key.strip_prefix(super::INSTANCE_ROOT_PATH) else {
tracing::error!("WatchEvent::Put Key not in INSTANCE_ROOT_PATH. Should be impossible.");
continue;
};
if key.starts_with("/") {
key = &key[1..];
}
event_count += 1;
tracing::debug!("endpoint_watcher: Processing event #{} for discovery query: {:?}", event_count, discovery_query);
match serde_json::from_slice::<Instance>(kv.value()) {
Ok(val) => map.insert(key.to_string(), val),
Err(err) => {
tracing::error!(error = %err, prefix,
"Unable to parse put endpoint event; shutting down endpoint watcher");
break;
match discovery_event {
crate::discovery::DiscoveryEvent::Added(discovery_instance) => {
match discovery_instance {
crate::discovery::DiscoveryInstance::Endpoint(instance) => {
tracing::debug!(
"endpoint_watcher: Added endpoint instance_id={}, namespace={}, component={}, endpoint={}",
instance.instance_id,
instance.namespace,
instance.component,
instance.endpoint
);
map.insert(instance.instance_id, instance);
}
};
_ => {
tracing::debug!("endpoint_watcher: Ignoring non-endpoint instance (Model, etc.) for discovery query: {:?}", discovery_query);
}
WatchEvent::Delete(key) => {
let key = key.as_ref();
if !key.starts_with(&prefix) {
continue;
}
let Some(mut key) = key.strip_prefix(super::INSTANCE_ROOT_PATH) else {
tracing::error!("WatchEvent::Delete Key not in INSTANCE_ROOT_PATH. Should be impossible.");
continue;
};
if key.starts_with("/") {
key = &key[1..];
}
map.remove(key);
crate::discovery::DiscoveryEvent::Removed(instance_id) => {
tracing::debug!(
"endpoint_watcher: Removed instance_id={} for discovery query: {:?}",
instance_id,
discovery_query
);
map.remove(&instance_id);
}
}
let instances: Vec<Instance> = map.values().cloned().collect();
tracing::debug!(
"endpoint_watcher: Current map size={}, sending update for discovery query: {:?}",
instances.len(),
discovery_query
);
if watch_tx.send(instances).is_err() {
tracing::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix);
tracing::debug!("endpoint_watcher: Unable to send watch updates; shutting down for discovery query: {:?}", discovery_query);
break;
}
}
tracing::debug!("Completed endpoint watcher for prefix: {prefix}");
tracing::debug!("endpoint_watcher: Completed for discovery query: {:?}, total events processed: {}", discovery_query, event_count);
let _ = watch_tx.send(vec![]);
});
let instance_source = Arc::new(InstanceSource::Dynamic(watch_rx));
instance_sources.insert(endpoint.clone(), Arc::downgrade(&instance_source));
tracing::debug!(
"get_or_create_dynamic_instance_source: Successfully created and cached instance source for endpoint: {}",
endpoint.path()
);
Ok(instance_source)
}
}
......@@ -193,26 +193,23 @@ impl EndpointConfigBuilder {
result
});
let info = Instance {
// Register this endpoint instance in the discovery plane
// The discovery interface abstracts storage backend (etcd, k8s, etc) and provides
// consistent registration/discovery across the system.
let discovery = endpoint.drt().discovery();
let discovery_spec = crate::discovery::DiscoverySpec::Endpoint {
namespace: namespace_name.clone(),
component: component_name.clone(),
endpoint: endpoint_name.clone(),
namespace: namespace_name.clone(),
instance_id: connection_id,
transport: TransportType::NatsTcp(subject),
transport: TransportType::NatsTcp(subject.clone()),
};
let info = serde_json::to_vec_pretty(&info)?;
let store = endpoint.drt().store();
let instances_bucket = store
.get_or_create_bucket(super::INSTANCE_ROOT_PATH, None)
.await?;
let key = key_value_store::Key::from_raw(endpoint.unique_path(connection_id));
if let Err(err) = instances_bucket.insert(&key, info.into(), 0).await {
if let Err(e) = discovery.register(discovery_spec).await {
tracing::error!(
component_name,
endpoint_name,
error = %err,
error = %e,
"Unable to register service for discovery"
);
endpoint_shutdown_token.cancel();
......@@ -220,6 +217,7 @@ impl EndpointConfigBuilder {
"Unable to register service for discovery. Check discovery service status"
));
}
task.await??;
Ok(())
......
This diff is collapsed.
......@@ -2,10 +2,9 @@
// SPDX-License-Identifier: Apache-2.0
use super::{
DiscoveryClient, DiscoveryEvent, DiscoveryInstance, DiscoveryKey, DiscoverySpec,
DiscoveryStream,
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream,
};
use crate::Result;
use crate::{CancellationToken, Result};
use async_trait::async_trait;
use std::sync::{Arc, Mutex};
......@@ -21,14 +20,14 @@ impl SharedMockRegistry {
}
}
/// Mock implementation of DiscoveryClient for testing
/// We can potentially remove this once we have KeyValueDiscoveryClient implemented
pub struct MockDiscoveryClient {
/// Mock implementation of Discovery for testing
/// We can potentially remove this once we have KVStoreDiscovery fully tested
pub struct MockDiscovery {
instance_id: u64,
registry: SharedMockRegistry,
}
impl MockDiscoveryClient {
impl MockDiscovery {
pub fn new(instance_id: Option<u64>, registry: SharedMockRegistry) -> Self {
let instance_id = instance_id.unwrap_or_else(|| {
use std::sync::atomic::{AtomicU64, Ordering};
......@@ -43,24 +42,24 @@ impl MockDiscoveryClient {
}
}
/// Helper function to check if an instance matches a discovery key query
fn matches_key(instance: &DiscoveryInstance, key: &DiscoveryKey) -> bool {
match (instance, key) {
/// Helper function to check if an instance matches a discovery query
fn matches_query(instance: &DiscoveryInstance, query: &DiscoveryQuery) -> bool {
match (instance, query) {
// Endpoint matching
(DiscoveryInstance::Endpoint(_), DiscoveryKey::AllEndpoints) => true,
(DiscoveryInstance::Endpoint(inst), DiscoveryKey::NamespacedEndpoints { namespace }) => {
(DiscoveryInstance::Endpoint(_), DiscoveryQuery::AllEndpoints) => true,
(DiscoveryInstance::Endpoint(inst), DiscoveryQuery::NamespacedEndpoints { namespace }) => {
&inst.namespace == namespace
}
(
DiscoveryInstance::Endpoint(inst),
DiscoveryKey::ComponentEndpoints {
DiscoveryQuery::ComponentEndpoints {
namespace,
component,
},
) => &inst.namespace == namespace && &inst.component == component,
(
DiscoveryInstance::Endpoint(inst),
DiscoveryKey::Endpoint {
DiscoveryQuery::Endpoint {
namespace,
component,
endpoint,
......@@ -71,33 +70,33 @@ fn matches_key(instance: &DiscoveryInstance, key: &DiscoveryKey) -> bool {
&& &inst.endpoint == endpoint
}
// ModelCard matching
(DiscoveryInstance::ModelCard { .. }, DiscoveryKey::AllModelCards) => true,
// Model matching
(DiscoveryInstance::Model { .. }, DiscoveryQuery::AllModels) => true,
(
DiscoveryInstance::ModelCard {
DiscoveryInstance::Model {
namespace: inst_ns, ..
},
DiscoveryKey::NamespacedModelCards { namespace },
DiscoveryQuery::NamespacedModels { namespace },
) => inst_ns == namespace,
(
DiscoveryInstance::ModelCard {
DiscoveryInstance::Model {
namespace: inst_ns,
component: inst_comp,
..
},
DiscoveryKey::ComponentModelCards {
DiscoveryQuery::ComponentModels {
namespace,
component,
},
) => inst_ns == namespace && inst_comp == component,
(
DiscoveryInstance::ModelCard {
DiscoveryInstance::Model {
namespace: inst_ns,
component: inst_comp,
endpoint: inst_ep,
..
},
DiscoveryKey::EndpointModelCards {
DiscoveryQuery::EndpointModels {
namespace,
component,
endpoint,
......@@ -107,23 +106,23 @@ fn matches_key(instance: &DiscoveryInstance, key: &DiscoveryKey) -> bool {
// Cross-type matches return false
(
DiscoveryInstance::Endpoint(_),
DiscoveryKey::AllModelCards
| DiscoveryKey::NamespacedModelCards { .. }
| DiscoveryKey::ComponentModelCards { .. }
| DiscoveryKey::EndpointModelCards { .. },
DiscoveryQuery::AllModels
| DiscoveryQuery::NamespacedModels { .. }
| DiscoveryQuery::ComponentModels { .. }
| DiscoveryQuery::EndpointModels { .. },
) => false,
(
DiscoveryInstance::ModelCard { .. },
DiscoveryKey::AllEndpoints
| DiscoveryKey::NamespacedEndpoints { .. }
| DiscoveryKey::ComponentEndpoints { .. }
| DiscoveryKey::Endpoint { .. },
DiscoveryInstance::Model { .. },
DiscoveryQuery::AllEndpoints
| DiscoveryQuery::NamespacedEndpoints { .. }
| DiscoveryQuery::ComponentEndpoints { .. }
| DiscoveryQuery::Endpoint { .. },
) => false,
}
}
#[async_trait]
impl DiscoveryClient for MockDiscoveryClient {
impl Discovery for MockDiscovery {
fn instance_id(&self) -> u64 {
self.instance_id
}
......@@ -140,16 +139,20 @@ impl DiscoveryClient for MockDiscoveryClient {
Ok(instance)
}
async fn list(&self, key: DiscoveryKey) -> Result<Vec<DiscoveryInstance>> {
async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
let instances = self.registry.instances.lock().unwrap();
Ok(instances
.iter()
.filter(|instance| matches_key(instance, &key))
.filter(|instance| matches_query(instance, &query))
.cloned()
.collect())
}
async fn list_and_watch(&self, key: DiscoveryKey) -> Result<DiscoveryStream> {
async fn list_and_watch(
&self,
query: DiscoveryQuery,
_cancel_token: Option<CancellationToken>,
) -> Result<DiscoveryStream> {
use std::collections::HashSet;
let registry = self.registry.clone();
......@@ -162,7 +165,7 @@ impl DiscoveryClient for MockDiscoveryClient {
let instances = registry.instances.lock().unwrap();
instances
.iter()
.filter(|instance| matches_key(instance, &key))
.filter(|instance| matches_query(instance, &query))
.cloned()
.collect()
};
......@@ -170,7 +173,7 @@ impl DiscoveryClient for MockDiscoveryClient {
let current_ids: HashSet<_> = current.iter().map(|i| {
match i {
DiscoveryInstance::Endpoint(inst) => inst.instance_id,
DiscoveryInstance::ModelCard { instance_id, .. } => *instance_id,
DiscoveryInstance::Model { instance_id, .. } => *instance_id,
}
}).collect();
......@@ -178,7 +181,7 @@ impl DiscoveryClient for MockDiscoveryClient {
for instance in current {
let id = match &instance {
DiscoveryInstance::Endpoint(inst) => inst.instance_id,
DiscoveryInstance::ModelCard { instance_id, .. } => *instance_id,
DiscoveryInstance::Model { instance_id, .. } => *instance_id,
};
if known_instances.insert(id) {
yield Ok(DiscoveryEvent::Added(instance));
......@@ -207,8 +210,8 @@ mod tests {
#[tokio::test]
async fn test_mock_discovery_add_and_remove() {
let registry = SharedMockRegistry::new();
let client1 = MockDiscoveryClient::new(Some(1), registry.clone());
let client2 = MockDiscoveryClient::new(Some(2), registry.clone());
let client1 = MockDiscovery::new(Some(1), registry.clone());
let client2 = MockDiscovery::new(Some(2), registry.clone());
let spec = DiscoverySpec::Endpoint {
namespace: "test-ns".to_string(),
......@@ -217,14 +220,14 @@ mod tests {
transport: crate::component::TransportType::NatsTcp("test-subject".to_string()),
};
let key = DiscoveryKey::Endpoint {
let query = DiscoveryQuery::Endpoint {
namespace: "test-ns".to_string(),
component: "test-comp".to_string(),
endpoint: "test-ep".to_string(),
};
// Start watching
let mut stream = client1.list_and_watch(key.clone()).await.unwrap();
let mut stream = client1.list_and_watch(query.clone(), None).await.unwrap();
// Add first instance
client1.register(spec.clone()).await.unwrap();
......@@ -251,7 +254,7 @@ mod tests {
// Remove first instance
registry.instances.lock().unwrap().retain(|i| match i {
DiscoveryInstance::Endpoint(inst) => inst.instance_id != 1,
DiscoveryInstance::ModelCard { instance_id, .. } => *instance_id != 1,
DiscoveryInstance::Model { instance_id, .. } => *instance_id != 1,
});
let event = stream.next().await.unwrap().unwrap();
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::CancellationToken;
use crate::Result;
use crate::component::TransportType;
use async_trait::async_trait;
......@@ -9,7 +10,10 @@ use serde::{Deserialize, Serialize};
use std::pin::Pin;
mod mock;
pub use mock::{MockDiscoveryClient, SharedMockRegistry};
pub use mock::{MockDiscovery, SharedMockRegistry};
mod kv_store;
pub use kv_store::KVStoreDiscovery;
pub mod utils;
pub use utils::watch_and_extract_field;
......@@ -17,7 +21,7 @@ pub use utils::watch_and_extract_field;
/// Query key for prefix-based discovery queries
/// Supports hierarchical queries from all endpoints down to specific endpoints
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DiscoveryKey {
pub enum DiscoveryQuery {
/// Query all endpoints in the system
AllEndpoints,
/// Query all endpoints in a specific namespace
......@@ -35,15 +39,15 @@ pub enum DiscoveryKey {
component: String,
endpoint: String,
},
AllModelCards,
NamespacedModelCards {
AllModels,
NamespacedModels {
namespace: String,
},
ComponentModelCards {
ComponentModels {
namespace: String,
component: String,
},
EndpointModelCards {
EndpointModels {
namespace: String,
component: String,
endpoint: String,
......@@ -62,21 +66,21 @@ pub enum DiscoverySpec {
/// Transport type and routing information
transport: TransportType,
},
ModelCard {
Model {
namespace: String,
component: String,
endpoint: String,
/// ModelDeploymentCard serialized as JSON
/// This allows lib/runtime to remain independent of lib/llm types
/// DiscoverySpec.from_model_card() and DiscoveryInstance.deserialize_model_card() are ergonomic helpers to create and deserialize the model card.
/// DiscoverySpec.from_model() and DiscoveryInstance.deserialize_model() are ergonomic helpers to create and deserialize the model card.
card_json: serde_json::Value,
},
}
impl DiscoverySpec {
/// Creates a ModelCard discovery spec from a serializable type
/// Creates a Model discovery spec from a serializable type
/// The card will be serialized to JSON to avoid cross-crate dependencies
pub fn from_model_card<T>(
pub fn from_model<T>(
namespace: String,
component: String,
endpoint: String,
......@@ -86,7 +90,7 @@ impl DiscoverySpec {
T: Serialize,
{
let card_json = serde_json::to_value(card)?;
Ok(Self::ModelCard {
Ok(Self::Model {
namespace,
component,
endpoint,
......@@ -109,12 +113,12 @@ impl DiscoverySpec {
instance_id,
transport,
}),
Self::ModelCard {
Self::Model {
namespace,
component,
endpoint,
card_json,
} => DiscoveryInstance::ModelCard {
} => DiscoveryInstance::Model {
namespace,
component,
endpoint,
......@@ -132,7 +136,7 @@ impl DiscoverySpec {
pub enum DiscoveryInstance {
/// Registered endpoint instance - wraps the component::Instance directly
Endpoint(crate::component::Instance),
ModelCard {
Model {
namespace: String,
component: String,
endpoint: String,
......@@ -148,26 +152,26 @@ impl DiscoveryInstance {
pub fn instance_id(&self) -> u64 {
match self {
Self::Endpoint(inst) => inst.instance_id,
Self::ModelCard { instance_id, .. } => *instance_id,
Self::Model { instance_id, .. } => *instance_id,
}
}
/// Deserializes the model card JSON into the specified type T
/// Returns an error if this is not a ModelCard instance or if deserialization fails
pub fn deserialize_model_card<T>(&self) -> crate::Result<T>
/// Deserializes the model JSON into the specified type T
/// Returns an error if this is not a Model instance or if deserialization fails
pub fn deserialize_model<T>(&self) -> crate::Result<T>
where
T: for<'de> Deserialize<'de>,
{
match self {
Self::ModelCard { card_json, .. } => Ok(serde_json::from_value(card_json.clone())?),
Self::Model { card_json, .. } => Ok(serde_json::from_value(card_json.clone())?),
Self::Endpoint(_) => {
crate::raise!("Cannot deserialize model card from Endpoint instance")
crate::raise!("Cannot deserialize model from Endpoint instance")
}
}
}
}
/// Events emitted by the discovery client watch stream
/// Events emitted by the discovery watch stream
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DiscoveryEvent {
/// A new instance was added
......@@ -179,9 +183,9 @@ pub enum DiscoveryEvent {
/// Stream type for discovery events
pub type DiscoveryStream = Pin<Box<dyn Stream<Item = Result<DiscoveryEvent>> + Send>>;
/// Discovery client trait for service discovery across different backends
/// Discovery trait for service discovery across different backends
#[async_trait]
pub trait DiscoveryClient: Send + Sync {
pub trait Discovery: Send + Sync {
/// Returns a unique identifier for this worker (e.g lease id if using etcd or generated id for memory store)
/// Discovery objects created by this worker will be associated with this id.
fn instance_id(&self) -> u64;
......@@ -189,10 +193,15 @@ pub trait DiscoveryClient: Send + Sync {
/// Registers an object in the discovery plane with the instance id
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance>;
/// Returns a list of currently registered instances for the given discovery key
/// Returns a list of currently registered instances for the given discovery query
/// This is a one-time snapshot without watching for changes
async fn list(&self, key: DiscoveryKey) -> Result<Vec<DiscoveryInstance>>;
/// Returns a stream of discovery events (Added/Removed) for the given discovery key
async fn list_and_watch(&self, key: DiscoveryKey) -> Result<DiscoveryStream>;
async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>>;
/// Returns a stream of discovery events (Added/Removed) for the given discovery query
/// The optional cancellation token can be used to stop the watch stream
async fn list_and_watch(
&self,
query: DiscoveryQuery,
cancel_token: Option<CancellationToken>,
) -> Result<DiscoveryStream>;
}
......@@ -26,7 +26,7 @@ use super::{DiscoveryEvent, DiscoveryInstance, DiscoveryStream};
///
/// # Example
/// ```ignore
/// let stream = discovery.list_and_watch(DiscoveryKey::ComponentModelCards { ... }).await?;
/// let stream = discovery.list_and_watch(DiscoveryQuery::ComponentModels { ... }, None).await?;
/// let runtime_configs_rx = watch_and_extract_field(
/// stream,
/// |card: ModelDeploymentCard| card.runtime_config,
......@@ -62,7 +62,7 @@ where
let instance_id = instance.instance_id();
// Deserialize the full instance into type T
let deserialized: T = match instance.deserialize_model_card() {
let deserialized: T = match instance.deserialize_model() {
Ok(d) => d,
Err(e) => {
tracing::warn!(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment