Unverified Commit 5a0d710b authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(discovery): Use Store interface instead of etcd (#3887)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 6deeecb1
......@@ -14,7 +14,7 @@ use dynamo_runtime::{
network::egress::push_router::PushRouter,
},
protocols::{EndpointId, annotated::Annotated},
transports::etcd::WatchEvent,
storage::key_value_store::WatchEvent,
};
use crate::{
......@@ -105,31 +105,11 @@ impl ModelWatcher {
while let Some(event) = events_rx.recv().await {
match event {
WatchEvent::Put(kv) => {
let mut card = match serde_json::from_slice::<ModelDeploymentCard>(kv.value()) {
Ok(card) => card,
Err(err) => {
match kv.value_str() {
Ok(value) => {
tracing::error!(%err, value, "Invalid JSON in model card")
}
Err(value_str_err) => {
tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON")
}
}
continue;
}
};
let key = match kv.key_str() {
Ok(k) => k,
Err(err) => {
tracing::error!(%err, ?kv, "Invalid UTF-8 string in model card key, skipping");
continue;
}
};
let endpoint_id = match etcd_key_extract(key) {
let key = kv.key_str();
let endpoint_id = match key_extract(key) {
Ok((eid, _)) => eid,
Err(err) => {
tracing::error!(%key, model_name = card.name(), %err, "Failed extracting EndpointId from key. Ignoring instance.");
tracing::error!(%key, %err, "Failed extracting EndpointId from key. Ignoring instance.");
continue;
}
};
......@@ -142,12 +122,26 @@ impl ModelWatcher {
tracing::debug!(
model_namespace = endpoint_id.namespace,
target_namespace = target_ns,
model_name = card.name(),
"Skipping model from different namespace"
);
continue;
}
let mut card = match serde_json::from_slice::<ModelDeploymentCard>(kv.value()) {
Ok(card) => card,
Err(err) => {
match kv.value_str() {
Ok(value) => {
tracing::error!(%err, value, "Invalid JSON in model card")
}
Err(value_str_err) => {
tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON")
}
}
continue;
}
};
// If we already have a worker for this model, and the ModelDeploymentCard
// cards don't match, alert, and don't add the new instance
let can_add =
......@@ -190,10 +184,7 @@ impl ModelWatcher {
}
}
WatchEvent::Delete(kv) => {
let Ok(deleted_key) = kv.key_str() else {
tracing::warn!("Invalid UTF-8 in etcd delete notification key: {kv:?}");
continue;
};
let deleted_key = kv.key_str();
match self
.handle_delete(deleted_key, target_namespace, global_namespace)
.await
......@@ -304,7 +295,7 @@ impl ModelWatcher {
Ok(Some(model_name))
}
// Handles a PUT event from etcd, this usually means adding a new model to the list of served
// Handles a PUT event from store, this usually means adding a new model to the list of served
// models.
async fn handle_put(
&self,
......@@ -569,8 +560,6 @@ impl ModelWatcher {
/// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance
async fn all_cards(&self) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
let store = self.drt.store();
//let kvs = etcd_client.kv_get_prefix(model_card::ROOT_PATH).await?;
let Some(card_bucket) = store.get_bucket(model_card::ROOT_PATH).await? else {
// no cards
return Ok(vec![]);
......@@ -582,11 +571,11 @@ impl ModelWatcher {
let r = match serde_json::from_slice::<ModelDeploymentCard>(&card_bytes) {
Ok(card) => {
let maybe_endpoint_id =
etcd_key_extract(&key).map(|(endpoint_id, _instance_id)| endpoint_id);
key_extract(&key).map(|(endpoint_id, _instance_id)| endpoint_id);
let endpoint_id = match maybe_endpoint_id {
Ok(eid) => eid,
Err(err) => {
tracing::error!(%err, "Skipping invalid etcd key, not string or not EndpointId");
tracing::error!(%err, "Skipping invalid key, not string or not EndpointId");
continue;
}
};
......@@ -623,9 +612,9 @@ impl ModelWatcher {
}
}
/// The ModelDeploymentCard is published in etcd with a key like "v1/mdc/dynamo/backend/generate/694d9981145a61ad".
/// The ModelDeploymentCard is published in store with a key like "v1/mdc/dynamo/backend/generate/694d9981145a61ad".
/// Extract the EndpointId and instance_id from that.
fn etcd_key_extract(s: &str) -> anyhow::Result<(EndpointId, String)> {
fn key_extract(s: &str) -> anyhow::Result<(EndpointId, String)> {
if !s.starts_with(model_card::ROOT_PATH) {
anyhow::bail!("Invalid format: expected model card ROOT_PATH segment in {s}");
}
......@@ -649,12 +638,12 @@ mod tests {
use super::*;
#[test]
fn test_etcd_key_extract() {
fn test_key_extract() {
let input = format!(
"{}/dynamo/backend/generate/694d9981145a61ad",
model_card::ROOT_PATH
);
let (endpoint_id, _) = etcd_key_extract(&input).unwrap();
let (endpoint_id, _) = key_extract(&input).unwrap();
assert_eq!(endpoint_id.namespace, "dynamo");
assert_eq!(endpoint_id.component, "backend");
assert_eq!(endpoint_id.name, "generate");
......
......@@ -62,9 +62,7 @@ pub async fn prepare_engine(
EngineConfig::Dynamic(local_model) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let Some(etcd_client) = distributed_runtime.etcd_client() else {
anyhow::bail!("Cannot be both static mode and run with dynamic discovery.");
};
let store = Arc::new(distributed_runtime.store().clone());
let model_manager = Arc::new(ModelManager::new());
let watch_obj = Arc::new(ModelWatcher::new(
distributed_runtime,
......@@ -73,11 +71,7 @@ pub async fn prepare_engine(
None,
None,
));
let models_watcher = etcd_client
.kv_get_and_watch_prefix(model_card::ROOT_PATH)
.await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let (_, receiver) = store.watch(model_card::ROOT_PATH, None, runtime.primary_token());
let inner_watch_obj = watch_obj.clone();
let _watcher_task = tokio::spawn(async move {
inner_watch_obj.watch(receiver, None).await;
......@@ -100,9 +94,6 @@ pub async fn prepare_engine(
})
}
EngineConfig::StaticRemote(local_model) => {
// For now we only do ModelType.Backend
// For batch/text we only do Chat Completions
// The card should have been loaded at 'build' phase earlier
let card = local_model.card();
let router_mode = local_model.router_config().router_mode;
......
......@@ -16,27 +16,22 @@ use crate::{
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
};
use dynamo_runtime::transports::etcd;
use dynamo_runtime::{DistributedRuntime, Runtime};
use dynamo_runtime::{DistributedRuntime, Runtime, storage::key_value_store::KeyValueStoreManager};
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode};
/// Build and run an KServe gRPC service
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
let mut grpc_service_builder = kserve::KserveService::builder()
let grpc_service_builder = kserve::KserveService::builder()
.port(engine_config.local_model().http_port()) // [WIP] generalize port..
.with_request_template(engine_config.local_model().request_template());
let grpc_service = match engine_config {
EngineConfig::Dynamic(_) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let etcd_client = distributed_runtime.etcd_client();
// This allows the /health endpoint to query etcd for active instances
grpc_service_builder = grpc_service_builder.with_etcd_client(etcd_client.clone());
let store = Arc::new(distributed_runtime.store().clone());
let grpc_service = grpc_service_builder.build()?;
match etcd_client {
Some(ref etcd_client) => {
let router_config = engine_config.local_model().router_config();
// Listen for models registering themselves in etcd, add them to gRPC service
// Listen for models registering themselves, add them to gRPC service
let namespace = engine_config.local_model().namespace().unwrap_or("");
let target_namespace = if is_global_namespace(namespace) {
None
......@@ -46,19 +41,13 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
run_watcher(
distributed_runtime,
grpc_service.state().manager_clone(),
etcd_client.clone(),
model_card::ROOT_PATH,
store,
router_config.router_mode,
Some(router_config.kv_router_config),
router_config.busy_threshold,
target_namespace,
)
.await?;
}
None => {
// Static endpoints don't need discovery
}
}
grpc_service
}
EngineConfig::StaticRemote(local_model) => {
......@@ -173,19 +162,19 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
Ok(())
}
/// Spawns a task that watches for new models in etcd at network_prefix,
/// Spawns a task that watches for new models in store,
/// and registers them with the ModelManager so that the HTTP service can use them.
#[allow(clippy::too_many_arguments)]
async fn run_watcher(
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
etcd_client: etcd::Client,
network_prefix: &str,
store: Arc<KeyValueStoreManager>,
router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
target_namespace: Option<String>,
) -> anyhow::Result<()> {
let cancellation_token = runtime.primary_token();
let watch_obj = ModelWatcher::new(
runtime,
model_manager,
......@@ -193,9 +182,8 @@ async fn run_watcher(
kv_router_config,
busy_threshold,
);
tracing::info!("Watching for remote model at {network_prefix}");
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
tracing::debug!("Waiting for remote model");
let (_, receiver) = store.watch(model_card::ROOT_PATH, None, cancellation_token);
// [gluo NOTE] This is different from http::run_watcher where it alters the HTTP service
// endpoint being exposed, gRPC doesn't have the same concept as the KServe service
......
......@@ -17,7 +17,7 @@ use crate::{
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
};
use dynamo_runtime::transports::etcd;
use dynamo_runtime::storage::key_value_store::KeyValueStoreManager;
use dynamo_runtime::{DistributedRuntime, Runtime};
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode};
......@@ -64,14 +64,13 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let http_service = match engine_config {
EngineConfig::Dynamic(_) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
// This allows the /health endpoint to query etcd for active instances
// This allows the /health endpoint to query store for active instances
http_service_builder = http_service_builder.store(distributed_runtime.store().clone());
let http_service = http_service_builder.build()?;
let etcd_client = distributed_runtime.etcd_client();
match etcd_client {
Some(ref etcd_client) => {
let store = Arc::new(distributed_runtime.store().clone());
let router_config = engine_config.local_model().router_config();
// Listen for models registering themselves in etcd, add them to HTTP service
// Listen for models registering themselves, add them to HTTP service
// Check if we should filter by namespace (based on the local model's namespace)
// Get namespace from the model, fallback to endpoint_id namespace if not set
let namespace = engine_config.local_model().namespace().unwrap_or("");
......@@ -83,8 +82,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
run_watcher(
distributed_runtime,
http_service.state().manager_clone(),
etcd_client.clone(),
model_card::ROOT_PATH,
store,
router_config.router_mode,
Some(router_config.kv_router_config),
router_config.busy_threshold,
......@@ -93,11 +91,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
http_service.state().metrics_clone(),
)
.await?;
}
None => {
// Static endpoints don't need discovery
}
}
http_service
}
EngineConfig::StaticRemote(local_model) => {
......@@ -274,14 +267,13 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
Ok(())
}
/// Spawns a task that watches for new models in etcd at network_prefix,
/// Spawns a task that watches for new models in store,
/// and registers them with the ModelManager so that the HTTP service can use them.
#[allow(clippy::too_many_arguments)]
async fn run_watcher(
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
etcd_client: etcd::Client,
network_prefix: &str,
store: Arc<KeyValueStoreManager>,
router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
......@@ -289,6 +281,7 @@ async fn run_watcher(
http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>,
) -> anyhow::Result<()> {
let cancellation_token = runtime.primary_token();
let mut watch_obj = ModelWatcher::new(
runtime,
model_manager,
......@@ -296,13 +289,11 @@ async fn run_watcher(
kv_router_config,
busy_threshold,
);
tracing::info!("Watching for remote model at {network_prefix}");
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
tracing::debug!("Waiting for remote model");
let (_, receiver) = store.watch(model_card::ROOT_PATH, None, cancellation_token);
// Create a channel to receive model type updates
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
watch_obj.set_notify_on_model_update(tx);
// Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
......
......@@ -15,7 +15,6 @@ use crate::protocols::tensor::{NvCreateTensorRequest, NvCreateTensorResponse};
use crate::request_template::RequestTemplate;
use anyhow::Result;
use derive_builder::Builder;
use dynamo_runtime::transports::etcd;
use futures::pin_mut;
use tokio::task::JoinHandle;
use tokio_stream::{Stream, StreamExt};
......@@ -45,7 +44,6 @@ use inference::{
pub struct State {
metrics: Arc<Metrics>,
manager: Arc<ModelManager>,
etcd_client: Option<etcd::Client>,
}
impl State {
......@@ -53,15 +51,6 @@ impl State {
Self {
manager,
metrics: Arc::new(Metrics::default()),
etcd_client: None,
}
}
pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: etcd::Client) -> Self {
Self {
manager,
metrics: Arc::new(Metrics::default()),
etcd_client: Some(etcd_client),
}
}
......@@ -78,10 +67,6 @@ impl State {
self.manager.clone()
}
pub fn etcd_client(&self) -> Option<&etcd::Client> {
self.etcd_client.as_ref()
}
fn is_tensor_model(&self, model: &String) -> bool {
self.manager.list_tensor_models().contains(model)
}
......@@ -108,9 +93,6 @@ pub struct KserveServiceConfig {
#[builder(default = "None")]
request_template: Option<RequestTemplate>,
#[builder(default = "None")]
etcd_client: Option<etcd::Client>,
}
impl KserveService {
......@@ -155,10 +137,7 @@ impl KserveServiceConfigBuilder {
let config: KserveServiceConfig = self.build_internal()?;
let model_manager = Arc::new(ModelManager::new());
let state = match config.etcd_client {
Some(etcd_client) => Arc::new(State::new_with_etcd(model_manager, etcd_client)),
None => Arc::new(State::new(model_manager)),
};
let state = Arc::new(State::new(model_manager));
// enable prometheus metrics
let registry = metrics::Registry::new();
......@@ -176,11 +155,6 @@ impl KserveServiceConfigBuilder {
self.request_template = Some(request_template);
self
}
pub fn with_etcd_client(mut self, etcd_client: Option<etcd::Client>) -> Self {
self.etcd_client = Some(etcd_client);
self
}
}
#[tonic::async_trait]
......
......@@ -343,20 +343,18 @@ mod integration_tests {
None,
None,
);
// Start watching etcd for model registrations
if let Some(etcd_client) = distributed_runtime.etcd_client() {
let models_watcher = etcd_client
.kv_get_and_watch_prefix(model_card::ROOT_PATH)
.await
.unwrap();
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let store = Arc::new(distributed_runtime.store().clone());
let (_, receiver) = store.watch(
model_card::ROOT_PATH,
None,
distributed_runtime.primary_token(),
);
// Spawn watcher task to discover models from etcd
let _watcher_task = tokio::spawn(async move {
model_watcher.watch(receiver, None).await;
});
}
// Set up the engine following the StaticFull pattern from http.rs
let EngineConfig::StaticFull { engine, model, .. } = engine_config else {
......
......@@ -23,6 +23,8 @@ pub use nats::NATSStore;
mod etcd;
pub use etcd::EtcdStore;
const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(100);
/// A key that is safe to use directly in the KV store.
#[derive(Debug, Clone, PartialEq)]
pub struct Key(String);
......@@ -72,6 +74,22 @@ impl KeyValue {
pub fn new(key: String, value: bytes::Bytes) -> Self {
KeyValue { key, value }
}
pub fn key(&self) -> String {
self.key.clone()
}
pub fn key_str(&self) -> &str {
&self.key
}
pub fn value(&self) -> &[u8] {
&self.value
}
pub fn value_str(&self) -> anyhow::Result<&str> {
std::str::from_utf8(self.value()).map_err(From::from)
}
}
#[derive(Debug, Clone, PartialEq)]
......@@ -221,10 +239,10 @@ impl KeyValueStoreManager {
cancel_token: CancellationToken,
) -> (
tokio::task::JoinHandle<Result<(), StoreError>>,
tokio::sync::mpsc::UnboundedReceiver<WatchEvent>,
tokio::sync::mpsc::Receiver<WatchEvent>,
) {
let bucket_name = bucket_name.to_string();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let (tx, rx) = tokio::sync::mpsc::channel(128);
let watch_task = tokio::spawn(async move {
// Start listening for changes but don't poll this yet
let bucket = self
......@@ -235,7 +253,15 @@ impl KeyValueStoreManager {
// Send all the existing keys
for (key, bytes) in bucket.entries().await? {
let _ = tx.send(WatchEvent::Put(KeyValue::new(key, bytes)));
if let Err(err) = tx
.send_timeout(
WatchEvent::Put(KeyValue::new(key, bytes)),
WATCH_SEND_TIMEOUT,
)
.await
{
tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding existing key to channel");
}
}
// Now block waiting for new entries
......@@ -247,7 +273,9 @@ impl KeyValueStoreManager {
None => break,
}
};
let _ = tx.send(event);
if let Err(err) = tx.send_timeout(event, WATCH_SEND_TIMEOUT).await {
tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding new key to channel");
}
}
Ok::<(), StoreError>(())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment