Unverified Commit 7a7d397c authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Introduce storage_client in DistributedRuntime (#3507)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 0a2a820b
...@@ -57,11 +57,11 @@ impl State { ...@@ -57,11 +57,11 @@ impl State {
} }
} }
pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: Option<etcd::Client>) -> Self { pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: etcd::Client) -> Self {
Self { Self {
manager, manager,
metrics: Arc::new(Metrics::default()), metrics: Arc::new(Metrics::default()),
etcd_client, etcd_client: Some(etcd_client),
} }
} }
...@@ -155,7 +155,10 @@ impl KserveServiceConfigBuilder { ...@@ -155,7 +155,10 @@ impl KserveServiceConfigBuilder {
let config: KserveServiceConfig = self.build_internal()?; let config: KserveServiceConfig = self.build_internal()?;
let model_manager = Arc::new(ModelManager::new()); let model_manager = Arc::new(ModelManager::new());
let state = Arc::new(State::new_with_etcd(model_manager, config.etcd_client)); let state = match config.etcd_client {
Some(etcd_client) => Arc::new(State::new_with_etcd(model_manager, etcd_client)),
None => Arc::new(State::new(model_manager)),
};
// enable prometheus metrics // enable prometheus metrics
let registry = metrics::Registry::new(); let registry = metrics::Registry::new();
......
...@@ -52,16 +52,12 @@ async fn live_handler( ...@@ -52,16 +52,12 @@ async fn live_handler(
async fn health_handler( async fn health_handler(
axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>, axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
) -> impl IntoResponse { ) -> impl IntoResponse {
let instances = if let Some(etcd_client) = state.etcd_client() { let instances = match list_all_instances(state.store()).await {
match list_all_instances(etcd_client).await { Ok(instances) => instances,
Ok(instances) => instances, Err(err) => {
Err(err) => { tracing::warn!(%err, "Failed to fetch instances from store");
tracing::warn!("Failed to fetch instances from etcd: {}", err); vec![]
vec![]
}
} }
} else {
vec![]
}; };
let mut endpoints: Vec<String> = instances let mut endpoints: Vec<String> = instances
......
...@@ -19,6 +19,9 @@ use anyhow::Result; ...@@ -19,6 +19,9 @@ use anyhow::Result;
use axum_server::tls_rustls::RustlsConfig; use axum_server::tls_rustls::RustlsConfig;
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::logging::make_request_span; use dynamo_runtime::logging::make_request_span;
use dynamo_runtime::storage::key_value_store::EtcdStore;
use dynamo_runtime::storage::key_value_store::KeyValueStore;
use dynamo_runtime::storage::key_value_store::MemoryStore;
use dynamo_runtime::transports::etcd; use dynamo_runtime::transports::etcd;
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
...@@ -26,11 +29,11 @@ use tokio_util::sync::CancellationToken; ...@@ -26,11 +29,11 @@ use tokio_util::sync::CancellationToken;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
/// HTTP service shared state /// HTTP service shared state
#[derive(Default)]
pub struct State { pub struct State {
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
manager: Arc<ModelManager>, manager: Arc<ModelManager>,
etcd_client: Option<etcd::Client>, etcd_client: Option<etcd::Client>,
store: Arc<dyn KeyValueStore>,
flags: StateFlags, flags: StateFlags,
} }
...@@ -76,6 +79,7 @@ impl State { ...@@ -76,6 +79,7 @@ impl State {
manager, manager,
metrics: Arc::new(Metrics::default()), metrics: Arc::new(Metrics::default()),
etcd_client: None, etcd_client: None,
store: Arc::new(MemoryStore::new()),
flags: StateFlags { flags: StateFlags {
chat_endpoints_enabled: AtomicBool::new(false), chat_endpoints_enabled: AtomicBool::new(false),
cmpl_endpoints_enabled: AtomicBool::new(false), cmpl_endpoints_enabled: AtomicBool::new(false),
...@@ -85,11 +89,12 @@ impl State { ...@@ -85,11 +89,12 @@ impl State {
} }
} }
pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: Option<etcd::Client>) -> Self { pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: etcd::Client) -> Self {
Self { Self {
manager, manager,
metrics: Arc::new(Metrics::default()), metrics: Arc::new(Metrics::default()),
etcd_client, store: Arc::new(EtcdStore::new(etcd_client.clone())),
etcd_client: Some(etcd_client),
flags: StateFlags { flags: StateFlags {
chat_endpoints_enabled: AtomicBool::new(false), chat_endpoints_enabled: AtomicBool::new(false),
cmpl_endpoints_enabled: AtomicBool::new(false), cmpl_endpoints_enabled: AtomicBool::new(false),
...@@ -115,6 +120,10 @@ impl State { ...@@ -115,6 +120,10 @@ impl State {
self.etcd_client.as_ref() self.etcd_client.as_ref()
} }
pub fn store(&self) -> Arc<dyn KeyValueStore> {
self.store.clone()
}
// TODO // TODO
pub fn sse_keep_alive(&self) -> Option<Duration> { pub fn sse_keep_alive(&self) -> Option<Duration> {
None None
...@@ -294,9 +303,10 @@ impl HttpServiceConfigBuilder { ...@@ -294,9 +303,10 @@ impl HttpServiceConfigBuilder {
let config: HttpServiceConfig = self.build_internal()?; let config: HttpServiceConfig = self.build_internal()?;
let model_manager = Arc::new(ModelManager::new()); let model_manager = Arc::new(ModelManager::new());
let etcd_client = config.etcd_client; let state = match config.etcd_client {
let state = Arc::new(State::new_with_etcd(model_manager, etcd_client)); Some(etcd_client) => Arc::new(State::new_with_etcd(model_manager, etcd_client)),
None => Arc::new(State::new(model_manager)),
};
state state
.flags .flags
.set(&EndpointType::Chat, config.enable_chat_endpoints); .set(&EndpointType::Chat, config.enable_chat_endpoints);
......
...@@ -12,7 +12,7 @@ use dynamo_runtime::storage::key_value_store::Key; ...@@ -12,7 +12,7 @@ use dynamo_runtime::storage::key_value_store::Key;
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::{ use dynamo_runtime::{
component::Endpoint, component::Endpoint,
storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager}, storage::key_value_store::{EtcdStore, KeyValueStore, KeyValueStoreManager},
}; };
use crate::entrypoint::RouterConfig; use crate::entrypoint::RouterConfig;
...@@ -409,7 +409,7 @@ impl LocalModel { ...@@ -409,7 +409,7 @@ impl LocalModel {
self.card.move_to_nats(nats_client.clone()).await?; self.card.move_to_nats(nats_client.clone()).await?;
// Publish the Model Deployment Card to KV store // Publish the Model Deployment Card to KV store
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone())); let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStore::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore)); let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let lease_id = endpoint.drt().primary_lease().map(|l| l.id()).unwrap_or(0); let lease_id = endpoint.drt().primary_lease().map(|l| l.id()).unwrap_or(0);
let key = Key::from_raw(endpoint.unique_path(lease_id)); let key = Key::from_raw(endpoint.unique_path(lease_id));
......
...@@ -23,7 +23,7 @@ use anyhow::{Context, Result}; ...@@ -23,7 +23,7 @@ use anyhow::{Context, Result};
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::storage::key_value_store::{ use dynamo_runtime::storage::key_value_store::{
EtcdStorage, Key, KeyValueStore, KeyValueStoreManager, EtcdStore, Key, KeyValueStore, KeyValueStoreManager,
}; };
use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats}; use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -457,7 +457,7 @@ impl ModelDeploymentCard { ...@@ -457,7 +457,7 @@ impl ModelDeploymentCard {
// Should be impossible because we only get here on an etcd event // Should be impossible because we only get here on an etcd event
anyhow::bail!("Missing etcd_client"); anyhow::bail!("Missing etcd_client");
}; };
let store: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client)); let store: Box<dyn KeyValueStore> = Box::new(EtcdStore::new(etcd_client));
let card_store = Arc::new(KeyValueStoreManager::new(store)); let card_store = Arc::new(KeyValueStoreManager::new(store));
let Some(mut card) = card_store let Some(mut card) = card_store
.load::<ModelDeploymentCard>(ROOT_PATH, mdc_key) .load::<ModelDeploymentCard>(ROOT_PATH, mdc_key)
......
...@@ -29,6 +29,8 @@ ...@@ -29,6 +29,8 @@
//! //!
//! TODO: Top-level Overview of Endpoints/Functions //! TODO: Top-level Overview of Endpoints/Functions
use std::fmt;
use crate::{ use crate::{
config::HealthStatus, config::HealthStatus,
discovery::Lease, discovery::Lease,
...@@ -70,7 +72,7 @@ pub mod service; ...@@ -70,7 +72,7 @@ pub mod service;
pub use client::{Client, InstanceSource}; pub use client::{Client, InstanceSource};
/// The root etcd path where each instance registers itself in etcd. /// The root key-value path where each instance registers itself in.
/// An instance is namespace+component+endpoint+lease_id and must be unique. /// An instance is namespace+component+endpoint+lease_id and must be unique.
pub const INSTANCE_ROOT_PATH: &str = "v1/instances"; pub const INSTANCE_ROOT_PATH: &str = "v1/instances";
...@@ -91,7 +93,7 @@ pub struct Registry { ...@@ -91,7 +93,7 @@ pub struct Registry {
inner: Arc<tokio::sync::Mutex<RegistryInner>>, inner: Arc<tokio::sync::Mutex<RegistryInner>>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Instance { pub struct Instance {
pub component: String, pub component: String,
pub endpoint: String, pub endpoint: String,
...@@ -113,6 +115,30 @@ impl Instance { ...@@ -113,6 +115,30 @@ impl Instance {
} }
} }
impl fmt::Display for Instance {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}/{}/{}/{}",
self.namespace, self.component, self.endpoint, self.instance_id
)
}
}
/// Sort by string name
impl std::cmp::Ord for Instance {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.to_string().cmp(&other.to_string())
}
}
impl PartialOrd for Instance {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
// Since Ord is fully implemented, the comparison is always total.
Some(self.cmp(other))
}
}
/// A [Component] a discoverable entity in the distributed runtime. /// A [Component] a discoverable entity in the distributed runtime.
/// You can host [Endpoint] on a [Component] by first creating /// You can host [Endpoint] on a [Component] by first creating
/// a [Service] then adding one or more [Endpoint] to the [Service]. /// a [Service] then adding one or more [Endpoint] to the [Service].
...@@ -197,8 +223,8 @@ impl MetricsRegistry for Component { ...@@ -197,8 +223,8 @@ impl MetricsRegistry for Component {
} }
impl Component { impl Component {
/// The component part of an instance path in etcd. /// The component part of an instance path in key-value store.
pub fn etcd_root(&self) -> String { pub fn instance_root(&self) -> String {
let ns = self.namespace.name(); let ns = self.namespace.name();
let cp = &self.name; let cp = &self.name;
format!("{INSTANCE_ROOT_PATH}/{ns}/{cp}") format!("{INSTANCE_ROOT_PATH}/{ns}/{cp}")
...@@ -240,27 +266,23 @@ impl Component { ...@@ -240,27 +266,23 @@ impl Component {
} }
pub async fn list_instances(&self) -> anyhow::Result<Vec<Instance>> { pub async fn list_instances(&self) -> anyhow::Result<Vec<Instance>> {
let Some(etcd_client) = self.drt.etcd_client() else { let client = self.drt.store();
let Some(bucket) = client.get_bucket(&self.instance_root()).await? else {
return Ok(vec![]); return Ok(vec![]);
}; };
let mut out = vec![]; let entries = bucket.entries().await?;
// The extra slash is important to only list exact component matches, not substrings. let mut instances = Vec::with_capacity(entries.len());
for kv in etcd_client for (name, bytes) in entries.into_iter() {
.kv_get_prefix(format!("{}/", self.etcd_root())) let val = match serde_json::from_slice::<Instance>(&bytes) {
.await?
{
let val = match serde_json::from_slice::<Instance>(kv.value()) {
Ok(val) => val, Ok(val) => val,
Err(err) => { Err(err) => {
anyhow::bail!( anyhow::bail!("Error converting storage response to Instance: {err}. {name}",);
"Error converting etcd response to Instance: {err}. {}",
kv.value_str()?
);
} }
}; };
out.push(val); instances.push(val);
} }
Ok(out) instances.sort();
Ok(instances)
} }
/// Scrape ServiceSet, which contains NATS stats as well as user defined stats /// Scrape ServiceSet, which contains NATS stats as well as user defined stats
...@@ -445,7 +467,7 @@ impl Endpoint { ...@@ -445,7 +467,7 @@ impl Endpoint {
/// The endpoint part of an instance path in etcd /// The endpoint part of an instance path in etcd
pub fn etcd_root(&self) -> String { pub fn etcd_root(&self) -> String {
let component_path = self.component.etcd_root(); let component_path = self.component.instance_root();
let endpoint_name = &self.name; let endpoint_name = &self.name;
format!("{component_path}/{endpoint_name}") format!("{component_path}/{endpoint_name}")
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
pub use crate::component::Component; pub use crate::component::Component;
use crate::storage::key_value_store::{EtcdStore, KeyValueStore, MemoryStore};
use crate::transports::nats::DRTNatsClientPrometheusMetrics; use crate::transports::nats::DRTNatsClientPrometheusMetrics;
use crate::{ use crate::{
ErrorContext, RuntimeCallback, ErrorContext, RuntimeCallback,
...@@ -44,10 +45,14 @@ impl DistributedRuntime { ...@@ -44,10 +45,14 @@ impl DistributedRuntime {
let runtime_clone = runtime.clone(); let runtime_clone = runtime.clone();
let etcd_client = if is_static { let (etcd_client, store) = if is_static {
None let store: Arc<dyn KeyValueStore> = Arc::new(MemoryStore::new());
(None, store)
} else { } else {
Some(etcd::Client::new(etcd_config.clone(), runtime_clone).await?) let etcd_client = etcd::Client::new(etcd_config.clone(), runtime_clone).await?;
let store: Arc<dyn KeyValueStore> = Arc::new(EtcdStore::new(etcd_client.clone()));
(Some(etcd_client), store)
}; };
let nats_client = nats_config.clone().connect().await?; let nats_client = nats_config.clone().connect().await?;
...@@ -77,6 +82,7 @@ impl DistributedRuntime { ...@@ -77,6 +82,7 @@ impl DistributedRuntime {
let distributed_runtime = Self { let distributed_runtime = Self {
runtime, runtime,
etcd_client, etcd_client,
store,
nats_client, nats_client,
tcp_server: Arc::new(OnceCell::new()), tcp_server: Arc::new(OnceCell::new()),
system_status_server: Arc::new(OnceLock::new()), system_status_server: Arc::new(OnceLock::new()),
...@@ -270,6 +276,12 @@ impl DistributedRuntime { ...@@ -270,6 +276,12 @@ impl DistributedRuntime {
self.etcd_client.clone() self.etcd_client.clone()
} }
/// An interface to store things. Will eventually replace `etcd_client`.
/// Currently does key-value, but will grow to include whatever we need to store.
pub fn store(&self) -> Arc<dyn KeyValueStore> {
self.store.clone()
}
pub fn child_token(&self) -> CancellationToken { pub fn child_token(&self) -> CancellationToken {
self.runtime.child_token() self.runtime.child_token()
} }
......
...@@ -7,28 +7,28 @@ ...@@ -7,28 +7,28 @@
//! the entire distributed system, complementing the component-specific //! the entire distributed system, complementing the component-specific
//! instance listing in `component.rs`. //! instance listing in `component.rs`.
use std::sync::Arc;
use crate::component::{INSTANCE_ROOT_PATH, Instance}; use crate::component::{INSTANCE_ROOT_PATH, Instance};
use crate::storage::key_value_store::KeyValueStore;
use crate::transports::etcd::Client as EtcdClient; use crate::transports::etcd::Client as EtcdClient;
pub async fn list_all_instances(etcd_client: &EtcdClient) -> anyhow::Result<Vec<Instance>> { pub async fn list_all_instances(client: Arc<dyn KeyValueStore>) -> anyhow::Result<Vec<Instance>> {
let mut instances = Vec::new(); let Some(bucket) = client.get_bucket(INSTANCE_ROOT_PATH).await? else {
return Ok(vec![]);
};
for kv in etcd_client let entries = bucket.entries().await?;
.kv_get_prefix(format!("{}/", INSTANCE_ROOT_PATH)) let mut instances = Vec::with_capacity(entries.len());
.await? for (name, bytes) in entries.into_iter() {
{ match serde_json::from_slice::<Instance>(&bytes) {
match serde_json::from_slice::<Instance>(kv.value()) {
Ok(instance) => instances.push(instance), Ok(instance) => instances.push(instance),
Err(err) => { Err(err) => {
tracing::warn!( tracing::warn!(%err, key = name, "Failed to parse instance from storage");
"Failed to parse instance from etcd: {}. Key: {}, Value: {}",
err,
kv.key_str().unwrap_or("invalid_key"),
kv.value_str().unwrap_or("invalid_value")
);
} }
} }
} }
instances.sort();
Ok(instances) Ok(instances)
} }
...@@ -51,7 +51,9 @@ pub use system_health::{HealthCheckTarget, SystemHealth}; ...@@ -51,7 +51,9 @@ pub use system_health::{HealthCheckTarget, SystemHealth};
pub use tokio_util::sync::CancellationToken; pub use tokio_util::sync::CancellationToken;
pub use worker::Worker; pub use worker::Worker;
use crate::metrics::prometheus_names::distributed_runtime; use crate::{
metrics::prometheus_names::distributed_runtime, storage::key_value_store::KeyValueStore,
};
use component::{Endpoint, InstanceSource}; use component::{Endpoint, InstanceSource};
use utils::GracefulShutdownTracker; use utils::GracefulShutdownTracker;
...@@ -152,6 +154,7 @@ pub struct DistributedRuntime { ...@@ -152,6 +154,7 @@ pub struct DistributedRuntime {
// we might consider a unifed transport manager here // we might consider a unifed transport manager here
etcd_client: Option<transports::etcd::Client>, etcd_client: Option<transports::etcd::Client>,
nats_client: transports::nats::Client, nats_client: transports::nats::Client,
store: Arc<dyn KeyValueStore>,
tcp_server: Arc<OnceCell<Arc<transports::tcp::server::TcpStreamServer>>>, tcp_server: Arc<OnceCell<Arc<transports::tcp::server::TcpStreamServer>>>,
system_status_server: Arc<OnceLock<Arc<system_status_server::SystemStatusServerInfo>>>, system_status_server: Arc<OnceLock<Arc<system_status_server::SystemStatusServerInfo>>>,
......
...@@ -17,11 +17,11 @@ use futures::StreamExt; ...@@ -17,11 +17,11 @@ use futures::StreamExt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
mod mem; mod mem;
pub use mem::MemoryStorage; pub use mem::MemoryStore;
mod nats; mod nats;
pub use nats::NATSStorage; pub use nats::NATSStore;
mod etcd; mod etcd;
pub use etcd::EtcdStorage; pub use etcd::EtcdStore;
/// A key that is safe to use directly in the KV store. /// A key that is safe to use directly in the KV store.
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
...@@ -69,12 +69,14 @@ pub trait KeyValueStore: Send + Sync { ...@@ -69,12 +69,14 @@ pub trait KeyValueStore: Send + Sync {
bucket_name: &str, bucket_name: &str,
// auto-delete items older than this // auto-delete items older than this
ttl: Option<Duration>, ttl: Option<Duration>,
) -> Result<Box<dyn KeyValueBucket>, StorageError>; ) -> Result<Box<dyn KeyValueBucket>, StoreError>;
async fn get_bucket( async fn get_bucket(
&self, &self,
bucket_name: &str, bucket_name: &str,
) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError>; ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError>;
fn connection_id(&self) -> u64;
} }
pub struct KeyValueStoreManager(Box<dyn KeyValueStore>); pub struct KeyValueStoreManager(Box<dyn KeyValueStore>);
...@@ -88,7 +90,7 @@ impl KeyValueStoreManager { ...@@ -88,7 +90,7 @@ impl KeyValueStoreManager {
&self, &self,
bucket: &str, bucket: &str,
key: &Key, key: &Key,
) -> Result<Option<T>, StorageError> { ) -> Result<Option<T>, StoreError> {
let Some(bucket) = self.0.get_bucket(bucket).await? else { let Some(bucket) = self.0.get_bucket(bucket).await? else {
// No bucket means no cards // No bucket means no cards
return Ok(None); return Ok(None);
...@@ -101,7 +103,7 @@ impl KeyValueStoreManager { ...@@ -101,7 +103,7 @@ impl KeyValueStoreManager {
Ok(None) => Ok(None), Ok(None) => Ok(None),
Err(err) => { Err(err) => {
// TODO look at what errors NATS can give us and make more specific wrappers // TODO look at what errors NATS can give us and make more specific wrappers
Err(StorageError::NATSError(err.to_string())) Err(StoreError::NATSError(err.to_string()))
} }
} }
} }
...@@ -114,7 +116,7 @@ impl KeyValueStoreManager { ...@@ -114,7 +116,7 @@ impl KeyValueStoreManager {
bucket_name: &str, bucket_name: &str,
bucket_ttl: Option<Duration>, bucket_ttl: Option<Duration>,
) -> ( ) -> (
tokio::task::JoinHandle<Result<(), StorageError>>, tokio::task::JoinHandle<Result<(), StoreError>>,
tokio::sync::mpsc::UnboundedReceiver<T>, tokio::sync::mpsc::UnboundedReceiver<T>,
) { ) {
let bucket_name = bucket_name.to_string(); let bucket_name = bucket_name.to_string();
...@@ -139,7 +141,7 @@ impl KeyValueStoreManager { ...@@ -139,7 +141,7 @@ impl KeyValueStoreManager {
let _ = tx.send(card); let _ = tx.send(card);
} }
Ok::<(), StorageError>(()) Ok::<(), StoreError>(())
}); });
(watch_task, rx) (watch_task, rx)
} }
...@@ -150,14 +152,14 @@ impl KeyValueStoreManager { ...@@ -150,14 +152,14 @@ impl KeyValueStoreManager {
bucket_ttl: Option<Duration>, bucket_ttl: Option<Duration>,
key: &Key, key: &Key,
obj: &mut T, obj: &mut T,
) -> anyhow::Result<StorageOutcome> { ) -> anyhow::Result<StoreOutcome> {
let obj_json = serde_json::to_string(obj)?; let obj_json = serde_json::to_string(obj)?;
let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?; let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;
let outcome = bucket.insert(key, &obj_json, obj.revision()).await?; let outcome = bucket.insert(key, &obj_json, obj.revision()).await?;
match outcome { match outcome {
StorageOutcome::Created(revision) | StorageOutcome::Exists(revision) => { StoreOutcome::Created(revision) | StoreOutcome::Exists(revision) => {
obj.set_revision(revision); obj.set_revision(revision);
} }
} }
...@@ -176,43 +178,43 @@ pub trait KeyValueBucket: Send { ...@@ -176,43 +178,43 @@ pub trait KeyValueBucket: Send {
key: &Key, key: &Key,
value: &str, value: &str,
revision: u64, revision: u64,
) -> Result<StorageOutcome, StorageError>; ) -> Result<StoreOutcome, StoreError>;
/// Fetch an item from the key-value storage /// Fetch an item from the key-value storage
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError>; async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError>;
/// Delete an item from the bucket /// Delete an item from the bucket
async fn delete(&self, key: &Key) -> Result<(), StorageError>; async fn delete(&self, key: &Key) -> Result<(), StoreError>;
/// A stream of items inserted into the bucket. /// A stream of items inserted into the bucket.
/// Every time the stream is polled it will either return a newly created entry, or block until /// Every time the stream is polled it will either return a newly created entry, or block until
/// such time. /// such time.
async fn watch( async fn watch(
&self, &self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>; ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>;
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError>; async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError>;
} }
#[derive(Debug, Copy, Clone, Eq, PartialEq)] #[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum StorageOutcome { pub enum StoreOutcome {
/// The operation succeeded and created a new entry with this revision. /// The operation succeeded and created a new entry with this revision.
/// Note that "create" also means update, because each new revision is a "create". /// Note that "create" also means update, because each new revision is a "create".
Created(u64), Created(u64),
/// The operation did not do anything, the value was already present, with this revision. /// The operation did not do anything, the value was already present, with this revision.
Exists(u64), Exists(u64),
} }
impl fmt::Display for StorageOutcome { impl fmt::Display for StoreOutcome {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
StorageOutcome::Created(revision) => write!(f, "Created at {revision}"), StoreOutcome::Created(revision) => write!(f, "Created at {revision}"),
StorageOutcome::Exists(revision) => write!(f, "Exists at {revision}"), StoreOutcome::Exists(revision) => write!(f, "Exists at {revision}"),
} }
} }
} }
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum StorageError { pub enum StoreError {
#[error("Could not find bucket '{0}'")] #[error("Could not find bucket '{0}'")]
MissingBucket(String), MissingBucket(String),
...@@ -291,12 +293,12 @@ mod tests { ...@@ -291,12 +293,12 @@ mod tests {
async fn test_memory_storage() -> anyhow::Result<()> { async fn test_memory_storage() -> anyhow::Result<()> {
init(); init();
let s = Arc::new(MemoryStorage::new()); let s = Arc::new(MemoryStore::new());
let s2 = Arc::clone(&s); let s2 = Arc::clone(&s);
let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?; let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
let res = bucket.insert(&"test1".into(), "value1", 0).await?; let res = bucket.insert(&"test1".into(), "value1", 0).await?;
assert_eq!(res, StorageOutcome::Created(0)); assert_eq!(res, StoreOutcome::Created(0));
let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel(); let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
let ingress = tokio::spawn(async move { let ingress = tokio::spawn(async move {
...@@ -315,27 +317,27 @@ mod tests { ...@@ -315,27 +317,27 @@ mod tests {
let v = stream.next().await.unwrap(); let v = stream.next().await.unwrap();
assert_eq!(v, "value3".as_bytes()); assert_eq!(v, "value3".as_bytes());
Ok::<_, StorageError>(()) Ok::<_, StoreError>(())
}); });
// MemoryStorage uses a HashMap with no inherent ordering, so we must ensure test1 is // MemoryStore uses a HashMap with no inherent ordering, so we must ensure test1 is
// fetched before test2 is inserted, otherwise they can come out in any order, and we // fetched before test2 is inserted, otherwise they can come out in any order, and we
// wouldn't be testing the watch behavior. // wouldn't be testing the watch behavior.
got_first_rx.await?; got_first_rx.await?;
let res = bucket.insert(&"test2".into(), "value2", 0).await?; let res = bucket.insert(&"test2".into(), "value2", 0).await?;
assert_eq!(res, StorageOutcome::Created(0)); assert_eq!(res, StoreOutcome::Created(0));
// Repeat a key and revision. Ignored. // Repeat a key and revision. Ignored.
let res = bucket.insert(&"test2".into(), "value2", 0).await?; let res = bucket.insert(&"test2".into(), "value2", 0).await?;
assert_eq!(res, StorageOutcome::Exists(0)); assert_eq!(res, StoreOutcome::Exists(0));
// Increment revision // Increment revision
let res = bucket.insert(&"test2".into(), "value2", 1).await?; let res = bucket.insert(&"test2".into(), "value2", 1).await?;
assert_eq!(res, StorageOutcome::Created(1)); assert_eq!(res, StoreOutcome::Created(1));
let res = bucket.insert(&"test3".into(), "value3", 0).await?; let res = bucket.insert(&"test3".into(), "value3", 0).await?;
assert_eq!(res, StorageOutcome::Created(0)); assert_eq!(res, StoreOutcome::Created(0));
// ingress exits once it has received all values // ingress exits once it has received all values
let _ = ingress.await?; let _ = ingress.await?;
...@@ -347,12 +349,12 @@ mod tests { ...@@ -347,12 +349,12 @@ mod tests {
async fn test_broadcast_stream() -> anyhow::Result<()> { async fn test_broadcast_stream() -> anyhow::Result<()> {
init(); init();
let s: &'static _ = Box::leak(Box::new(MemoryStorage::new())); let s: &'static _ = Box::leak(Box::new(MemoryStore::new()));
let bucket: &'static _ = let bucket: &'static _ =
Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?)); Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));
let res = bucket.insert(&"test1".into(), "value1", 0).await?; let res = bucket.insert(&"test1".into(), "value1", 0).await?;
assert_eq!(res, StorageOutcome::Created(0)); assert_eq!(res, StoreOutcome::Created(0));
let stream = bucket.watch().await?; let stream = bucket.watch().await?;
let tap = TappableStream::new(stream, 10).await; let tap = TappableStream::new(stream, 10).await;
......
...@@ -10,27 +10,27 @@ use async_stream::stream; ...@@ -10,27 +10,27 @@ use async_stream::stream;
use async_trait::async_trait; use async_trait::async_trait;
use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions}; use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome}; use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
#[derive(Clone)] #[derive(Clone)]
pub struct EtcdStorage { pub struct EtcdStore {
client: Client, client: Client,
} }
impl EtcdStorage { impl EtcdStore {
pub fn new(client: Client) -> Self { pub fn new(client: Client) -> Self {
Self { client } Self { client }
} }
} }
#[async_trait] #[async_trait]
impl KeyValueStore for EtcdStorage { impl KeyValueStore for EtcdStore {
/// A "bucket" in etcd is a path prefix /// A "bucket" in etcd is a path prefix
async fn get_or_create_bucket( async fn get_or_create_bucket(
&self, &self,
bucket_name: &str, bucket_name: &str,
_ttl: Option<Duration>, // TODO ttl not used yet _ttl: Option<Duration>, // TODO ttl not used yet
) -> Result<Box<dyn KeyValueBucket>, StorageError> { ) -> Result<Box<dyn KeyValueBucket>, StoreError> {
Ok(self.get_bucket(bucket_name).await?.unwrap()) Ok(self.get_bucket(bucket_name).await?.unwrap())
} }
...@@ -39,12 +39,18 @@ impl KeyValueStore for EtcdStorage { ...@@ -39,12 +39,18 @@ impl KeyValueStore for EtcdStorage {
async fn get_bucket( async fn get_bucket(
&self, &self,
bucket_name: &str, bucket_name: &str,
) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> { ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
Ok(Some(Box::new(EtcdBucket { Ok(Some(Box::new(EtcdBucket {
client: self.client.clone(), client: self.client.clone(),
bucket_name: bucket_name.to_string(), bucket_name: bucket_name.to_string(),
}))) })))
} }
fn connection_id(&self) -> u64 {
// This conversion from i64 to u64 is safe because etcd lease IDs are u64 internally.
// They present as i64 because of the limitations of the etcd grpc/HTTP JSON API.
self.client.lease_id() as u64
}
} }
pub struct EtcdBucket { pub struct EtcdBucket {
...@@ -60,7 +66,7 @@ impl KeyValueBucket for EtcdBucket { ...@@ -60,7 +66,7 @@ impl KeyValueBucket for EtcdBucket {
value: &str, value: &str,
// "version" in etcd speak. revision is a global cluster-wide value // "version" in etcd speak. revision is a global cluster-wide value
revision: u64, revision: u64,
) -> Result<StorageOutcome, StorageError> { ) -> Result<StoreOutcome, StoreError> {
let version = revision; let version = revision;
if version == 0 { if version == 0 {
self.create(key, value).await self.create(key, value).await
...@@ -69,7 +75,7 @@ impl KeyValueBucket for EtcdBucket { ...@@ -69,7 +75,7 @@ impl KeyValueBucket for EtcdBucket {
} }
} }
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError> { async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
let k = make_key(&self.bucket_name, key); let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd get: {k}"); tracing::trace!("etcd get: {k}");
...@@ -77,7 +83,7 @@ impl KeyValueBucket for EtcdBucket { ...@@ -77,7 +83,7 @@ impl KeyValueBucket for EtcdBucket {
.client .client
.kv_get(k, None) .kv_get(k, None)
.await .await
.map_err(|e| StorageError::EtcdError(e.to_string()))?; .map_err(|e| StoreError::EtcdError(e.to_string()))?;
if kvs.is_empty() { if kvs.is_empty() {
return Ok(None); return Ok(None);
} }
...@@ -85,20 +91,20 @@ impl KeyValueBucket for EtcdBucket { ...@@ -85,20 +91,20 @@ impl KeyValueBucket for EtcdBucket {
Ok(Some(val.into())) Ok(Some(val.into()))
} }
async fn delete(&self, key: &Key) -> Result<(), StorageError> { async fn delete(&self, key: &Key) -> Result<(), StoreError> {
let k = make_key(&self.bucket_name, key); let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd delete: {k}"); tracing::trace!("etcd delete: {k}");
let _ = self let _ = self
.client .client
.kv_delete(k, None) .kv_delete(k, None)
.await .await
.map_err(|e| StorageError::EtcdError(e.to_string()))?; .map_err(|e| StoreError::EtcdError(e.to_string()))?;
Ok(()) Ok(())
} }
async fn watch( async fn watch(
&self, &self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError> ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>
{ {
let k = make_key(&self.bucket_name, &"".into()); let k = make_key(&self.bucket_name, &"".into());
tracing::trace!("etcd watch: {k}"); tracing::trace!("etcd watch: {k}");
...@@ -108,7 +114,7 @@ impl KeyValueBucket for EtcdBucket { ...@@ -108,7 +114,7 @@ impl KeyValueBucket for EtcdBucket {
.clone() .clone()
.watch(k.as_bytes(), Some(WatchOptions::new().with_prefix())) .watch(k.as_bytes(), Some(WatchOptions::new().with_prefix()))
.await .await
.map_err(|e| StorageError::EtcdError(e.to_string()))?; .map_err(|e| StoreError::EtcdError(e.to_string()))?;
let output = stream! { let output = stream! {
while let Ok(Some(resp)) = watch_stream.message().await { while let Ok(Some(resp)) = watch_stream.message().await {
for e in resp.events() { for e in resp.events() {
...@@ -122,7 +128,7 @@ impl KeyValueBucket for EtcdBucket { ...@@ -122,7 +128,7 @@ impl KeyValueBucket for EtcdBucket {
Ok(Box::pin(output)) Ok(Box::pin(output))
} }
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> { async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
let k = make_key(&self.bucket_name, &"".into()); let k = make_key(&self.bucket_name, &"".into());
tracing::trace!("etcd entries: {k}"); tracing::trace!("etcd entries: {k}");
...@@ -130,7 +136,7 @@ impl KeyValueBucket for EtcdBucket { ...@@ -130,7 +136,7 @@ impl KeyValueBucket for EtcdBucket {
.client .client
.kv_get_prefix(k) .kv_get_prefix(k)
.await .await
.map_err(|e| StorageError::EtcdError(e.to_string()))?; .map_err(|e| StoreError::EtcdError(e.to_string()))?;
let out: HashMap<String, bytes::Bytes> = resp let out: HashMap<String, bytes::Bytes> = resp
.into_iter() .into_iter()
.map(|kv| { .map(|kv| {
...@@ -144,7 +150,7 @@ impl KeyValueBucket for EtcdBucket { ...@@ -144,7 +150,7 @@ impl KeyValueBucket for EtcdBucket {
} }
impl EtcdBucket { impl EtcdBucket {
async fn create(&self, key: &Key, value: &str) -> Result<StorageOutcome, StorageError> { async fn create(&self, key: &Key, value: &str) -> Result<StoreOutcome, StoreError> {
let k = make_key(&self.bucket_name, key); let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd create: {k}"); tracing::trace!("etcd create: {k}");
...@@ -166,11 +172,11 @@ impl EtcdBucket { ...@@ -166,11 +172,11 @@ impl EtcdBucket {
.kv_client() .kv_client()
.txn(txn) .txn(txn)
.await .await
.map_err(|e| StorageError::EtcdError(e.to_string()))?; .map_err(|e| StoreError::EtcdError(e.to_string()))?;
if result.succeeded() { if result.succeeded() {
// Key was created successfully // Key was created successfully
return Ok(StorageOutcome::Created(1)); // version of new key is always 1 return Ok(StoreOutcome::Created(1)); // version of new key is always 1
} }
// Key already existed, get its version // Key already existed, get its version
...@@ -179,10 +185,10 @@ impl EtcdBucket { ...@@ -179,10 +185,10 @@ impl EtcdBucket {
&& let Some(kv) = get_resp.kvs().first() && let Some(kv) = get_resp.kvs().first()
{ {
let version = kv.version() as u64; let version = kv.version() as u64;
return Ok(StorageOutcome::Exists(version)); return Ok(StoreOutcome::Exists(version));
} }
// Shouldn't happen, but handle edge case // Shouldn't happen, but handle edge case
Err(StorageError::EtcdError( Err(StoreError::EtcdError(
"Unexpected transaction response".to_string(), "Unexpected transaction response".to_string(),
)) ))
} }
...@@ -192,7 +198,7 @@ impl EtcdBucket { ...@@ -192,7 +198,7 @@ impl EtcdBucket {
key: &Key, key: &Key,
value: &str, value: &str,
revision: u64, revision: u64,
) -> Result<StorageOutcome, StorageError> { ) -> Result<StoreOutcome, StoreError> {
let version = revision; let version = revision;
let k = make_key(&self.bucket_name, key); let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd update: {k}"); tracing::trace!("etcd update: {k}");
...@@ -201,9 +207,9 @@ impl EtcdBucket { ...@@ -201,9 +207,9 @@ impl EtcdBucket {
.client .client
.kv_get(k.clone(), None) .kv_get(k.clone(), None)
.await .await
.map_err(|e| StorageError::EtcdError(e.to_string()))?; .map_err(|e| StoreError::EtcdError(e.to_string()))?;
if kvs.is_empty() { if kvs.is_empty() {
return Err(StorageError::MissingKey(key.to_string())); return Err(StoreError::MissingKey(key.to_string()));
} }
let current_version = kvs.first().unwrap().version() as u64; let current_version = kvs.first().unwrap().version() as u64;
if current_version != version + 1 { if current_version != version + 1 {
...@@ -224,17 +230,17 @@ impl EtcdBucket { ...@@ -224,17 +230,17 @@ impl EtcdBucket {
.client .client
.kv_put_with_options(k, value, Some(put_options)) .kv_put_with_options(k, value, Some(put_options))
.await .await
.map_err(|e| StorageError::EtcdError(e.to_string()))?; .map_err(|e| StoreError::EtcdError(e.to_string()))?;
Ok(match put_resp.take_prev_key() { Ok(match put_resp.take_prev_key() {
// Should this be an error? // Should this be an error?
// The key was deleted between our get and put. We re-created it. // The key was deleted between our get and put. We re-created it.
// Version of new key is always 1. // Version of new key is always 1.
// <https://etcd.io/docs/v3.5/learning/data_model/> // <https://etcd.io/docs/v3.5/learning/data_model/>
None => StorageOutcome::Created(1), None => StoreOutcome::Created(1),
// Expected case, success // Expected case, success
Some(kv) if kv.version() as u64 == version + 1 => StorageOutcome::Created(version), Some(kv) if kv.version() as u64 == version + 1 => StoreOutcome::Created(version),
// Should this be an error? Something updated the version between our get and put // Should this be an error? Something updated the version between our get and put
Some(kv) => StorageOutcome::Created(kv.version() as u64 + 1), Some(kv) => StoreOutcome::Created(kv.version() as u64 + 1),
}) })
} }
} }
...@@ -263,9 +269,9 @@ mod concurrent_create_tests { ...@@ -263,9 +269,9 @@ mod concurrent_create_tests {
}); });
} }
async fn test_concurrent_create(drt: DistributedRuntime) -> Result<(), StorageError> { async fn test_concurrent_create(drt: DistributedRuntime) -> Result<(), StoreError> {
let etcd_client = drt.etcd_client().expect("etcd client should be available"); let etcd_client = drt.etcd_client().expect("etcd client should be available");
let storage = EtcdStorage::new(etcd_client); let storage = EtcdStore::new(etcd_client);
// Create a bucket for testing // Create a bucket for testing
let bucket = Arc::new(tokio::sync::Mutex::new( let bucket = Arc::new(tokio::sync::Mutex::new(
...@@ -307,7 +313,7 @@ mod concurrent_create_tests { ...@@ -307,7 +313,7 @@ mod concurrent_create_tests {
.await; .await;
match result { match result {
Ok(StorageOutcome::Created(version)) => { Ok(StoreOutcome::Created(version)) => {
println!( println!(
"Worker {} successfully created key with version {}", "Worker {} successfully created key with version {}",
worker_id, version worker_id, version
...@@ -316,7 +322,7 @@ mod concurrent_create_tests { ...@@ -316,7 +322,7 @@ mod concurrent_create_tests {
*count += 1; *count += 1;
Ok(version) Ok(version)
} }
Ok(StorageOutcome::Exists(version)) => { Ok(StoreOutcome::Exists(version)) => {
println!( println!(
"Worker {} found key already exists with version {}", "Worker {} found key already exists with version {}",
worker_id, version worker_id, version
......
...@@ -8,25 +8,27 @@ use std::sync::Arc; ...@@ -8,25 +8,27 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use rand::Rng as _;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use crate::storage::key_value_store::Key; use crate::storage::key_value_store::Key;
use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome}; use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
#[derive(Clone)] #[derive(Clone)]
pub struct MemoryStorage { pub struct MemoryStore {
inner: Arc<MemoryStorageInner>, inner: Arc<MemoryStoreInner>,
connection_id: u64,
} }
impl Default for MemoryStorage { impl Default for MemoryStore {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }
} }
struct MemoryStorageInner { struct MemoryStoreInner {
data: Mutex<HashMap<String, MemoryBucket>>, data: Mutex<HashMap<String, MemoryBucket>>,
change_sender: UnboundedSender<(String, String)>, change_sender: UnboundedSender<(String, String)>,
change_receiver: Mutex<UnboundedReceiver<(String, String)>>, change_receiver: Mutex<UnboundedReceiver<(String, String)>>,
...@@ -34,7 +36,7 @@ struct MemoryStorageInner { ...@@ -34,7 +36,7 @@ struct MemoryStorageInner {
pub struct MemoryBucketRef { pub struct MemoryBucketRef {
name: String, name: String,
inner: Arc<MemoryStorageInner>, inner: Arc<MemoryStoreInner>,
} }
struct MemoryBucket { struct MemoryBucket {
...@@ -49,27 +51,28 @@ impl MemoryBucket { ...@@ -49,27 +51,28 @@ impl MemoryBucket {
} }
} }
impl MemoryStorage { impl MemoryStore {
pub fn new() -> Self { pub fn new() -> Self {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
MemoryStorage { MemoryStore {
inner: Arc::new(MemoryStorageInner { inner: Arc::new(MemoryStoreInner {
data: Mutex::new(HashMap::new()), data: Mutex::new(HashMap::new()),
change_sender: tx, change_sender: tx,
change_receiver: Mutex::new(rx), change_receiver: Mutex::new(rx),
}), }),
connection_id: rand::rng().random(),
} }
} }
} }
#[async_trait] #[async_trait]
impl KeyValueStore for MemoryStorage { impl KeyValueStore for MemoryStore {
async fn get_or_create_bucket( async fn get_or_create_bucket(
&self, &self,
bucket_name: &str, bucket_name: &str,
// MemoryStorage doesn't respect TTL yet // MemoryStore doesn't respect TTL yet
_ttl: Option<Duration>, _ttl: Option<Duration>,
) -> Result<Box<dyn KeyValueBucket>, StorageError> { ) -> Result<Box<dyn KeyValueBucket>, StoreError> {
let mut locked_data = self.inner.data.lock().await; let mut locked_data = self.inner.data.lock().await;
// Ensure the bucket exists // Ensure the bucket exists
locked_data locked_data
...@@ -82,11 +85,11 @@ impl KeyValueStore for MemoryStorage { ...@@ -82,11 +85,11 @@ impl KeyValueStore for MemoryStorage {
})) }))
} }
/// This operation cannot fail on MemoryStorage. Always returns Ok. /// This operation cannot fail on MemoryStore. Always returns Ok.
async fn get_bucket( async fn get_bucket(
&self, &self,
bucket_name: &str, bucket_name: &str,
) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> { ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
let locked_data = self.inner.data.lock().await; let locked_data = self.inner.data.lock().await;
match locked_data.get(bucket_name) { match locked_data.get(bucket_name) {
Some(_) => Ok(Some(Box::new(MemoryBucketRef { Some(_) => Ok(Some(Box::new(MemoryBucketRef {
...@@ -96,6 +99,10 @@ impl KeyValueStore for MemoryStorage { ...@@ -96,6 +99,10 @@ impl KeyValueStore for MemoryStorage {
None => Ok(None), None => Ok(None),
} }
} }
fn connection_id(&self) -> u64 {
self.connection_id
}
} }
#[async_trait] #[async_trait]
...@@ -105,11 +112,11 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -105,11 +112,11 @@ impl KeyValueBucket for MemoryBucketRef {
key: &Key, key: &Key,
value: &str, value: &str,
revision: u64, revision: u64,
) -> Result<StorageOutcome, StorageError> { ) -> Result<StoreOutcome, StoreError> {
let mut locked_data = self.inner.data.lock().await; let mut locked_data = self.inner.data.lock().await;
let mut b = locked_data.get_mut(&self.name); let mut b = locked_data.get_mut(&self.name);
let Some(bucket) = b.as_mut() else { let Some(bucket) = b.as_mut() else {
return Err(StorageError::MissingBucket(self.name.to_string())); return Err(StoreError::MissingBucket(self.name.to_string()));
}; };
let outcome = match bucket.data.entry(key.to_string()) { let outcome = match bucket.data.entry(key.to_string()) {
Entry::Vacant(e) => { Entry::Vacant(e) => {
...@@ -118,22 +125,22 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -118,22 +125,22 @@ impl KeyValueBucket for MemoryBucketRef {
.inner .inner
.change_sender .change_sender
.send((key.to_string(), value.to_string())); .send((key.to_string(), value.to_string()));
StorageOutcome::Created(revision) StoreOutcome::Created(revision)
} }
Entry::Occupied(mut entry) => { Entry::Occupied(mut entry) => {
let (rev, _v) = entry.get(); let (rev, _v) = entry.get();
if *rev == revision { if *rev == revision {
StorageOutcome::Exists(revision) StoreOutcome::Exists(revision)
} else { } else {
entry.insert((revision, value.to_string())); entry.insert((revision, value.to_string()));
StorageOutcome::Created(revision) StoreOutcome::Created(revision)
} }
} }
}; };
Ok(outcome) Ok(outcome)
} }
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError> { async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
let locked_data = self.inner.data.lock().await; let locked_data = self.inner.data.lock().await;
let Some(bucket) = locked_data.get(&self.name) else { let Some(bucket) = locked_data.get(&self.name) else {
return Ok(None); return Ok(None);
...@@ -144,10 +151,10 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -144,10 +151,10 @@ impl KeyValueBucket for MemoryBucketRef {
.map(|(_, v)| bytes::Bytes::from(v.clone()))) .map(|(_, v)| bytes::Bytes::from(v.clone())))
} }
async fn delete(&self, key: &Key) -> Result<(), StorageError> { async fn delete(&self, key: &Key) -> Result<(), StoreError> {
let mut locked_data = self.inner.data.lock().await; let mut locked_data = self.inner.data.lock().await;
let Some(bucket) = locked_data.get_mut(&self.name) else { let Some(bucket) = locked_data.get_mut(&self.name) else {
return Err(StorageError::MissingBucket(self.name.to_string())); return Err(StoreError::MissingBucket(self.name.to_string()));
}; };
bucket.data.remove(&key.0); bucket.data.remove(&key.0);
Ok(()) Ok(())
...@@ -158,7 +165,7 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -158,7 +165,7 @@ impl KeyValueBucket for MemoryBucketRef {
/// Caller takes the lock so only a single caller may use this at once. /// Caller takes the lock so only a single caller may use this at once.
async fn watch( async fn watch(
&self, &self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError> ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>
{ {
Ok(Box::pin(async_stream::stream! { Ok(Box::pin(async_stream::stream! {
// All the existing ones first // All the existing ones first
...@@ -192,7 +199,7 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -192,7 +199,7 @@ impl KeyValueBucket for MemoryBucketRef {
})) }))
} }
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> { async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
let locked_data = self.inner.data.lock().await; let locked_data = self.inner.data.lock().await;
match locked_data.get(&self.name) { match locked_data.get(&self.name) {
Some(bucket) => Ok(bucket Some(bucket) => Ok(bucket
...@@ -200,7 +207,7 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -200,7 +207,7 @@ impl KeyValueBucket for MemoryBucketRef {
.iter() .iter()
.map(|(k, (_rev, v))| (k.to_string(), bytes::Bytes::from(v.clone()))) .map(|(k, (_rev, v))| (k.to_string(), bytes::Bytes::from(v.clone())))
.collect()), .collect()),
None => Err(StorageError::MissingBucket(self.name.clone())), None => Err(StoreError::MissingBucket(self.name.clone())),
} }
} }
} }
...@@ -9,10 +9,10 @@ use crate::{ ...@@ -9,10 +9,10 @@ use crate::{
use async_trait::async_trait; use async_trait::async_trait;
use futures::StreamExt; use futures::StreamExt;
use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome}; use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
#[derive(Clone)] #[derive(Clone)]
pub struct NATSStorage { pub struct NATSStore {
client: Client, client: Client,
endpoint: EndpointId, endpoint: EndpointId,
} }
...@@ -22,12 +22,12 @@ pub struct NATSBucket { ...@@ -22,12 +22,12 @@ pub struct NATSBucket {
} }
#[async_trait] #[async_trait]
impl KeyValueStore for NATSStorage { impl KeyValueStore for NATSStore {
async fn get_or_create_bucket( async fn get_or_create_bucket(
&self, &self,
bucket_name: &str, bucket_name: &str,
ttl: Option<Duration>, ttl: Option<Duration>,
) -> Result<Box<dyn KeyValueBucket>, StorageError> { ) -> Result<Box<dyn KeyValueBucket>, StoreError> {
let name = Slug::slugify(bucket_name); let name = Slug::slugify(bucket_name);
let nats_store = self let nats_store = self
.get_or_create_key_value(&self.endpoint.namespace, &name, ttl) .get_or_create_key_value(&self.endpoint.namespace, &name, ttl)
...@@ -38,18 +38,22 @@ impl KeyValueStore for NATSStorage { ...@@ -38,18 +38,22 @@ impl KeyValueStore for NATSStorage {
async fn get_bucket( async fn get_bucket(
&self, &self,
bucket_name: &str, bucket_name: &str,
) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> { ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
let name = Slug::slugify(bucket_name); let name = Slug::slugify(bucket_name);
match self.get_key_value(&self.endpoint.namespace, &name).await? { match self.get_key_value(&self.endpoint.namespace, &name).await? {
Some(nats_store) => Ok(Some(Box::new(NATSBucket { nats_store }))), Some(nats_store) => Ok(Some(Box::new(NATSBucket { nats_store }))),
None => Ok(None), None => Ok(None),
} }
} }
fn connection_id(&self) -> u64 {
self.client.client().server_info().client_id
}
} }
impl NATSStorage { impl NATSStore {
pub fn new(client: Client, endpoint: EndpointId) -> Self { pub fn new(client: Client, endpoint: EndpointId) -> Self {
NATSStorage { client, endpoint } NATSStore { client, endpoint }
} }
/// Get or create a key-value store (aka bucket) in NATS. /// Get or create a key-value store (aka bucket) in NATS.
...@@ -62,7 +66,7 @@ impl NATSStorage { ...@@ -62,7 +66,7 @@ impl NATSStorage {
bucket_name: &Slug, bucket_name: &Slug,
// Delete entries older than this // Delete entries older than this
ttl: Option<Duration>, ttl: Option<Duration>,
) -> Result<async_nats::jetstream::kv::Store, StorageError> { ) -> Result<async_nats::jetstream::kv::Store, StoreError> {
if let Ok(Some(kv)) = self.get_key_value(namespace, bucket_name).await { if let Ok(Some(kv)) = self.get_key_value(namespace, bucket_name).await {
return Ok(kv); return Ok(kv);
} }
...@@ -82,7 +86,7 @@ impl NATSStorage { ...@@ -82,7 +86,7 @@ impl NATSStorage {
) )
.await; .await;
let nats_store = create_result let nats_store = create_result
.map_err(|err| StorageError::KeyValueError(err.to_string(), bucket_name.clone()))?; .map_err(|err| StoreError::KeyValueError(err.to_string(), bucket_name.clone()))?;
tracing::debug!("Created bucket {bucket_name}"); tracing::debug!("Created bucket {bucket_name}");
Ok(nats_store) Ok(nats_store)
} }
...@@ -91,7 +95,7 @@ impl NATSStorage { ...@@ -91,7 +95,7 @@ impl NATSStorage {
&self, &self,
namespace: &str, namespace: &str,
bucket_name: &Slug, bucket_name: &Slug,
) -> Result<Option<async_nats::jetstream::kv::Store>, StorageError> { ) -> Result<Option<async_nats::jetstream::kv::Store>, StoreError> {
let bucket_name = single_name(namespace, bucket_name); let bucket_name = single_name(namespace, bucket_name);
let js = self.client.jetstream(); let js = self.client.jetstream();
...@@ -102,7 +106,7 @@ impl NATSStorage { ...@@ -102,7 +106,7 @@ impl NATSStorage {
// bucket doesn't exist // bucket doesn't exist
Ok(None) Ok(None)
} }
Err(err) => Err(StorageError::KeyValueError(err.to_string(), bucket_name)), Err(err) => Err(StoreError::KeyValueError(err.to_string(), bucket_name)),
} }
} }
} }
...@@ -114,7 +118,7 @@ impl KeyValueBucket for NATSBucket { ...@@ -114,7 +118,7 @@ impl KeyValueBucket for NATSBucket {
key: &Key, key: &Key,
value: &str, value: &str,
revision: u64, revision: u64,
) -> Result<StorageOutcome, StorageError> { ) -> Result<StoreOutcome, StoreError> {
if revision == 0 { if revision == 0 {
self.create(key, value).await self.create(key, value).await
} else { } else {
...@@ -122,29 +126,29 @@ impl KeyValueBucket for NATSBucket { ...@@ -122,29 +126,29 @@ impl KeyValueBucket for NATSBucket {
} }
} }
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError> { async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
self.nats_store self.nats_store
.get(key) .get(key)
.await .await
.map_err(|e| StorageError::NATSError(e.to_string())) .map_err(|e| StoreError::NATSError(e.to_string()))
} }
async fn delete(&self, key: &Key) -> Result<(), StorageError> { async fn delete(&self, key: &Key) -> Result<(), StoreError> {
self.nats_store self.nats_store
.delete(key) .delete(key)
.await .await
.map_err(|e| StorageError::NATSError(e.to_string())) .map_err(|e| StoreError::NATSError(e.to_string()))
} }
async fn watch( async fn watch(
&self, &self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError> ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>
{ {
let watch_stream = self let watch_stream = self
.nats_store .nats_store
.watch_all() .watch_all()
.await .await
.map_err(|e| StorageError::NATSError(e.to_string()))?; .map_err(|e| StoreError::NATSError(e.to_string()))?;
// Map the `Entry` to `Entry.value` which is Bytes of the stored value. // Map the `Entry` to `Entry.value` which is Bytes of the stored value.
Ok(Box::pin( Ok(Box::pin(
watch_stream.filter_map( watch_stream.filter_map(
...@@ -164,12 +168,12 @@ impl KeyValueBucket for NATSBucket { ...@@ -164,12 +168,12 @@ impl KeyValueBucket for NATSBucket {
)) ))
} }
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> { async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
let mut key_stream = self let mut key_stream = self
.nats_store .nats_store
.keys() .keys()
.await .await
.map_err(|e| StorageError::NATSError(e.to_string()))?; .map_err(|e| StoreError::NATSError(e.to_string()))?;
let mut out = HashMap::new(); let mut out = HashMap::new();
while let Some(Ok(key)) = key_stream.next().await { while let Some(Ok(key)) = key_stream.next().await {
if let Ok(Some(entry)) = self.nats_store.entry(&key).await { if let Ok(Some(entry)) = self.nats_store.entry(&key).await {
...@@ -181,24 +185,24 @@ impl KeyValueBucket for NATSBucket { ...@@ -181,24 +185,24 @@ impl KeyValueBucket for NATSBucket {
} }
impl NATSBucket { impl NATSBucket {
async fn create(&self, key: &Key, value: &str) -> Result<StorageOutcome, StorageError> { async fn create(&self, key: &Key, value: &str) -> Result<StoreOutcome, StoreError> {
match self.nats_store.create(&key, value.to_string().into()).await { match self.nats_store.create(&key, value.to_string().into()).await {
Ok(revision) => Ok(StorageOutcome::Created(revision)), Ok(revision) => Ok(StoreOutcome::Created(revision)),
Err(err) if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists => { Err(err) if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists => {
// key exists, get the revsion // key exists, get the revsion
match self.nats_store.entry(key).await { match self.nats_store.entry(key).await {
Ok(Some(entry)) => Ok(StorageOutcome::Exists(entry.revision)), Ok(Some(entry)) => Ok(StoreOutcome::Exists(entry.revision)),
Ok(None) => { Ok(None) => {
tracing::error!( tracing::error!(
%key, %key,
"Race condition, key deleted between create and fetch. Retry." "Race condition, key deleted between create and fetch. Retry."
); );
Err(StorageError::Retry) Err(StoreError::Retry)
} }
Err(err) => Err(StorageError::NATSError(err.to_string())), Err(err) => Err(StoreError::NATSError(err.to_string())),
} }
} }
Err(err) => Err(StorageError::NATSError(err.to_string())), Err(err) => Err(StoreError::NATSError(err.to_string())),
} }
} }
...@@ -207,26 +211,26 @@ impl NATSBucket { ...@@ -207,26 +211,26 @@ impl NATSBucket {
key: &Key, key: &Key,
value: &str, value: &str,
revision: u64, revision: u64,
) -> Result<StorageOutcome, StorageError> { ) -> Result<StoreOutcome, StoreError> {
match self match self
.nats_store .nats_store
.update(key, value.to_string().into(), revision) .update(key, value.to_string().into(), revision)
.await .await
{ {
Ok(revision) => Ok(StorageOutcome::Created(revision)), Ok(revision) => Ok(StoreOutcome::Created(revision)),
Err(err) Err(err)
if err.kind() == async_nats::jetstream::kv::UpdateErrorKind::WrongLastRevision => if err.kind() == async_nats::jetstream::kv::UpdateErrorKind::WrongLastRevision =>
{ {
tracing::warn!(revision, %key, "Update WrongLastRevision, resync"); tracing::warn!(revision, %key, "Update WrongLastRevision, resync");
self.resync_update(key, value).await self.resync_update(key, value).await
} }
Err(err) => Err(StorageError::NATSError(err.to_string())), Err(err) => Err(StoreError::NATSError(err.to_string())),
} }
} }
/// We have the wrong revision for a key. Fetch it's entry to get the correct revision, /// We have the wrong revision for a key. Fetch it's entry to get the correct revision,
/// and try the update again. /// and try the update again.
async fn resync_update(&self, key: &Key, value: &str) -> Result<StorageOutcome, StorageError> { async fn resync_update(&self, key: &Key, value: &str) -> Result<StoreOutcome, StoreError> {
match self.nats_store.entry(key).await { match self.nats_store.entry(key).await {
Ok(Some(entry)) => { Ok(Some(entry)) => {
// Re-try the update with new version number // Re-try the update with new version number
...@@ -236,8 +240,8 @@ impl NATSBucket { ...@@ -236,8 +240,8 @@ impl NATSBucket {
.update(key, value.to_string().into(), next_rev) .update(key, value.to_string().into(), next_rev)
.await .await
{ {
Ok(correct_revision) => Ok(StorageOutcome::Created(correct_revision)), Ok(correct_revision) => Ok(StoreOutcome::Created(correct_revision)),
Err(err) => Err(StorageError::NATSError(format!( Err(err) => Err(StoreError::NATSError(format!(
"Error during update of key {key} after resync: {err}" "Error during update of key {key} after resync: {err}"
))), ))),
} }
...@@ -248,7 +252,7 @@ impl NATSBucket { ...@@ -248,7 +252,7 @@ impl NATSBucket {
} }
Err(err) => { Err(err) => {
tracing::error!(%key, %err, "Failed fetching entry during resync"); tracing::error!(%key, %err, "Failed fetching entry during resync");
Err(StorageError::NATSError(err.to_string())) Err(StoreError::NATSError(err.to_string()))
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment