Unverified Commit 5a0d710b authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(discovery): Use Store interface instead of etcd (#3887)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 6deeecb1
...@@ -14,7 +14,7 @@ use dynamo_runtime::{ ...@@ -14,7 +14,7 @@ use dynamo_runtime::{
network::egress::push_router::PushRouter, network::egress::push_router::PushRouter,
}, },
protocols::{EndpointId, annotated::Annotated}, protocols::{EndpointId, annotated::Annotated},
transports::etcd::WatchEvent, storage::key_value_store::WatchEvent,
}; };
use crate::{ use crate::{
...@@ -105,31 +105,11 @@ impl ModelWatcher { ...@@ -105,31 +105,11 @@ impl ModelWatcher {
while let Some(event) = events_rx.recv().await { while let Some(event) = events_rx.recv().await {
match event { match event {
WatchEvent::Put(kv) => { WatchEvent::Put(kv) => {
let mut card = match serde_json::from_slice::<ModelDeploymentCard>(kv.value()) { let key = kv.key_str();
Ok(card) => card, let endpoint_id = match key_extract(key) {
Err(err) => {
match kv.value_str() {
Ok(value) => {
tracing::error!(%err, value, "Invalid JSON in model card")
}
Err(value_str_err) => {
tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON")
}
}
continue;
}
};
let key = match kv.key_str() {
Ok(k) => k,
Err(err) => {
tracing::error!(%err, ?kv, "Invalid UTF-8 string in model card key, skipping");
continue;
}
};
let endpoint_id = match etcd_key_extract(key) {
Ok((eid, _)) => eid, Ok((eid, _)) => eid,
Err(err) => { Err(err) => {
tracing::error!(%key, model_name = card.name(), %err, "Failed extracting EndpointId from key. Ignoring instance."); tracing::error!(%key, %err, "Failed extracting EndpointId from key. Ignoring instance.");
continue; continue;
} }
}; };
...@@ -142,12 +122,26 @@ impl ModelWatcher { ...@@ -142,12 +122,26 @@ impl ModelWatcher {
tracing::debug!( tracing::debug!(
model_namespace = endpoint_id.namespace, model_namespace = endpoint_id.namespace,
target_namespace = target_ns, target_namespace = target_ns,
model_name = card.name(),
"Skipping model from different namespace" "Skipping model from different namespace"
); );
continue; continue;
} }
let mut card = match serde_json::from_slice::<ModelDeploymentCard>(kv.value()) {
Ok(card) => card,
Err(err) => {
match kv.value_str() {
Ok(value) => {
tracing::error!(%err, value, "Invalid JSON in model card")
}
Err(value_str_err) => {
tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON")
}
}
continue;
}
};
// If we already have a worker for this model, and the ModelDeploymentCard // If we already have a worker for this model, and the ModelDeploymentCard
// cards don't match, alert, and don't add the new instance // cards don't match, alert, and don't add the new instance
let can_add = let can_add =
...@@ -190,10 +184,7 @@ impl ModelWatcher { ...@@ -190,10 +184,7 @@ impl ModelWatcher {
} }
} }
WatchEvent::Delete(kv) => { WatchEvent::Delete(kv) => {
let Ok(deleted_key) = kv.key_str() else { let deleted_key = kv.key_str();
tracing::warn!("Invalid UTF-8 in etcd delete notification key: {kv:?}");
continue;
};
match self match self
.handle_delete(deleted_key, target_namespace, global_namespace) .handle_delete(deleted_key, target_namespace, global_namespace)
.await .await
...@@ -304,7 +295,7 @@ impl ModelWatcher { ...@@ -304,7 +295,7 @@ impl ModelWatcher {
Ok(Some(model_name)) Ok(Some(model_name))
} }
// Handles a PUT event from etcd, this usually means adding a new model to the list of served // Handles a PUT event from store, this usually means adding a new model to the list of served
// models. // models.
async fn handle_put( async fn handle_put(
&self, &self,
...@@ -569,8 +560,6 @@ impl ModelWatcher { ...@@ -569,8 +560,6 @@ impl ModelWatcher {
/// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance /// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance
async fn all_cards(&self) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> { async fn all_cards(&self) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
let store = self.drt.store(); let store = self.drt.store();
//let kvs = etcd_client.kv_get_prefix(model_card::ROOT_PATH).await?;
let Some(card_bucket) = store.get_bucket(model_card::ROOT_PATH).await? else { let Some(card_bucket) = store.get_bucket(model_card::ROOT_PATH).await? else {
// no cards // no cards
return Ok(vec![]); return Ok(vec![]);
...@@ -582,11 +571,11 @@ impl ModelWatcher { ...@@ -582,11 +571,11 @@ impl ModelWatcher {
let r = match serde_json::from_slice::<ModelDeploymentCard>(&card_bytes) { let r = match serde_json::from_slice::<ModelDeploymentCard>(&card_bytes) {
Ok(card) => { Ok(card) => {
let maybe_endpoint_id = let maybe_endpoint_id =
etcd_key_extract(&key).map(|(endpoint_id, _instance_id)| endpoint_id); key_extract(&key).map(|(endpoint_id, _instance_id)| endpoint_id);
let endpoint_id = match maybe_endpoint_id { let endpoint_id = match maybe_endpoint_id {
Ok(eid) => eid, Ok(eid) => eid,
Err(err) => { Err(err) => {
tracing::error!(%err, "Skipping invalid etcd key, not string or not EndpointId"); tracing::error!(%err, "Skipping invalid key, not string or not EndpointId");
continue; continue;
} }
}; };
...@@ -623,9 +612,9 @@ impl ModelWatcher { ...@@ -623,9 +612,9 @@ impl ModelWatcher {
} }
} }
/// The ModelDeploymentCard is published in etcd with a key like "v1/mdc/dynamo/backend/generate/694d9981145a61ad". /// The ModelDeploymentCard is published in store with a key like "v1/mdc/dynamo/backend/generate/694d9981145a61ad".
/// Extract the EndpointId and instance_id from that. /// Extract the EndpointId and instance_id from that.
fn etcd_key_extract(s: &str) -> anyhow::Result<(EndpointId, String)> { fn key_extract(s: &str) -> anyhow::Result<(EndpointId, String)> {
if !s.starts_with(model_card::ROOT_PATH) { if !s.starts_with(model_card::ROOT_PATH) {
anyhow::bail!("Invalid format: expected model card ROOT_PATH segment in {s}"); anyhow::bail!("Invalid format: expected model card ROOT_PATH segment in {s}");
} }
...@@ -649,12 +638,12 @@ mod tests { ...@@ -649,12 +638,12 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_etcd_key_extract() { fn test_key_extract() {
let input = format!( let input = format!(
"{}/dynamo/backend/generate/694d9981145a61ad", "{}/dynamo/backend/generate/694d9981145a61ad",
model_card::ROOT_PATH model_card::ROOT_PATH
); );
let (endpoint_id, _) = etcd_key_extract(&input).unwrap(); let (endpoint_id, _) = key_extract(&input).unwrap();
assert_eq!(endpoint_id.namespace, "dynamo"); assert_eq!(endpoint_id.namespace, "dynamo");
assert_eq!(endpoint_id.component, "backend"); assert_eq!(endpoint_id.component, "backend");
assert_eq!(endpoint_id.name, "generate"); assert_eq!(endpoint_id.name, "generate");
......
...@@ -62,9 +62,7 @@ pub async fn prepare_engine( ...@@ -62,9 +62,7 @@ pub async fn prepare_engine(
EngineConfig::Dynamic(local_model) => { EngineConfig::Dynamic(local_model) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let Some(etcd_client) = distributed_runtime.etcd_client() else { let store = Arc::new(distributed_runtime.store().clone());
anyhow::bail!("Cannot be both static mode and run with dynamic discovery.");
};
let model_manager = Arc::new(ModelManager::new()); let model_manager = Arc::new(ModelManager::new());
let watch_obj = Arc::new(ModelWatcher::new( let watch_obj = Arc::new(ModelWatcher::new(
distributed_runtime, distributed_runtime,
...@@ -73,11 +71,7 @@ pub async fn prepare_engine( ...@@ -73,11 +71,7 @@ pub async fn prepare_engine(
None, None,
None, None,
)); ));
let models_watcher = etcd_client let (_, receiver) = store.watch(model_card::ROOT_PATH, None, runtime.primary_token());
.kv_get_and_watch_prefix(model_card::ROOT_PATH)
.await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let inner_watch_obj = watch_obj.clone(); let inner_watch_obj = watch_obj.clone();
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
inner_watch_obj.watch(receiver, None).await; inner_watch_obj.watch(receiver, None).await;
...@@ -100,9 +94,6 @@ pub async fn prepare_engine( ...@@ -100,9 +94,6 @@ pub async fn prepare_engine(
}) })
} }
EngineConfig::StaticRemote(local_model) => { EngineConfig::StaticRemote(local_model) => {
// For now we only do ModelType.Backend
// For batch/text we only do Chat Completions
// The card should have been loaded at 'build' phase earlier // The card should have been loaded at 'build' phase earlier
let card = local_model.card(); let card = local_model.card();
let router_mode = local_model.router_config().router_mode; let router_mode = local_model.router_config().router_mode;
......
...@@ -16,27 +16,22 @@ use crate::{ ...@@ -16,27 +16,22 @@ use crate::{
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
}, },
}; };
use dynamo_runtime::transports::etcd; use dynamo_runtime::{DistributedRuntime, Runtime, storage::key_value_store::KeyValueStoreManager};
use dynamo_runtime::{DistributedRuntime, Runtime};
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode}; use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode};
/// Build and run an KServe gRPC service /// Build and run an KServe gRPC service
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> { pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
let mut grpc_service_builder = kserve::KserveService::builder() let grpc_service_builder = kserve::KserveService::builder()
.port(engine_config.local_model().http_port()) // [WIP] generalize port.. .port(engine_config.local_model().http_port()) // [WIP] generalize port..
.with_request_template(engine_config.local_model().request_template()); .with_request_template(engine_config.local_model().request_template());
let grpc_service = match engine_config { let grpc_service = match engine_config {
EngineConfig::Dynamic(_) => { EngineConfig::Dynamic(_) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let etcd_client = distributed_runtime.etcd_client(); let store = Arc::new(distributed_runtime.store().clone());
// This allows the /health endpoint to query etcd for active instances
grpc_service_builder = grpc_service_builder.with_etcd_client(etcd_client.clone());
let grpc_service = grpc_service_builder.build()?; let grpc_service = grpc_service_builder.build()?;
match etcd_client {
Some(ref etcd_client) => {
let router_config = engine_config.local_model().router_config(); let router_config = engine_config.local_model().router_config();
// Listen for models registering themselves in etcd, add them to gRPC service // Listen for models registering themselves, add them to gRPC service
let namespace = engine_config.local_model().namespace().unwrap_or(""); let namespace = engine_config.local_model().namespace().unwrap_or("");
let target_namespace = if is_global_namespace(namespace) { let target_namespace = if is_global_namespace(namespace) {
None None
...@@ -46,19 +41,13 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -46,19 +41,13 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
run_watcher( run_watcher(
distributed_runtime, distributed_runtime,
grpc_service.state().manager_clone(), grpc_service.state().manager_clone(),
etcd_client.clone(), store,
model_card::ROOT_PATH,
router_config.router_mode, router_config.router_mode,
Some(router_config.kv_router_config), Some(router_config.kv_router_config),
router_config.busy_threshold, router_config.busy_threshold,
target_namespace, target_namespace,
) )
.await?; .await?;
}
None => {
// Static endpoints don't need discovery
}
}
grpc_service grpc_service
} }
EngineConfig::StaticRemote(local_model) => { EngineConfig::StaticRemote(local_model) => {
...@@ -173,19 +162,19 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -173,19 +162,19 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
Ok(()) Ok(())
} }
/// Spawns a task that watches for new models in etcd at network_prefix, /// Spawns a task that watches for new models in store,
/// and registers them with the ModelManager so that the HTTP service can use them. /// and registers them with the ModelManager so that the HTTP service can use them.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
async fn run_watcher( async fn run_watcher(
runtime: DistributedRuntime, runtime: DistributedRuntime,
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
etcd_client: etcd::Client, store: Arc<KeyValueStoreManager>,
network_prefix: &str,
router_mode: RouterMode, router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>, busy_threshold: Option<f64>,
target_namespace: Option<String>, target_namespace: Option<String>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancellation_token = runtime.primary_token();
let watch_obj = ModelWatcher::new( let watch_obj = ModelWatcher::new(
runtime, runtime,
model_manager, model_manager,
...@@ -193,9 +182,8 @@ async fn run_watcher( ...@@ -193,9 +182,8 @@ async fn run_watcher(
kv_router_config, kv_router_config,
busy_threshold, busy_threshold,
); );
tracing::info!("Watching for remote model at {network_prefix}"); tracing::debug!("Waiting for remote model");
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?; let (_, receiver) = store.watch(model_card::ROOT_PATH, None, cancellation_token);
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
// [gluo NOTE] This is different from http::run_watcher where it alters the HTTP service // [gluo NOTE] This is different from http::run_watcher where it alters the HTTP service
// endpoint being exposed, gRPC doesn't have the same concept as the KServe service // endpoint being exposed, gRPC doesn't have the same concept as the KServe service
......
...@@ -17,7 +17,7 @@ use crate::{ ...@@ -17,7 +17,7 @@ use crate::{
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
}, },
}; };
use dynamo_runtime::transports::etcd; use dynamo_runtime::storage::key_value_store::KeyValueStoreManager;
use dynamo_runtime::{DistributedRuntime, Runtime}; use dynamo_runtime::{DistributedRuntime, Runtime};
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode}; use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode};
...@@ -64,14 +64,13 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -64,14 +64,13 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let http_service = match engine_config { let http_service = match engine_config {
EngineConfig::Dynamic(_) => { EngineConfig::Dynamic(_) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
// This allows the /health endpoint to query etcd for active instances // This allows the /health endpoint to query store for active instances
http_service_builder = http_service_builder.store(distributed_runtime.store().clone()); http_service_builder = http_service_builder.store(distributed_runtime.store().clone());
let http_service = http_service_builder.build()?; let http_service = http_service_builder.build()?;
let etcd_client = distributed_runtime.etcd_client(); let store = Arc::new(distributed_runtime.store().clone());
match etcd_client {
Some(ref etcd_client) => {
let router_config = engine_config.local_model().router_config(); let router_config = engine_config.local_model().router_config();
// Listen for models registering themselves in etcd, add them to HTTP service // Listen for models registering themselves, add them to HTTP service
// Check if we should filter by namespace (based on the local model's namespace) // Check if we should filter by namespace (based on the local model's namespace)
// Get namespace from the model, fallback to endpoint_id namespace if not set // Get namespace from the model, fallback to endpoint_id namespace if not set
let namespace = engine_config.local_model().namespace().unwrap_or(""); let namespace = engine_config.local_model().namespace().unwrap_or("");
...@@ -83,8 +82,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -83,8 +82,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
run_watcher( run_watcher(
distributed_runtime, distributed_runtime,
http_service.state().manager_clone(), http_service.state().manager_clone(),
etcd_client.clone(), store,
model_card::ROOT_PATH,
router_config.router_mode, router_config.router_mode,
Some(router_config.kv_router_config), Some(router_config.kv_router_config),
router_config.busy_threshold, router_config.busy_threshold,
...@@ -93,11 +91,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -93,11 +91,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
http_service.state().metrics_clone(), http_service.state().metrics_clone(),
) )
.await?; .await?;
}
None => {
// Static endpoints don't need discovery
}
}
http_service http_service
} }
EngineConfig::StaticRemote(local_model) => { EngineConfig::StaticRemote(local_model) => {
...@@ -274,14 +267,13 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -274,14 +267,13 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
Ok(()) Ok(())
} }
/// Spawns a task that watches for new models in etcd at network_prefix, /// Spawns a task that watches for new models in store,
/// and registers them with the ModelManager so that the HTTP service can use them. /// and registers them with the ModelManager so that the HTTP service can use them.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
async fn run_watcher( async fn run_watcher(
runtime: DistributedRuntime, runtime: DistributedRuntime,
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
etcd_client: etcd::Client, store: Arc<KeyValueStoreManager>,
network_prefix: &str,
router_mode: RouterMode, router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>, busy_threshold: Option<f64>,
...@@ -289,6 +281,7 @@ async fn run_watcher( ...@@ -289,6 +281,7 @@ async fn run_watcher(
http_service: Arc<HttpService>, http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>, metrics: Arc<crate::http::service::metrics::Metrics>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancellation_token = runtime.primary_token();
let mut watch_obj = ModelWatcher::new( let mut watch_obj = ModelWatcher::new(
runtime, runtime,
model_manager, model_manager,
...@@ -296,13 +289,11 @@ async fn run_watcher( ...@@ -296,13 +289,11 @@ async fn run_watcher(
kv_router_config, kv_router_config,
busy_threshold, busy_threshold,
); );
tracing::info!("Watching for remote model at {network_prefix}"); tracing::debug!("Waiting for remote model");
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?; let (_, receiver) = store.watch(model_card::ROOT_PATH, None, cancellation_token);
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
// Create a channel to receive model type updates // Create a channel to receive model type updates
let (tx, mut rx) = tokio::sync::mpsc::channel(32); let (tx, mut rx) = tokio::sync::mpsc::channel(32);
watch_obj.set_notify_on_model_update(tx); watch_obj.set_notify_on_model_update(tx);
// Spawn a task to watch for model type changes and update HTTP service endpoints and metrics // Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
......
...@@ -15,7 +15,6 @@ use crate::protocols::tensor::{NvCreateTensorRequest, NvCreateTensorResponse}; ...@@ -15,7 +15,6 @@ use crate::protocols::tensor::{NvCreateTensorRequest, NvCreateTensorResponse};
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use anyhow::Result; use anyhow::Result;
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::transports::etcd;
use futures::pin_mut; use futures::pin_mut;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio_stream::{Stream, StreamExt}; use tokio_stream::{Stream, StreamExt};
...@@ -45,7 +44,6 @@ use inference::{ ...@@ -45,7 +44,6 @@ use inference::{
pub struct State { pub struct State {
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
manager: Arc<ModelManager>, manager: Arc<ModelManager>,
etcd_client: Option<etcd::Client>,
} }
impl State { impl State {
...@@ -53,15 +51,6 @@ impl State { ...@@ -53,15 +51,6 @@ impl State {
Self { Self {
manager, manager,
metrics: Arc::new(Metrics::default()), metrics: Arc::new(Metrics::default()),
etcd_client: None,
}
}
pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: etcd::Client) -> Self {
Self {
manager,
metrics: Arc::new(Metrics::default()),
etcd_client: Some(etcd_client),
} }
} }
...@@ -78,10 +67,6 @@ impl State { ...@@ -78,10 +67,6 @@ impl State {
self.manager.clone() self.manager.clone()
} }
pub fn etcd_client(&self) -> Option<&etcd::Client> {
self.etcd_client.as_ref()
}
fn is_tensor_model(&self, model: &String) -> bool { fn is_tensor_model(&self, model: &String) -> bool {
self.manager.list_tensor_models().contains(model) self.manager.list_tensor_models().contains(model)
} }
...@@ -108,9 +93,6 @@ pub struct KserveServiceConfig { ...@@ -108,9 +93,6 @@ pub struct KserveServiceConfig {
#[builder(default = "None")] #[builder(default = "None")]
request_template: Option<RequestTemplate>, request_template: Option<RequestTemplate>,
#[builder(default = "None")]
etcd_client: Option<etcd::Client>,
} }
impl KserveService { impl KserveService {
...@@ -155,10 +137,7 @@ impl KserveServiceConfigBuilder { ...@@ -155,10 +137,7 @@ impl KserveServiceConfigBuilder {
let config: KserveServiceConfig = self.build_internal()?; let config: KserveServiceConfig = self.build_internal()?;
let model_manager = Arc::new(ModelManager::new()); let model_manager = Arc::new(ModelManager::new());
let state = match config.etcd_client { let state = Arc::new(State::new(model_manager));
Some(etcd_client) => Arc::new(State::new_with_etcd(model_manager, etcd_client)),
None => Arc::new(State::new(model_manager)),
};
// enable prometheus metrics // enable prometheus metrics
let registry = metrics::Registry::new(); let registry = metrics::Registry::new();
...@@ -176,11 +155,6 @@ impl KserveServiceConfigBuilder { ...@@ -176,11 +155,6 @@ impl KserveServiceConfigBuilder {
self.request_template = Some(request_template); self.request_template = Some(request_template);
self self
} }
pub fn with_etcd_client(mut self, etcd_client: Option<etcd::Client>) -> Self {
self.etcd_client = Some(etcd_client);
self
}
} }
#[tonic::async_trait] #[tonic::async_trait]
......
...@@ -343,20 +343,18 @@ mod integration_tests { ...@@ -343,20 +343,18 @@ mod integration_tests {
None, None,
None, None,
); );
// Start watching etcd for model registrations // Start watching etcd for model registrations
if let Some(etcd_client) = distributed_runtime.etcd_client() { let store = Arc::new(distributed_runtime.store().clone());
let models_watcher = etcd_client let (_, receiver) = store.watch(
.kv_get_and_watch_prefix(model_card::ROOT_PATH) model_card::ROOT_PATH,
.await None,
.unwrap(); distributed_runtime.primary_token(),
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); );
// Spawn watcher task to discover models from etcd // Spawn watcher task to discover models from etcd
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
model_watcher.watch(receiver, None).await; model_watcher.watch(receiver, None).await;
}); });
}
// Set up the engine following the StaticFull pattern from http.rs // Set up the engine following the StaticFull pattern from http.rs
let EngineConfig::StaticFull { engine, model, .. } = engine_config else { let EngineConfig::StaticFull { engine, model, .. } = engine_config else {
......
...@@ -23,6 +23,8 @@ pub use nats::NATSStore; ...@@ -23,6 +23,8 @@ pub use nats::NATSStore;
mod etcd; mod etcd;
pub use etcd::EtcdStore; pub use etcd::EtcdStore;
const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(100);
/// A key that is safe to use directly in the KV store. /// A key that is safe to use directly in the KV store.
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct Key(String); pub struct Key(String);
...@@ -72,6 +74,22 @@ impl KeyValue { ...@@ -72,6 +74,22 @@ impl KeyValue {
pub fn new(key: String, value: bytes::Bytes) -> Self { pub fn new(key: String, value: bytes::Bytes) -> Self {
KeyValue { key, value } KeyValue { key, value }
} }
pub fn key(&self) -> String {
self.key.clone()
}
pub fn key_str(&self) -> &str {
&self.key
}
pub fn value(&self) -> &[u8] {
&self.value
}
pub fn value_str(&self) -> anyhow::Result<&str> {
std::str::from_utf8(self.value()).map_err(From::from)
}
} }
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
...@@ -221,10 +239,10 @@ impl KeyValueStoreManager { ...@@ -221,10 +239,10 @@ impl KeyValueStoreManager {
cancel_token: CancellationToken, cancel_token: CancellationToken,
) -> ( ) -> (
tokio::task::JoinHandle<Result<(), StoreError>>, tokio::task::JoinHandle<Result<(), StoreError>>,
tokio::sync::mpsc::UnboundedReceiver<WatchEvent>, tokio::sync::mpsc::Receiver<WatchEvent>,
) { ) {
let bucket_name = bucket_name.to_string(); let bucket_name = bucket_name.to_string();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let (tx, rx) = tokio::sync::mpsc::channel(128);
let watch_task = tokio::spawn(async move { let watch_task = tokio::spawn(async move {
// Start listening for changes but don't poll this yet // Start listening for changes but don't poll this yet
let bucket = self let bucket = self
...@@ -235,7 +253,15 @@ impl KeyValueStoreManager { ...@@ -235,7 +253,15 @@ impl KeyValueStoreManager {
// Send all the existing keys // Send all the existing keys
for (key, bytes) in bucket.entries().await? { for (key, bytes) in bucket.entries().await? {
let _ = tx.send(WatchEvent::Put(KeyValue::new(key, bytes))); if let Err(err) = tx
.send_timeout(
WatchEvent::Put(KeyValue::new(key, bytes)),
WATCH_SEND_TIMEOUT,
)
.await
{
tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding existing key to channel");
}
} }
// Now block waiting for new entries // Now block waiting for new entries
...@@ -247,7 +273,9 @@ impl KeyValueStoreManager { ...@@ -247,7 +273,9 @@ impl KeyValueStoreManager {
None => break, None => break,
} }
}; };
let _ = tx.send(event); if let Err(err) = tx.send_timeout(event, WATCH_SEND_TIMEOUT).await {
tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding new key to channel");
}
} }
Ok::<(), StoreError>(()) Ok::<(), StoreError>(())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment