Unverified Commit 50cdae5f authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Add Key abstraction in our KeyValueStore (#3322)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent cf3ac5b6
...@@ -14,6 +14,7 @@ use dynamo_runtime::{ ...@@ -14,6 +14,7 @@ use dynamo_runtime::{
network::egress::push_router::PushRouter, network::egress::push_router::PushRouter,
}, },
protocols::annotated::Annotated, protocols::annotated::Annotated,
storage::key_value_store::Key,
transports::etcd::{KeyValue, WatchEvent}, transports::etcd::{KeyValue, WatchEvent},
}; };
...@@ -258,7 +259,12 @@ impl ModelWatcher { ...@@ -258,7 +259,12 @@ impl ModelWatcher {
.component(&endpoint_id.component)?; .component(&endpoint_id.component)?;
let client = component.endpoint(&endpoint_id.name).client().await?; let client = component.endpoint(&endpoint_id.name).client().await?;
let model_slug = model_entry.slug(); let model_slug = model_entry.slug();
let card = match ModelDeploymentCard::load_from_store(&model_slug, &self.drt).await { let card = match ModelDeploymentCard::load_from_store(
&Key::from_raw(model_slug.to_string()),
&self.drt,
)
.await
{
Ok(Some(mut card)) => { Ok(Some(mut card)) => {
tracing::debug!(card.display_name, "adding model"); tracing::debug!(card.display_name, "adding model");
// Ensure runtime_config is populated // Ensure runtime_config is populated
......
...@@ -239,10 +239,7 @@ async fn run_watcher( ...@@ -239,10 +239,7 @@ async fn run_watcher(
// Spawn a task to watch for model type changes and update HTTP service endpoints and metrics // Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
let _endpoint_enabler_task = tokio::spawn(async move { let _endpoint_enabler_task = tokio::spawn(async move {
while let Some(model_update) = rx.recv().await { while let Some(model_update) = rx.recv().await {
// Update HTTP endpoints (existing functionality)
update_http_endpoints(http_service.clone(), model_update.clone()); update_http_endpoints(http_service.clone(), model_update.clone());
// Update metrics (only for added models)
update_model_metrics(model_update, metrics.clone()); update_model_metrics(model_update, metrics.clone());
} }
}); });
......
...@@ -8,6 +8,7 @@ use std::sync::Arc; ...@@ -8,6 +8,7 @@ use std::sync::Arc;
use anyhow::Context as _; use anyhow::Context as _;
use dynamo_runtime::protocols::EndpointId; use dynamo_runtime::protocols::EndpointId;
use dynamo_runtime::slug::Slug; use dynamo_runtime::slug::Slug;
use dynamo_runtime::storage::key_value_store::Key;
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::{ use dynamo_runtime::{
component::Endpoint, component::Endpoint,
...@@ -417,13 +418,21 @@ impl LocalModel { ...@@ -417,13 +418,21 @@ impl LocalModel {
let nats_client = endpoint.drt().nats_client(); let nats_client = endpoint.drt().nats_client();
self.card.move_to_nats(nats_client.clone()).await?; self.card.move_to_nats(nats_client.clone()).await?;
// Publish the Model Deployment Card to etcd // Publish the Model Deployment Card to KV store
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone())); let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore)); let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let key = self.card.slug().to_string(); let key = self.card.slug().to_string();
// TODO: Next PR will use this
//let lease_id = endpoint.drt().primary_lease().map(|l| l.id()).unwrap_or(0);
//let key = Key::from_raw(endpoint.unique_path(lease_id));
card_store card_store
.publish(model_card::ROOT_PATH, None, &key, &mut self.card) .publish(
model_card::ROOT_PATH,
None,
&Key::from_raw(key),
&mut self.card,
)
.await?; .await?;
// Publish our ModelEntry to etcd. This allows ingress to find the model card. // Publish our ModelEntry to etcd. This allows ingress to find the model card.
......
...@@ -23,7 +23,9 @@ use crate::model_type::{ModelInput, ModelType}; ...@@ -23,7 +23,9 @@ use crate::model_type::{ModelInput, ModelType};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager}; use dynamo_runtime::storage::key_value_store::{
EtcdStorage, Key, KeyValueStore, KeyValueStoreManager,
};
use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats}; use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer as HfTokenizer; use tokenizers::Tokenizer as HfTokenizer;
...@@ -394,7 +396,7 @@ impl ModelDeploymentCard { ...@@ -394,7 +396,7 @@ impl ModelDeploymentCard {
/// Load a ModelDeploymentCard from storage the DistributedRuntime is configured to use. /// Load a ModelDeploymentCard from storage the DistributedRuntime is configured to use.
/// Card should be fully local and ready to use when the call returns. /// Card should be fully local and ready to use when the call returns.
pub async fn load_from_store( pub async fn load_from_store(
model_slug: &Slug, mdc_key: &Key,
drt: &DistributedRuntime, drt: &DistributedRuntime,
) -> anyhow::Result<Option<Self>> { ) -> anyhow::Result<Option<Self>> {
let Some(etcd_client) = drt.etcd_client() else { let Some(etcd_client) = drt.etcd_client() else {
...@@ -404,7 +406,7 @@ impl ModelDeploymentCard { ...@@ -404,7 +406,7 @@ impl ModelDeploymentCard {
let store: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client)); let store: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client));
let card_store = Arc::new(KeyValueStoreManager::new(store)); let card_store = Arc::new(KeyValueStoreManager::new(store));
let Some(mut card) = card_store let Some(mut card) = card_store
.load::<ModelDeploymentCard>(ROOT_PATH, model_slug) .load::<ModelDeploymentCard>(ROOT_PATH, mdc_key)
.await? .await?
else { else {
return Ok(None); return Ok(None);
......
...@@ -225,8 +225,8 @@ impl Component { ...@@ -225,8 +225,8 @@ impl Component {
&self.namespace &self.namespace
} }
pub fn name(&self) -> String { pub fn name(&self) -> &str {
self.name.clone() &self.name
} }
pub fn labels(&self) -> &[(String, String)] { pub fn labels(&self) -> &[(String, String)] {
...@@ -457,7 +457,7 @@ impl Endpoint { ...@@ -457,7 +457,7 @@ impl Endpoint {
pub fn etcd_path(&self) -> EtcdPath { pub fn etcd_path(&self) -> EtcdPath {
EtcdPath::new_endpoint( EtcdPath::new_endpoint(
&self.component.namespace().name(), &self.component.namespace().name(),
&self.component.name(), self.component.name(),
&self.name, &self.name,
) )
.expect("Endpoint name and component name should be valid") .expect("Endpoint name and component name should be valid")
...@@ -465,12 +465,15 @@ impl Endpoint { ...@@ -465,12 +465,15 @@ impl Endpoint {
/// The fully path of an instance in etcd /// The fully path of an instance in etcd
pub fn etcd_path_with_lease_id(&self, lease_id: i64) -> String { pub fn etcd_path_with_lease_id(&self, lease_id: i64) -> String {
let endpoint_root = self.etcd_root(); format!("{INSTANCE_ROOT_PATH}/{}", self.unique_path(lease_id))
if self.is_static { }
endpoint_root
} else { /// Full path of this endpoint with forward slash separators, including lease id
format!("{endpoint_root}:{lease_id:x}") pub fn unique_path(&self, lease_id: i64) -> String {
} let ns = self.component.namespace().name();
let cp = self.component.name();
let ep = self.name();
format!("{ns}/{cp}/{ep}/{lease_id:x}")
} }
/// The endpoint as an EtcdPath object with lease ID /// The endpoint as an EtcdPath object with lease ID
...@@ -480,7 +483,7 @@ impl Endpoint { ...@@ -480,7 +483,7 @@ impl Endpoint {
} else { } else {
EtcdPath::new_endpoint_with_lease( EtcdPath::new_endpoint_with_lease(
&self.component.namespace().name(), &self.component.namespace().name(),
&self.component.name(), self.component.name(),
&self.name, &self.name,
lease_id, lease_id,
) )
......
...@@ -425,8 +425,8 @@ mod integration_tests { ...@@ -425,8 +425,8 @@ mod integration_tests {
let request_timeout = Duration::from_secs(3); let request_timeout = Duration::from_secs(3);
let config = HealthCheckConfig { let config = HealthCheckConfig {
canary_wait_time: canary_wait_time, canary_wait_time,
request_timeout: request_timeout, request_timeout,
}; };
let manager = HealthCheckManager::new(drt.clone(), config); let manager = HealthCheckManager::new(drt.clone(), config);
......
...@@ -23,6 +23,45 @@ pub use nats::NATSStorage; ...@@ -23,6 +23,45 @@ pub use nats::NATSStorage;
mod etcd; mod etcd;
pub use etcd::EtcdStorage; pub use etcd::EtcdStorage;
/// A key that is safe to use directly in the KV store.
#[derive(Debug, Clone, PartialEq)]
pub struct Key(String);
impl Key {
pub fn new(s: &str) -> Key {
Key(Slug::slugify(s).to_string())
}
/// Create a Key without changing the string, it is assumed already KV store safe.
pub fn from_raw(s: String) -> Key {
Key(s)
}
}
impl From<&str> for Key {
fn from(s: &str) -> Key {
Key::new(s)
}
}
impl fmt::Display for Key {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl AsRef<str> for Key {
fn as_ref(&self) -> &str {
&self.0
}
}
impl From<&Key> for String {
fn from(k: &Key) -> String {
k.0.clone()
}
}
#[async_trait] #[async_trait]
pub trait KeyValueStore: Send + Sync { pub trait KeyValueStore: Send + Sync {
async fn get_or_create_bucket( async fn get_or_create_bucket(
...@@ -48,13 +87,13 @@ impl KeyValueStoreManager { ...@@ -48,13 +87,13 @@ impl KeyValueStoreManager {
pub async fn load<T: for<'a> Deserialize<'a>>( pub async fn load<T: for<'a> Deserialize<'a>>(
&self, &self,
bucket: &str, bucket: &str,
key: &Slug, key: &Key,
) -> Result<Option<T>, StorageError> { ) -> Result<Option<T>, StorageError> {
let Some(bucket) = self.0.get_bucket(bucket).await? else { let Some(bucket) = self.0.get_bucket(bucket).await? else {
// No bucket means no cards // No bucket means no cards
return Ok(None); return Ok(None);
}; };
match bucket.get(key.as_ref()).await { match bucket.get(key).await {
Ok(Some(card_bytes)) => { Ok(Some(card_bytes)) => {
let card: T = serde_json::from_slice(card_bytes.as_ref())?; let card: T = serde_json::from_slice(card_bytes.as_ref())?;
Ok(Some(card)) Ok(Some(card))
...@@ -109,15 +148,13 @@ impl KeyValueStoreManager { ...@@ -109,15 +148,13 @@ impl KeyValueStoreManager {
&self, &self,
bucket_name: &str, bucket_name: &str,
bucket_ttl: Option<Duration>, bucket_ttl: Option<Duration>,
key: &str, key: &Key,
obj: &mut T, obj: &mut T,
) -> anyhow::Result<StorageOutcome> { ) -> anyhow::Result<StorageOutcome> {
let obj_json = serde_json::to_string(obj)?; let obj_json = serde_json::to_string(obj)?;
let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?; let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;
let outcome = bucket let outcome = bucket.insert(key, &obj_json, obj.revision()).await?;
.insert(key.to_string(), obj_json, obj.revision())
.await?;
match outcome { match outcome {
StorageOutcome::Created(revision) | StorageOutcome::Exists(revision) => { StorageOutcome::Created(revision) | StorageOutcome::Exists(revision) => {
...@@ -126,59 +163,6 @@ impl KeyValueStoreManager { ...@@ -126,59 +163,6 @@ impl KeyValueStoreManager {
} }
Ok(outcome) Ok(outcome)
} }
/// Re-publish the model card to the store regularly. Spawns a task and returns.
/// Takes most arguments by value because it will hold on to them in the publish task.
/// Deletes the card on cancellation.
pub fn publish_until_cancelled<T: Serialize + Versioned + Send + Sync + 'static>(
self: Arc<Self>,
cancel_token: CancellationToken,
bucket_name: String,
bucket_ttl: Option<Duration>,
publish_interval: Duration,
key: String,
mut obj: T,
) {
tokio::spawn(async move {
loop {
let publish_result = self
.clone()
.publish(&bucket_name, bucket_ttl, &key, &mut obj)
.await;
if let Err(err) = publish_result {
tracing::error!(
model = key,
error = %err,
"Failed publishing to KV storage. Ending publish task.",
);
}
tokio::select! {
_ = tokio::time::sleep(publish_interval) => {},
_ = cancel_token.cancelled() => {
tracing::trace!(model_service_name = key, "Publish loop cancelled");
match self.0.get_bucket(&bucket_name).await {
Ok(Some(bucket)) => {
if let Err(err) = bucket.delete(&key).await {
// This is usually expected, our NATS connection is closed
tracing::trace!(bucket_name, key, %err, "Error delete published card from NATS on publish stop");
}
tracing::trace!(bucket_name, key, "Deleted Model Deployment Card from NATS");
}
Ok(None) => {
tracing::trace!(bucket_name, key, "Bucket does not exist");
}
Err(err) => {
tracing::trace!(bucket_name, %err, "publish_until_cancelled shutdown error");
}
}
// Stop publishing
break;
}
}
}
});
}
} }
/// An online storage for key-value config values. /// An online storage for key-value config values.
...@@ -189,16 +173,16 @@ pub trait KeyValueBucket: Send { ...@@ -189,16 +173,16 @@ pub trait KeyValueBucket: Send {
/// Insert a value into a bucket, if it doesn't exist already /// Insert a value into a bucket, if it doesn't exist already
async fn insert( async fn insert(
&self, &self,
key: String, key: &Key,
value: String, value: &str,
revision: u64, revision: u64,
) -> Result<StorageOutcome, StorageError>; ) -> Result<StorageOutcome, StorageError>;
/// Fetch an item from the key-value storage /// Fetch an item from the key-value storage
async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError>; async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError>;
/// Delete an item from the bucket /// Delete an item from the bucket
async fn delete(&self, key: &str) -> Result<(), StorageError>; async fn delete(&self, key: &Key) -> Result<(), StorageError>;
/// A stream of items inserted into the bucket. /// A stream of items inserted into the bucket.
/// Every time the stream is polled it will either return a newly created entry, or block until /// Every time the stream is polled it will either return a newly created entry, or block until
...@@ -311,9 +295,7 @@ mod tests { ...@@ -311,9 +295,7 @@ mod tests {
let s2 = Arc::clone(&s); let s2 = Arc::clone(&s);
let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?; let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
let res = bucket let res = bucket.insert(&"test1".into(), "value1", 0).await?;
.insert("test1".to_string(), "value1".to_string(), 0)
.await?;
assert_eq!(res, StorageOutcome::Created(0)); assert_eq!(res, StorageOutcome::Created(0));
let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel(); let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
...@@ -341,26 +323,18 @@ mod tests { ...@@ -341,26 +323,18 @@ mod tests {
// wouldn't be testing the watch behavior. // wouldn't be testing the watch behavior.
got_first_rx.await?; got_first_rx.await?;
let res = bucket let res = bucket.insert(&"test2".into(), "value2", 0).await?;
.insert("test2".to_string(), "value2".to_string(), 0)
.await?;
assert_eq!(res, StorageOutcome::Created(0)); assert_eq!(res, StorageOutcome::Created(0));
// Repeat a key and revision. Ignored. // Repeat a key and revision. Ignored.
let res = bucket let res = bucket.insert(&"test2".into(), "value2", 0).await?;
.insert("test2".to_string(), "value2".to_string(), 0)
.await?;
assert_eq!(res, StorageOutcome::Exists(0)); assert_eq!(res, StorageOutcome::Exists(0));
// Increment revision // Increment revision
let res = bucket let res = bucket.insert(&"test2".into(), "value2", 1).await?;
.insert("test2".to_string(), "value2".to_string(), 1)
.await?;
assert_eq!(res, StorageOutcome::Created(1)); assert_eq!(res, StorageOutcome::Created(1));
let res = bucket let res = bucket.insert(&"test3".into(), "value3", 0).await?;
.insert("test3".to_string(), "value3".to_string(), 0)
.await?;
assert_eq!(res, StorageOutcome::Created(0)); assert_eq!(res, StorageOutcome::Created(0));
// ingress exits once it has received all values // ingress exits once it has received all values
...@@ -377,9 +351,7 @@ mod tests { ...@@ -377,9 +351,7 @@ mod tests {
let bucket: &'static _ = let bucket: &'static _ =
Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?)); Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));
let res = bucket let res = bucket.insert(&"test1".into(), "value1", 0).await?;
.insert("test1".to_string(), "value1".to_string(), 0)
.await?;
assert_eq!(res, StorageOutcome::Created(0)); assert_eq!(res, StorageOutcome::Created(0));
let stream = bucket.watch().await?; let stream = bucket.watch().await?;
...@@ -397,9 +369,7 @@ mod tests { ...@@ -397,9 +369,7 @@ mod tests {
assert_eq!(b, bytes::Bytes::from(vec![b'G', b'K'])); assert_eq!(b, bytes::Bytes::from(vec![b'G', b'K']));
}); });
bucket bucket.insert(&"test1".into(), "GK", 1).await?;
.insert("test1".to_string(), "GK".to_string(), 1)
.await?;
let _ = futures::join!(handle1, handle2); let _ = futures::join!(handle1, handle2);
Ok(()) Ok(())
......
...@@ -5,7 +5,7 @@ use std::collections::HashMap; ...@@ -5,7 +5,7 @@ use std::collections::HashMap;
use std::pin::Pin; use std::pin::Pin;
use std::time::Duration; use std::time::Duration;
use crate::{slug::Slug, transports::etcd::Client}; use crate::{slug::Slug, storage::key_value_store::Key, transports::etcd::Client};
use async_stream::stream; use async_stream::stream;
use async_trait::async_trait; use async_trait::async_trait;
use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions}; use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
...@@ -56,21 +56,21 @@ pub struct EtcdBucket { ...@@ -56,21 +56,21 @@ pub struct EtcdBucket {
impl KeyValueBucket for EtcdBucket { impl KeyValueBucket for EtcdBucket {
async fn insert( async fn insert(
&self, &self,
key: String, key: &Key,
value: String, value: &str,
// "version" in etcd speak. revision is a global cluster-wide value // "version" in etcd speak. revision is a global cluster-wide value
revision: u64, revision: u64,
) -> Result<StorageOutcome, StorageError> { ) -> Result<StorageOutcome, StorageError> {
let version = revision; let version = revision;
if version == 0 { if version == 0 {
self.create(&key, &value).await self.create(key, value).await
} else { } else {
self.update(&key, &value, version).await self.update(key, value, version).await
} }
} }
async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> { async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError> {
let k = make_key(&self.bucket_name, key); let k = format!("{}/{key}", self.bucket_name);
tracing::trace!("etcd get: {k}"); tracing::trace!("etcd get: {k}");
let mut kvs = self let mut kvs = self
...@@ -85,10 +85,10 @@ impl KeyValueBucket for EtcdBucket { ...@@ -85,10 +85,10 @@ impl KeyValueBucket for EtcdBucket {
Ok(Some(val.into())) Ok(Some(val.into()))
} }
async fn delete(&self, key: &str) -> Result<(), StorageError> { async fn delete(&self, key: &Key) -> Result<(), StorageError> {
let _ = self let _ = self
.client .client
.kv_delete(key, None) .kv_delete(key.0.clone(), None)
.await .await
.map_err(|e| StorageError::EtcdError(e.to_string()))?; .map_err(|e| StorageError::EtcdError(e.to_string()))?;
Ok(()) Ok(())
...@@ -98,7 +98,7 @@ impl KeyValueBucket for EtcdBucket { ...@@ -98,7 +98,7 @@ impl KeyValueBucket for EtcdBucket {
&self, &self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError> ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>
{ {
let k = make_key(&self.bucket_name, ""); let k = make_key(&self.bucket_name, &"".into());
tracing::trace!("etcd watch: {k}"); tracing::trace!("etcd watch: {k}");
let (_watcher, mut watch_stream) = self let (_watcher, mut watch_stream) = self
.client .client
...@@ -121,7 +121,7 @@ impl KeyValueBucket for EtcdBucket { ...@@ -121,7 +121,7 @@ impl KeyValueBucket for EtcdBucket {
} }
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> { async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> {
let k = make_key(&self.bucket_name, ""); let k = make_key(&self.bucket_name, &"".into());
tracing::trace!("etcd entries: {k}"); tracing::trace!("etcd entries: {k}");
let resp = self let resp = self
...@@ -142,7 +142,7 @@ impl KeyValueBucket for EtcdBucket { ...@@ -142,7 +142,7 @@ impl KeyValueBucket for EtcdBucket {
} }
impl EtcdBucket { impl EtcdBucket {
async fn create(&self, key: &str, value: &str) -> Result<StorageOutcome, StorageError> { async fn create(&self, key: &Key, value: &str) -> Result<StorageOutcome, StorageError> {
let k = make_key(&self.bucket_name, key); let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd create: {k}"); tracing::trace!("etcd create: {k}");
...@@ -187,7 +187,7 @@ impl EtcdBucket { ...@@ -187,7 +187,7 @@ impl EtcdBucket {
async fn update( async fn update(
&self, &self,
key: &str, key: &Key,
value: &str, value: &str,
revision: u64, revision: u64,
) -> Result<StorageOutcome, StorageError> { ) -> Result<StorageOutcome, StorageError> {
...@@ -208,7 +208,7 @@ impl EtcdBucket { ...@@ -208,7 +208,7 @@ impl EtcdBucket {
tracing::warn!( tracing::warn!(
current_version, current_version,
attempted_next_version = version, attempted_next_version = version,
key, %key,
"update: Wrong revision" "update: Wrong revision"
); );
// NATS does a resync_update, overwriting the key anyway and getting the new revision. // NATS does a resync_update, overwriting the key anyway and getting the new revision.
...@@ -234,12 +234,8 @@ impl EtcdBucket { ...@@ -234,12 +234,8 @@ impl EtcdBucket {
} }
} }
fn make_key(bucket_name: &str, key: &str) -> String { fn make_key(bucket_name: &str, key: &Key) -> String {
[ [Slug::slugify(bucket_name).to_string(), key.to_string()].join("/")
Slug::slugify(bucket_name).to_string(),
Slug::slugify(key).to_string(),
]
.join("/")
} }
#[cfg(feature = "integration")] #[cfg(feature = "integration")]
...@@ -278,7 +274,7 @@ mod concurrent_create_tests { ...@@ -278,7 +274,7 @@ mod concurrent_create_tests {
let barrier = Arc::new(Barrier::new(num_workers)); let barrier = Arc::new(Barrier::new(num_workers));
// Shared test data // Shared test data
let test_key = format!("concurrent_test_key_{}", uuid::Uuid::new_v4()); let test_key: Key = Key::new(&format!("concurrent_test_key_{}", uuid::Uuid::new_v4()));
let test_value = "test_value"; let test_value = "test_value";
// Spawn multiple tasks that will all try to create the same key simultaneously // Spawn multiple tasks that will all try to create the same key simultaneously
...@@ -302,7 +298,7 @@ mod concurrent_create_tests { ...@@ -302,7 +298,7 @@ mod concurrent_create_tests {
let result = bucket_clone let result = bucket_clone
.lock() .lock()
.await .await
.insert(key_clone, value_clone, 0) .insert(&key_clone, &value_clone, 0)
.await; .await;
match result { match result {
......
...@@ -11,6 +11,8 @@ use async_trait::async_trait; ...@@ -11,6 +11,8 @@ use async_trait::async_trait;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use crate::storage::key_value_store::Key;
use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome}; use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
#[derive(Clone)] #[derive(Clone)]
...@@ -100,8 +102,8 @@ impl KeyValueStore for MemoryStorage { ...@@ -100,8 +102,8 @@ impl KeyValueStore for MemoryStorage {
impl KeyValueBucket for MemoryBucketRef { impl KeyValueBucket for MemoryBucketRef {
async fn insert( async fn insert(
&self, &self,
key: String, key: &Key,
value: String, value: &str,
revision: u64, revision: u64,
) -> Result<StorageOutcome, StorageError> { ) -> Result<StorageOutcome, StorageError> {
let mut locked_data = self.inner.data.lock().await; let mut locked_data = self.inner.data.lock().await;
...@@ -111,8 +113,11 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -111,8 +113,11 @@ impl KeyValueBucket for MemoryBucketRef {
}; };
let outcome = match bucket.data.entry(key.to_string()) { let outcome = match bucket.data.entry(key.to_string()) {
Entry::Vacant(e) => { Entry::Vacant(e) => {
e.insert((revision, value.clone())); e.insert((revision, value.to_string()));
let _ = self.inner.change_sender.send((key, value)); let _ = self
.inner
.change_sender
.send((key.to_string(), value.to_string()));
StorageOutcome::Created(revision) StorageOutcome::Created(revision)
} }
Entry::Occupied(mut entry) => { Entry::Occupied(mut entry) => {
...@@ -120,7 +125,7 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -120,7 +125,7 @@ impl KeyValueBucket for MemoryBucketRef {
if *rev == revision { if *rev == revision {
StorageOutcome::Exists(revision) StorageOutcome::Exists(revision)
} else { } else {
entry.insert((revision, value)); entry.insert((revision, value.to_string()));
StorageOutcome::Created(revision) StorageOutcome::Created(revision)
} }
} }
...@@ -128,23 +133,23 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -128,23 +133,23 @@ impl KeyValueBucket for MemoryBucketRef {
Ok(outcome) Ok(outcome)
} }
async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> { async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError> {
let locked_data = self.inner.data.lock().await; let locked_data = self.inner.data.lock().await;
let Some(bucket) = locked_data.get(&self.name) else { let Some(bucket) = locked_data.get(&self.name) else {
return Ok(None); return Ok(None);
}; };
Ok(bucket Ok(bucket
.data .data
.get(key) .get(&key.0)
.map(|(_, v)| bytes::Bytes::from(v.clone()))) .map(|(_, v)| bytes::Bytes::from(v.clone())))
} }
async fn delete(&self, key: &str) -> Result<(), StorageError> { async fn delete(&self, key: &Key) -> Result<(), StorageError> {
let mut locked_data = self.inner.data.lock().await; let mut locked_data = self.inner.data.lock().await;
let Some(bucket) = locked_data.get_mut(&self.name) else { let Some(bucket) = locked_data.get_mut(&self.name) else {
return Err(StorageError::MissingBucket(self.name.to_string())); return Err(StorageError::MissingBucket(self.name.to_string()));
}; };
bucket.data.remove(key); bucket.data.remove(&key.0);
Ok(()) Ok(())
} }
......
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
use std::{collections::HashMap, pin::Pin, time::Duration}; use std::{collections::HashMap, pin::Pin, time::Duration};
use crate::{protocols::EndpointId, slug::Slug, transports::nats::Client}; use crate::{
protocols::EndpointId, slug::Slug, storage::key_value_store::Key, transports::nats::Client,
};
use async_trait::async_trait; use async_trait::async_trait;
use futures::StreamExt; use futures::StreamExt;
...@@ -109,8 +111,8 @@ impl NATSStorage { ...@@ -109,8 +111,8 @@ impl NATSStorage {
impl KeyValueBucket for NATSBucket { impl KeyValueBucket for NATSBucket {
async fn insert( async fn insert(
&self, &self,
key: String, key: &Key,
value: String, value: &str,
revision: u64, revision: u64,
) -> Result<StorageOutcome, StorageError> { ) -> Result<StorageOutcome, StorageError> {
if revision == 0 { if revision == 0 {
...@@ -120,14 +122,14 @@ impl KeyValueBucket for NATSBucket { ...@@ -120,14 +122,14 @@ impl KeyValueBucket for NATSBucket {
} }
} }
async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> { async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError> {
self.nats_store self.nats_store
.get(key) .get(key)
.await .await
.map_err(|e| StorageError::NATSError(e.to_string())) .map_err(|e| StorageError::NATSError(e.to_string()))
} }
async fn delete(&self, key: &str) -> Result<(), StorageError> { async fn delete(&self, key: &Key) -> Result<(), StorageError> {
self.nats_store self.nats_store
.delete(key) .delete(key)
.await .await
...@@ -179,16 +181,16 @@ impl KeyValueBucket for NATSBucket { ...@@ -179,16 +181,16 @@ impl KeyValueBucket for NATSBucket {
} }
impl NATSBucket { impl NATSBucket {
async fn create(&self, key: String, value: String) -> Result<StorageOutcome, StorageError> { async fn create(&self, key: &Key, value: &str) -> Result<StorageOutcome, StorageError> {
match self.nats_store.create(&key, value.into()).await { match self.nats_store.create(&key, value.to_string().into()).await {
Ok(revision) => Ok(StorageOutcome::Created(revision)), Ok(revision) => Ok(StorageOutcome::Created(revision)),
Err(err) if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists => { Err(err) if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists => {
// key exists, get the revsion // key exists, get the revsion
match self.nats_store.entry(&key).await { match self.nats_store.entry(key).await {
Ok(Some(entry)) => Ok(StorageOutcome::Exists(entry.revision)), Ok(Some(entry)) => Ok(StorageOutcome::Exists(entry.revision)),
Ok(None) => { Ok(None) => {
tracing::error!( tracing::error!(
key, %key,
"Race condition, key deleted between create and fetch. Retry." "Race condition, key deleted between create and fetch. Retry."
); );
Err(StorageError::Retry) Err(StorageError::Retry)
...@@ -202,20 +204,20 @@ impl NATSBucket { ...@@ -202,20 +204,20 @@ impl NATSBucket {
async fn update( async fn update(
&self, &self,
key: String, key: &Key,
value: String, value: &str,
revision: u64, revision: u64,
) -> Result<StorageOutcome, StorageError> { ) -> Result<StorageOutcome, StorageError> {
match self match self
.nats_store .nats_store
.update(key.clone(), value.clone().into(), revision) .update(key, value.to_string().into(), revision)
.await .await
{ {
Ok(revision) => Ok(StorageOutcome::Created(revision)), Ok(revision) => Ok(StorageOutcome::Created(revision)),
Err(err) Err(err)
if err.kind() == async_nats::jetstream::kv::UpdateErrorKind::WrongLastRevision => if err.kind() == async_nats::jetstream::kv::UpdateErrorKind::WrongLastRevision =>
{ {
tracing::warn!(revision, key, "Update WrongLastRevision, resync"); tracing::warn!(revision, %key, "Update WrongLastRevision, resync");
self.resync_update(key, value).await self.resync_update(key, value).await
} }
Err(err) => Err(StorageError::NATSError(err.to_string())), Err(err) => Err(StorageError::NATSError(err.to_string())),
...@@ -224,18 +226,14 @@ impl NATSBucket { ...@@ -224,18 +226,14 @@ impl NATSBucket {
/// We have the wrong revision for a key. Fetch it's entry to get the correct revision, /// We have the wrong revision for a key. Fetch it's entry to get the correct revision,
/// and try the update again. /// and try the update again.
async fn resync_update( async fn resync_update(&self, key: &Key, value: &str) -> Result<StorageOutcome, StorageError> {
&self, match self.nats_store.entry(key).await {
key: String,
value: String,
) -> Result<StorageOutcome, StorageError> {
match self.nats_store.entry(&key).await {
Ok(Some(entry)) => { Ok(Some(entry)) => {
// Re-try the update with new version number // Re-try the update with new version number
let next_rev = entry.revision + 1; let next_rev = entry.revision + 1;
match self match self
.nats_store .nats_store
.update(key.clone(), value.into(), next_rev) .update(key, value.to_string().into(), next_rev)
.await .await
{ {
Ok(correct_revision) => Ok(StorageOutcome::Created(correct_revision)), Ok(correct_revision) => Ok(StorageOutcome::Created(correct_revision)),
...@@ -245,11 +243,11 @@ impl NATSBucket { ...@@ -245,11 +243,11 @@ impl NATSBucket {
} }
} }
Ok(None) => { Ok(None) => {
tracing::warn!(key, "Entry does not exist during resync, creating."); tracing::warn!(%key, "Entry does not exist during resync, creating.");
self.create(key, value).await self.create(key, value).await
} }
Err(err) => { Err(err) => {
tracing::error!(key, %err, "Failed fetching entry during resync"); tracing::error!(%key, %err, "Failed fetching entry during resync");
Err(StorageError::NATSError(err.to_string())) Err(StorageError::NATSError(err.to_string()))
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment