Unverified Commit 50cdae5f authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Add Key abstraction in our KeyValueStore (#3322)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent cf3ac5b6
......@@ -14,6 +14,7 @@ use dynamo_runtime::{
network::egress::push_router::PushRouter,
},
protocols::annotated::Annotated,
storage::key_value_store::Key,
transports::etcd::{KeyValue, WatchEvent},
};
......@@ -258,7 +259,12 @@ impl ModelWatcher {
.component(&endpoint_id.component)?;
let client = component.endpoint(&endpoint_id.name).client().await?;
let model_slug = model_entry.slug();
let card = match ModelDeploymentCard::load_from_store(&model_slug, &self.drt).await {
let card = match ModelDeploymentCard::load_from_store(
&Key::from_raw(model_slug.to_string()),
&self.drt,
)
.await
{
Ok(Some(mut card)) => {
tracing::debug!(card.display_name, "adding model");
// Ensure runtime_config is populated
......
......@@ -239,10 +239,7 @@ async fn run_watcher(
// Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
let _endpoint_enabler_task = tokio::spawn(async move {
while let Some(model_update) = rx.recv().await {
// Update HTTP endpoints (existing functionality)
update_http_endpoints(http_service.clone(), model_update.clone());
// Update metrics (only for added models)
update_model_metrics(model_update, metrics.clone());
}
});
......
......@@ -8,6 +8,7 @@ use std::sync::Arc;
use anyhow::Context as _;
use dynamo_runtime::protocols::EndpointId;
use dynamo_runtime::slug::Slug;
use dynamo_runtime::storage::key_value_store::Key;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::{
component::Endpoint,
......@@ -417,13 +418,21 @@ impl LocalModel {
let nats_client = endpoint.drt().nats_client();
self.card.move_to_nats(nats_client.clone()).await?;
// Publish the Model Deployment Card to etcd
// Publish the Model Deployment Card to KV store
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let key = self.card.slug().to_string();
// TODO: Next PR will use this
//let lease_id = endpoint.drt().primary_lease().map(|l| l.id()).unwrap_or(0);
//let key = Key::from_raw(endpoint.unique_path(lease_id));
card_store
.publish(model_card::ROOT_PATH, None, &key, &mut self.card)
.publish(
model_card::ROOT_PATH,
None,
&Key::from_raw(key),
&mut self.card,
)
.await?;
// Publish our ModelEntry to etcd. This allows ingress to find the model card.
......
......@@ -23,7 +23,9 @@ use crate::model_type::{ModelInput, ModelType};
use anyhow::{Context, Result};
use derive_builder::Builder;
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager};
use dynamo_runtime::storage::key_value_store::{
EtcdStorage, Key, KeyValueStore, KeyValueStoreManager,
};
use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats};
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer as HfTokenizer;
......@@ -394,7 +396,7 @@ impl ModelDeploymentCard {
/// Load a ModelDeploymentCard from storage the DistributedRuntime is configured to use.
/// Card should be fully local and ready to use when the call returns.
pub async fn load_from_store(
model_slug: &Slug,
mdc_key: &Key,
drt: &DistributedRuntime,
) -> anyhow::Result<Option<Self>> {
let Some(etcd_client) = drt.etcd_client() else {
......@@ -404,7 +406,7 @@ impl ModelDeploymentCard {
let store: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client));
let card_store = Arc::new(KeyValueStoreManager::new(store));
let Some(mut card) = card_store
.load::<ModelDeploymentCard>(ROOT_PATH, model_slug)
.load::<ModelDeploymentCard>(ROOT_PATH, mdc_key)
.await?
else {
return Ok(None);
......
......@@ -225,8 +225,8 @@ impl Component {
&self.namespace
}
pub fn name(&self) -> String {
self.name.clone()
pub fn name(&self) -> &str {
&self.name
}
pub fn labels(&self) -> &[(String, String)] {
......@@ -457,7 +457,7 @@ impl Endpoint {
pub fn etcd_path(&self) -> EtcdPath {
EtcdPath::new_endpoint(
&self.component.namespace().name(),
&self.component.name(),
self.component.name(),
&self.name,
)
.expect("Endpoint name and component name should be valid")
......@@ -465,12 +465,15 @@ impl Endpoint {
/// The fully path of an instance in etcd
pub fn etcd_path_with_lease_id(&self, lease_id: i64) -> String {
let endpoint_root = self.etcd_root();
if self.is_static {
endpoint_root
} else {
format!("{endpoint_root}:{lease_id:x}")
}
format!("{INSTANCE_ROOT_PATH}/{}", self.unique_path(lease_id))
}
/// Full path of this endpoint with forward slash separators, including lease id
pub fn unique_path(&self, lease_id: i64) -> String {
let ns = self.component.namespace().name();
let cp = self.component.name();
let ep = self.name();
format!("{ns}/{cp}/{ep}/{lease_id:x}")
}
/// The endpoint as an EtcdPath object with lease ID
......@@ -480,7 +483,7 @@ impl Endpoint {
} else {
EtcdPath::new_endpoint_with_lease(
&self.component.namespace().name(),
&self.component.name(),
self.component.name(),
&self.name,
lease_id,
)
......
......@@ -425,8 +425,8 @@ mod integration_tests {
let request_timeout = Duration::from_secs(3);
let config = HealthCheckConfig {
canary_wait_time: canary_wait_time,
request_timeout: request_timeout,
canary_wait_time,
request_timeout,
};
let manager = HealthCheckManager::new(drt.clone(), config);
......
......@@ -23,6 +23,45 @@ pub use nats::NATSStorage;
mod etcd;
pub use etcd::EtcdStorage;
/// A key that is safe to use directly in the KV store.
#[derive(Debug, Clone, PartialEq)]
pub struct Key(String);
impl Key {
pub fn new(s: &str) -> Key {
Key(Slug::slugify(s).to_string())
}
/// Create a Key without changing the string, it is assumed already KV store safe.
pub fn from_raw(s: String) -> Key {
Key(s)
}
}
impl From<&str> for Key {
fn from(s: &str) -> Key {
Key::new(s)
}
}
impl fmt::Display for Key {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl AsRef<str> for Key {
fn as_ref(&self) -> &str {
&self.0
}
}
impl From<&Key> for String {
fn from(k: &Key) -> String {
k.0.clone()
}
}
#[async_trait]
pub trait KeyValueStore: Send + Sync {
async fn get_or_create_bucket(
......@@ -48,13 +87,13 @@ impl KeyValueStoreManager {
pub async fn load<T: for<'a> Deserialize<'a>>(
&self,
bucket: &str,
key: &Slug,
key: &Key,
) -> Result<Option<T>, StorageError> {
let Some(bucket) = self.0.get_bucket(bucket).await? else {
// No bucket means no cards
return Ok(None);
};
match bucket.get(key.as_ref()).await {
match bucket.get(key).await {
Ok(Some(card_bytes)) => {
let card: T = serde_json::from_slice(card_bytes.as_ref())?;
Ok(Some(card))
......@@ -109,15 +148,13 @@ impl KeyValueStoreManager {
&self,
bucket_name: &str,
bucket_ttl: Option<Duration>,
key: &str,
key: &Key,
obj: &mut T,
) -> anyhow::Result<StorageOutcome> {
let obj_json = serde_json::to_string(obj)?;
let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;
let outcome = bucket
.insert(key.to_string(), obj_json, obj.revision())
.await?;
let outcome = bucket.insert(key, &obj_json, obj.revision()).await?;
match outcome {
StorageOutcome::Created(revision) | StorageOutcome::Exists(revision) => {
......@@ -126,59 +163,6 @@ impl KeyValueStoreManager {
}
Ok(outcome)
}
/// Re-publish the model card to the store regularly. Spawns a task and returns.
/// Takes most arguments by value because it will hold on to them in the publish task.
/// Deletes the card on cancellation.
pub fn publish_until_cancelled<T: Serialize + Versioned + Send + Sync + 'static>(
self: Arc<Self>,
cancel_token: CancellationToken,
bucket_name: String,
bucket_ttl: Option<Duration>,
publish_interval: Duration,
key: String,
mut obj: T,
) {
tokio::spawn(async move {
loop {
let publish_result = self
.clone()
.publish(&bucket_name, bucket_ttl, &key, &mut obj)
.await;
if let Err(err) = publish_result {
tracing::error!(
model = key,
error = %err,
"Failed publishing to KV storage. Ending publish task.",
);
}
tokio::select! {
_ = tokio::time::sleep(publish_interval) => {},
_ = cancel_token.cancelled() => {
tracing::trace!(model_service_name = key, "Publish loop cancelled");
match self.0.get_bucket(&bucket_name).await {
Ok(Some(bucket)) => {
if let Err(err) = bucket.delete(&key).await {
// This is usually expected, our NATS connection is closed
tracing::trace!(bucket_name, key, %err, "Error delete published card from NATS on publish stop");
}
tracing::trace!(bucket_name, key, "Deleted Model Deployment Card from NATS");
}
Ok(None) => {
tracing::trace!(bucket_name, key, "Bucket does not exist");
}
Err(err) => {
tracing::trace!(bucket_name, %err, "publish_until_cancelled shutdown error");
}
}
// Stop publishing
break;
}
}
}
});
}
}
/// An online storage for key-value config values.
......@@ -189,16 +173,16 @@ pub trait KeyValueBucket: Send {
/// Insert a value into a bucket, if it doesn't exist already
async fn insert(
&self,
key: String,
value: String,
key: &Key,
value: &str,
revision: u64,
) -> Result<StorageOutcome, StorageError>;
/// Fetch an item from the key-value storage
async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError>;
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError>;
/// Delete an item from the bucket
async fn delete(&self, key: &str) -> Result<(), StorageError>;
async fn delete(&self, key: &Key) -> Result<(), StorageError>;
/// A stream of items inserted into the bucket.
/// Every time the stream is polled it will either return a newly created entry, or block until
......@@ -311,9 +295,7 @@ mod tests {
let s2 = Arc::clone(&s);
let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
let res = bucket
.insert("test1".to_string(), "value1".to_string(), 0)
.await?;
let res = bucket.insert(&"test1".into(), "value1", 0).await?;
assert_eq!(res, StorageOutcome::Created(0));
let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
......@@ -341,26 +323,18 @@ mod tests {
// wouldn't be testing the watch behavior.
got_first_rx.await?;
let res = bucket
.insert("test2".to_string(), "value2".to_string(), 0)
.await?;
let res = bucket.insert(&"test2".into(), "value2", 0).await?;
assert_eq!(res, StorageOutcome::Created(0));
// Repeat a key and revision. Ignored.
let res = bucket
.insert("test2".to_string(), "value2".to_string(), 0)
.await?;
let res = bucket.insert(&"test2".into(), "value2", 0).await?;
assert_eq!(res, StorageOutcome::Exists(0));
// Increment revision
let res = bucket
.insert("test2".to_string(), "value2".to_string(), 1)
.await?;
let res = bucket.insert(&"test2".into(), "value2", 1).await?;
assert_eq!(res, StorageOutcome::Created(1));
let res = bucket
.insert("test3".to_string(), "value3".to_string(), 0)
.await?;
let res = bucket.insert(&"test3".into(), "value3", 0).await?;
assert_eq!(res, StorageOutcome::Created(0));
// ingress exits once it has received all values
......@@ -377,9 +351,7 @@ mod tests {
let bucket: &'static _ =
Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));
let res = bucket
.insert("test1".to_string(), "value1".to_string(), 0)
.await?;
let res = bucket.insert(&"test1".into(), "value1", 0).await?;
assert_eq!(res, StorageOutcome::Created(0));
let stream = bucket.watch().await?;
......@@ -397,9 +369,7 @@ mod tests {
assert_eq!(b, bytes::Bytes::from(vec![b'G', b'K']));
});
bucket
.insert("test1".to_string(), "GK".to_string(), 1)
.await?;
bucket.insert(&"test1".into(), "GK", 1).await?;
let _ = futures::join!(handle1, handle2);
Ok(())
......
......@@ -5,7 +5,7 @@ use std::collections::HashMap;
use std::pin::Pin;
use std::time::Duration;
use crate::{slug::Slug, transports::etcd::Client};
use crate::{slug::Slug, storage::key_value_store::Key, transports::etcd::Client};
use async_stream::stream;
use async_trait::async_trait;
use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
......@@ -56,21 +56,21 @@ pub struct EtcdBucket {
impl KeyValueBucket for EtcdBucket {
async fn insert(
&self,
key: String,
value: String,
key: &Key,
value: &str,
// "version" in etcd speak. revision is a global cluster-wide value
revision: u64,
) -> Result<StorageOutcome, StorageError> {
let version = revision;
if version == 0 {
self.create(&key, &value).await
self.create(key, value).await
} else {
self.update(&key, &value, version).await
self.update(key, value, version).await
}
}
async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> {
let k = make_key(&self.bucket_name, key);
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError> {
let k = format!("{}/{key}", self.bucket_name);
tracing::trace!("etcd get: {k}");
let mut kvs = self
......@@ -85,10 +85,10 @@ impl KeyValueBucket for EtcdBucket {
Ok(Some(val.into()))
}
async fn delete(&self, key: &str) -> Result<(), StorageError> {
async fn delete(&self, key: &Key) -> Result<(), StorageError> {
let _ = self
.client
.kv_delete(key, None)
.kv_delete(key.0.clone(), None)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
Ok(())
......@@ -98,7 +98,7 @@ impl KeyValueBucket for EtcdBucket {
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>
{
let k = make_key(&self.bucket_name, "");
let k = make_key(&self.bucket_name, &"".into());
tracing::trace!("etcd watch: {k}");
let (_watcher, mut watch_stream) = self
.client
......@@ -121,7 +121,7 @@ impl KeyValueBucket for EtcdBucket {
}
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> {
let k = make_key(&self.bucket_name, "");
let k = make_key(&self.bucket_name, &"".into());
tracing::trace!("etcd entries: {k}");
let resp = self
......@@ -142,7 +142,7 @@ impl KeyValueBucket for EtcdBucket {
}
impl EtcdBucket {
async fn create(&self, key: &str, value: &str) -> Result<StorageOutcome, StorageError> {
async fn create(&self, key: &Key, value: &str) -> Result<StorageOutcome, StorageError> {
let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd create: {k}");
......@@ -187,7 +187,7 @@ impl EtcdBucket {
async fn update(
&self,
key: &str,
key: &Key,
value: &str,
revision: u64,
) -> Result<StorageOutcome, StorageError> {
......@@ -208,7 +208,7 @@ impl EtcdBucket {
tracing::warn!(
current_version,
attempted_next_version = version,
key,
%key,
"update: Wrong revision"
);
// NATS does a resync_update, overwriting the key anyway and getting the new revision.
......@@ -234,12 +234,8 @@ impl EtcdBucket {
}
}
fn make_key(bucket_name: &str, key: &str) -> String {
[
Slug::slugify(bucket_name).to_string(),
Slug::slugify(key).to_string(),
]
.join("/")
fn make_key(bucket_name: &str, key: &Key) -> String {
[Slug::slugify(bucket_name).to_string(), key.to_string()].join("/")
}
#[cfg(feature = "integration")]
......@@ -278,7 +274,7 @@ mod concurrent_create_tests {
let barrier = Arc::new(Barrier::new(num_workers));
// Shared test data
let test_key = format!("concurrent_test_key_{}", uuid::Uuid::new_v4());
let test_key: Key = Key::new(&format!("concurrent_test_key_{}", uuid::Uuid::new_v4()));
let test_value = "test_value";
// Spawn multiple tasks that will all try to create the same key simultaneously
......@@ -302,7 +298,7 @@ mod concurrent_create_tests {
let result = bucket_clone
.lock()
.await
.insert(key_clone, value_clone, 0)
.insert(&key_clone, &value_clone, 0)
.await;
match result {
......
......@@ -11,6 +11,8 @@ use async_trait::async_trait;
use tokio::sync::Mutex;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use crate::storage::key_value_store::Key;
use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
#[derive(Clone)]
......@@ -100,8 +102,8 @@ impl KeyValueStore for MemoryStorage {
impl KeyValueBucket for MemoryBucketRef {
async fn insert(
&self,
key: String,
value: String,
key: &Key,
value: &str,
revision: u64,
) -> Result<StorageOutcome, StorageError> {
let mut locked_data = self.inner.data.lock().await;
......@@ -111,8 +113,11 @@ impl KeyValueBucket for MemoryBucketRef {
};
let outcome = match bucket.data.entry(key.to_string()) {
Entry::Vacant(e) => {
e.insert((revision, value.clone()));
let _ = self.inner.change_sender.send((key, value));
e.insert((revision, value.to_string()));
let _ = self
.inner
.change_sender
.send((key.to_string(), value.to_string()));
StorageOutcome::Created(revision)
}
Entry::Occupied(mut entry) => {
......@@ -120,7 +125,7 @@ impl KeyValueBucket for MemoryBucketRef {
if *rev == revision {
StorageOutcome::Exists(revision)
} else {
entry.insert((revision, value));
entry.insert((revision, value.to_string()));
StorageOutcome::Created(revision)
}
}
......@@ -128,23 +133,23 @@ impl KeyValueBucket for MemoryBucketRef {
Ok(outcome)
}
async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> {
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError> {
let locked_data = self.inner.data.lock().await;
let Some(bucket) = locked_data.get(&self.name) else {
return Ok(None);
};
Ok(bucket
.data
.get(key)
.get(&key.0)
.map(|(_, v)| bytes::Bytes::from(v.clone())))
}
async fn delete(&self, key: &str) -> Result<(), StorageError> {
async fn delete(&self, key: &Key) -> Result<(), StorageError> {
let mut locked_data = self.inner.data.lock().await;
let Some(bucket) = locked_data.get_mut(&self.name) else {
return Err(StorageError::MissingBucket(self.name.to_string()));
};
bucket.data.remove(key);
bucket.data.remove(&key.0);
Ok(())
}
......
......@@ -3,7 +3,9 @@
use std::{collections::HashMap, pin::Pin, time::Duration};
use crate::{protocols::EndpointId, slug::Slug, transports::nats::Client};
use crate::{
protocols::EndpointId, slug::Slug, storage::key_value_store::Key, transports::nats::Client,
};
use async_trait::async_trait;
use futures::StreamExt;
......@@ -109,8 +111,8 @@ impl NATSStorage {
impl KeyValueBucket for NATSBucket {
async fn insert(
&self,
key: String,
value: String,
key: &Key,
value: &str,
revision: u64,
) -> Result<StorageOutcome, StorageError> {
if revision == 0 {
......@@ -120,14 +122,14 @@ impl KeyValueBucket for NATSBucket {
}
}
async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> {
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError> {
self.nats_store
.get(key)
.await
.map_err(|e| StorageError::NATSError(e.to_string()))
}
async fn delete(&self, key: &str) -> Result<(), StorageError> {
async fn delete(&self, key: &Key) -> Result<(), StorageError> {
self.nats_store
.delete(key)
.await
......@@ -179,16 +181,16 @@ impl KeyValueBucket for NATSBucket {
}
impl NATSBucket {
async fn create(&self, key: String, value: String) -> Result<StorageOutcome, StorageError> {
match self.nats_store.create(&key, value.into()).await {
async fn create(&self, key: &Key, value: &str) -> Result<StorageOutcome, StorageError> {
match self.nats_store.create(&key, value.to_string().into()).await {
Ok(revision) => Ok(StorageOutcome::Created(revision)),
Err(err) if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists => {
// key exists, get the revsion
match self.nats_store.entry(&key).await {
match self.nats_store.entry(key).await {
Ok(Some(entry)) => Ok(StorageOutcome::Exists(entry.revision)),
Ok(None) => {
tracing::error!(
key,
%key,
"Race condition, key deleted between create and fetch. Retry."
);
Err(StorageError::Retry)
......@@ -202,20 +204,20 @@ impl NATSBucket {
async fn update(
&self,
key: String,
value: String,
key: &Key,
value: &str,
revision: u64,
) -> Result<StorageOutcome, StorageError> {
match self
.nats_store
.update(key.clone(), value.clone().into(), revision)
.update(key, value.to_string().into(), revision)
.await
{
Ok(revision) => Ok(StorageOutcome::Created(revision)),
Err(err)
if err.kind() == async_nats::jetstream::kv::UpdateErrorKind::WrongLastRevision =>
{
tracing::warn!(revision, key, "Update WrongLastRevision, resync");
tracing::warn!(revision, %key, "Update WrongLastRevision, resync");
self.resync_update(key, value).await
}
Err(err) => Err(StorageError::NATSError(err.to_string())),
......@@ -224,18 +226,14 @@ impl NATSBucket {
/// We have the wrong revision for a key. Fetch it's entry to get the correct revision,
/// and try the update again.
async fn resync_update(
&self,
key: String,
value: String,
) -> Result<StorageOutcome, StorageError> {
match self.nats_store.entry(&key).await {
async fn resync_update(&self, key: &Key, value: &str) -> Result<StorageOutcome, StorageError> {
match self.nats_store.entry(key).await {
Ok(Some(entry)) => {
// Re-try the update with new version number
let next_rev = entry.revision + 1;
match self
.nats_store
.update(key.clone(), value.into(), next_rev)
.update(key, value.to_string().into(), next_rev)
.await
{
Ok(correct_revision) => Ok(StorageOutcome::Created(correct_revision)),
......@@ -245,11 +243,11 @@ impl NATSBucket {
}
}
Ok(None) => {
tracing::warn!(key, "Entry does not exist during resync, creating.");
tracing::warn!(%key, "Entry does not exist during resync, creating.");
self.create(key, value).await
}
Err(err) => {
tracing::error!(key, %err, "Failed fetching entry during resync");
tracing::error!(%key, %err, "Failed fetching entry during resync");
Err(StorageError::NATSError(err.to_string()))
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment