"lib/runtime/src/vscode:/vscode.git/clone" did not exist on "e330d969b959f92f03065bf8a6170ce765e33f75"
Unverified Commit 7a7d397c authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Introduce storage_client in DistributedRuntime (#3507)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 0a2a820b
......@@ -57,11 +57,11 @@ impl State {
}
}
pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: Option<etcd::Client>) -> Self {
pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: etcd::Client) -> Self {
Self {
manager,
metrics: Arc::new(Metrics::default()),
etcd_client,
etcd_client: Some(etcd_client),
}
}
......@@ -155,7 +155,10 @@ impl KserveServiceConfigBuilder {
let config: KserveServiceConfig = self.build_internal()?;
let model_manager = Arc::new(ModelManager::new());
let state = Arc::new(State::new_with_etcd(model_manager, config.etcd_client));
let state = match config.etcd_client {
Some(etcd_client) => Arc::new(State::new_with_etcd(model_manager, etcd_client)),
None => Arc::new(State::new(model_manager)),
};
// enable prometheus metrics
let registry = metrics::Registry::new();
......
......@@ -52,16 +52,12 @@ async fn live_handler(
async fn health_handler(
axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
) -> impl IntoResponse {
let instances = if let Some(etcd_client) = state.etcd_client() {
match list_all_instances(etcd_client).await {
let instances = match list_all_instances(state.store()).await {
Ok(instances) => instances,
Err(err) => {
tracing::warn!("Failed to fetch instances from etcd: {}", err);
tracing::warn!(%err, "Failed to fetch instances from store");
vec![]
}
}
} else {
vec![]
};
let mut endpoints: Vec<String> = instances
......
......@@ -19,6 +19,9 @@ use anyhow::Result;
use axum_server::tls_rustls::RustlsConfig;
use derive_builder::Builder;
use dynamo_runtime::logging::make_request_span;
use dynamo_runtime::storage::key_value_store::EtcdStore;
use dynamo_runtime::storage::key_value_store::KeyValueStore;
use dynamo_runtime::storage::key_value_store::MemoryStore;
use dynamo_runtime::transports::etcd;
use std::net::SocketAddr;
use tokio::task::JoinHandle;
......@@ -26,11 +29,11 @@ use tokio_util::sync::CancellationToken;
use tower_http::trace::TraceLayer;
/// HTTP service shared state
#[derive(Default)]
pub struct State {
metrics: Arc<Metrics>,
manager: Arc<ModelManager>,
etcd_client: Option<etcd::Client>,
store: Arc<dyn KeyValueStore>,
flags: StateFlags,
}
......@@ -76,6 +79,7 @@ impl State {
manager,
metrics: Arc::new(Metrics::default()),
etcd_client: None,
store: Arc::new(MemoryStore::new()),
flags: StateFlags {
chat_endpoints_enabled: AtomicBool::new(false),
cmpl_endpoints_enabled: AtomicBool::new(false),
......@@ -85,11 +89,12 @@ impl State {
}
}
pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: Option<etcd::Client>) -> Self {
pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: etcd::Client) -> Self {
Self {
manager,
metrics: Arc::new(Metrics::default()),
etcd_client,
store: Arc::new(EtcdStore::new(etcd_client.clone())),
etcd_client: Some(etcd_client),
flags: StateFlags {
chat_endpoints_enabled: AtomicBool::new(false),
cmpl_endpoints_enabled: AtomicBool::new(false),
......@@ -115,6 +120,10 @@ impl State {
self.etcd_client.as_ref()
}
pub fn store(&self) -> Arc<dyn KeyValueStore> {
self.store.clone()
}
// TODO
pub fn sse_keep_alive(&self) -> Option<Duration> {
None
......@@ -294,9 +303,10 @@ impl HttpServiceConfigBuilder {
let config: HttpServiceConfig = self.build_internal()?;
let model_manager = Arc::new(ModelManager::new());
let etcd_client = config.etcd_client;
let state = Arc::new(State::new_with_etcd(model_manager, etcd_client));
let state = match config.etcd_client {
Some(etcd_client) => Arc::new(State::new_with_etcd(model_manager, etcd_client)),
None => Arc::new(State::new(model_manager)),
};
state
.flags
.set(&EndpointType::Chat, config.enable_chat_endpoints);
......
......@@ -12,7 +12,7 @@ use dynamo_runtime::storage::key_value_store::Key;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::{
component::Endpoint,
storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
storage::key_value_store::{EtcdStore, KeyValueStore, KeyValueStoreManager},
};
use crate::entrypoint::RouterConfig;
......@@ -409,7 +409,7 @@ impl LocalModel {
self.card.move_to_nats(nats_client.clone()).await?;
// Publish the Model Deployment Card to KV store
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStore::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let lease_id = endpoint.drt().primary_lease().map(|l| l.id()).unwrap_or(0);
let key = Key::from_raw(endpoint.unique_path(lease_id));
......
......@@ -23,7 +23,7 @@ use anyhow::{Context, Result};
use derive_builder::Builder;
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::storage::key_value_store::{
EtcdStorage, Key, KeyValueStore, KeyValueStoreManager,
EtcdStore, Key, KeyValueStore, KeyValueStoreManager,
};
use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats};
use serde::{Deserialize, Serialize};
......@@ -457,7 +457,7 @@ impl ModelDeploymentCard {
// Should be impossible because we only get here on an etcd event
anyhow::bail!("Missing etcd_client");
};
let store: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client));
let store: Box<dyn KeyValueStore> = Box::new(EtcdStore::new(etcd_client));
let card_store = Arc::new(KeyValueStoreManager::new(store));
let Some(mut card) = card_store
.load::<ModelDeploymentCard>(ROOT_PATH, mdc_key)
......
......@@ -29,6 +29,8 @@
//!
//! TODO: Top-level Overview of Endpoints/Functions
use std::fmt;
use crate::{
config::HealthStatus,
discovery::Lease,
......@@ -70,7 +72,7 @@ pub mod service;
pub use client::{Client, InstanceSource};
/// The root etcd path where each instance registers itself in etcd.
/// The root key-value path where each instance registers itself in.
/// An instance is namespace+component+endpoint+lease_id and must be unique.
pub const INSTANCE_ROOT_PATH: &str = "v1/instances";
......@@ -91,7 +93,7 @@ pub struct Registry {
inner: Arc<tokio::sync::Mutex<RegistryInner>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Instance {
pub component: String,
pub endpoint: String,
......@@ -113,6 +115,30 @@ impl Instance {
}
}
impl fmt::Display for Instance {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}/{}/{}/{}",
self.namespace, self.component, self.endpoint, self.instance_id
)
}
}
/// Sort by string name
impl std::cmp::Ord for Instance {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.to_string().cmp(&other.to_string())
}
}
impl PartialOrd for Instance {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
// Since Ord is fully implemented, the comparison is always total.
Some(self.cmp(other))
}
}
/// A [Component] a discoverable entity in the distributed runtime.
/// You can host [Endpoint] on a [Component] by first creating
/// a [Service] then adding one or more [Endpoint] to the [Service].
......@@ -197,8 +223,8 @@ impl MetricsRegistry for Component {
}
impl Component {
/// The component part of an instance path in etcd.
pub fn etcd_root(&self) -> String {
/// The component part of an instance path in key-value store.
pub fn instance_root(&self) -> String {
let ns = self.namespace.name();
let cp = &self.name;
format!("{INSTANCE_ROOT_PATH}/{ns}/{cp}")
......@@ -240,27 +266,23 @@ impl Component {
}
pub async fn list_instances(&self) -> anyhow::Result<Vec<Instance>> {
let Some(etcd_client) = self.drt.etcd_client() else {
let client = self.drt.store();
let Some(bucket) = client.get_bucket(&self.instance_root()).await? else {
return Ok(vec![]);
};
let mut out = vec![];
// The extra slash is important to only list exact component matches, not substrings.
for kv in etcd_client
.kv_get_prefix(format!("{}/", self.etcd_root()))
.await?
{
let val = match serde_json::from_slice::<Instance>(kv.value()) {
let entries = bucket.entries().await?;
let mut instances = Vec::with_capacity(entries.len());
for (name, bytes) in entries.into_iter() {
let val = match serde_json::from_slice::<Instance>(&bytes) {
Ok(val) => val,
Err(err) => {
anyhow::bail!(
"Error converting etcd response to Instance: {err}. {}",
kv.value_str()?
);
anyhow::bail!("Error converting storage response to Instance: {err}. {name}",);
}
};
out.push(val);
instances.push(val);
}
Ok(out)
instances.sort();
Ok(instances)
}
/// Scrape ServiceSet, which contains NATS stats as well as user defined stats
......@@ -445,7 +467,7 @@ impl Endpoint {
/// The endpoint part of an instance path in etcd
pub fn etcd_root(&self) -> String {
let component_path = self.component.etcd_root();
let component_path = self.component.instance_root();
let endpoint_name = &self.name;
format!("{component_path}/{endpoint_name}")
}
......
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
pub use crate::component::Component;
use crate::storage::key_value_store::{EtcdStore, KeyValueStore, MemoryStore};
use crate::transports::nats::DRTNatsClientPrometheusMetrics;
use crate::{
ErrorContext, RuntimeCallback,
......@@ -44,10 +45,14 @@ impl DistributedRuntime {
let runtime_clone = runtime.clone();
let etcd_client = if is_static {
None
let (etcd_client, store) = if is_static {
let store: Arc<dyn KeyValueStore> = Arc::new(MemoryStore::new());
(None, store)
} else {
Some(etcd::Client::new(etcd_config.clone(), runtime_clone).await?)
let etcd_client = etcd::Client::new(etcd_config.clone(), runtime_clone).await?;
let store: Arc<dyn KeyValueStore> = Arc::new(EtcdStore::new(etcd_client.clone()));
(Some(etcd_client), store)
};
let nats_client = nats_config.clone().connect().await?;
......@@ -77,6 +82,7 @@ impl DistributedRuntime {
let distributed_runtime = Self {
runtime,
etcd_client,
store,
nats_client,
tcp_server: Arc::new(OnceCell::new()),
system_status_server: Arc::new(OnceLock::new()),
......@@ -270,6 +276,12 @@ impl DistributedRuntime {
self.etcd_client.clone()
}
/// An interface to store things. Will eventually replace `etcd_client`.
/// Currently does key-value, but will grow to include whatever we need to store.
pub fn store(&self) -> Arc<dyn KeyValueStore> {
self.store.clone()
}
pub fn child_token(&self) -> CancellationToken {
self.runtime.child_token()
}
......
......@@ -7,28 +7,28 @@
//! the entire distributed system, complementing the component-specific
//! instance listing in `component.rs`.
use std::sync::Arc;
use crate::component::{INSTANCE_ROOT_PATH, Instance};
use crate::storage::key_value_store::KeyValueStore;
use crate::transports::etcd::Client as EtcdClient;
pub async fn list_all_instances(etcd_client: &EtcdClient) -> anyhow::Result<Vec<Instance>> {
let mut instances = Vec::new();
pub async fn list_all_instances(client: Arc<dyn KeyValueStore>) -> anyhow::Result<Vec<Instance>> {
let Some(bucket) = client.get_bucket(INSTANCE_ROOT_PATH).await? else {
return Ok(vec![]);
};
for kv in etcd_client
.kv_get_prefix(format!("{}/", INSTANCE_ROOT_PATH))
.await?
{
match serde_json::from_slice::<Instance>(kv.value()) {
let entries = bucket.entries().await?;
let mut instances = Vec::with_capacity(entries.len());
for (name, bytes) in entries.into_iter() {
match serde_json::from_slice::<Instance>(&bytes) {
Ok(instance) => instances.push(instance),
Err(err) => {
tracing::warn!(
"Failed to parse instance from etcd: {}. Key: {}, Value: {}",
err,
kv.key_str().unwrap_or("invalid_key"),
kv.value_str().unwrap_or("invalid_value")
);
tracing::warn!(%err, key = name, "Failed to parse instance from storage");
}
}
}
instances.sort();
Ok(instances)
}
......@@ -51,7 +51,9 @@ pub use system_health::{HealthCheckTarget, SystemHealth};
pub use tokio_util::sync::CancellationToken;
pub use worker::Worker;
use crate::metrics::prometheus_names::distributed_runtime;
use crate::{
metrics::prometheus_names::distributed_runtime, storage::key_value_store::KeyValueStore,
};
use component::{Endpoint, InstanceSource};
use utils::GracefulShutdownTracker;
......@@ -152,6 +154,7 @@ pub struct DistributedRuntime {
// we might consider a unifed transport manager here
etcd_client: Option<transports::etcd::Client>,
nats_client: transports::nats::Client,
store: Arc<dyn KeyValueStore>,
tcp_server: Arc<OnceCell<Arc<transports::tcp::server::TcpStreamServer>>>,
system_status_server: Arc<OnceLock<Arc<system_status_server::SystemStatusServerInfo>>>,
......
......@@ -17,11 +17,11 @@ use futures::StreamExt;
use serde::{Deserialize, Serialize};
mod mem;
pub use mem::MemoryStorage;
pub use mem::MemoryStore;
mod nats;
pub use nats::NATSStorage;
pub use nats::NATSStore;
mod etcd;
pub use etcd::EtcdStorage;
pub use etcd::EtcdStore;
/// A key that is safe to use directly in the KV store.
#[derive(Debug, Clone, PartialEq)]
......@@ -69,12 +69,14 @@ pub trait KeyValueStore: Send + Sync {
bucket_name: &str,
// auto-delete items older than this
ttl: Option<Duration>,
) -> Result<Box<dyn KeyValueBucket>, StorageError>;
) -> Result<Box<dyn KeyValueBucket>, StoreError>;
async fn get_bucket(
&self,
bucket_name: &str,
) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError>;
) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError>;
fn connection_id(&self) -> u64;
}
pub struct KeyValueStoreManager(Box<dyn KeyValueStore>);
......@@ -88,7 +90,7 @@ impl KeyValueStoreManager {
&self,
bucket: &str,
key: &Key,
) -> Result<Option<T>, StorageError> {
) -> Result<Option<T>, StoreError> {
let Some(bucket) = self.0.get_bucket(bucket).await? else {
// No bucket means no cards
return Ok(None);
......@@ -101,7 +103,7 @@ impl KeyValueStoreManager {
Ok(None) => Ok(None),
Err(err) => {
// TODO look at what errors NATS can give us and make more specific wrappers
Err(StorageError::NATSError(err.to_string()))
Err(StoreError::NATSError(err.to_string()))
}
}
}
......@@ -114,7 +116,7 @@ impl KeyValueStoreManager {
bucket_name: &str,
bucket_ttl: Option<Duration>,
) -> (
tokio::task::JoinHandle<Result<(), StorageError>>,
tokio::task::JoinHandle<Result<(), StoreError>>,
tokio::sync::mpsc::UnboundedReceiver<T>,
) {
let bucket_name = bucket_name.to_string();
......@@ -139,7 +141,7 @@ impl KeyValueStoreManager {
let _ = tx.send(card);
}
Ok::<(), StorageError>(())
Ok::<(), StoreError>(())
});
(watch_task, rx)
}
......@@ -150,14 +152,14 @@ impl KeyValueStoreManager {
bucket_ttl: Option<Duration>,
key: &Key,
obj: &mut T,
) -> anyhow::Result<StorageOutcome> {
) -> anyhow::Result<StoreOutcome> {
let obj_json = serde_json::to_string(obj)?;
let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;
let outcome = bucket.insert(key, &obj_json, obj.revision()).await?;
match outcome {
StorageOutcome::Created(revision) | StorageOutcome::Exists(revision) => {
StoreOutcome::Created(revision) | StoreOutcome::Exists(revision) => {
obj.set_revision(revision);
}
}
......@@ -176,43 +178,43 @@ pub trait KeyValueBucket: Send {
key: &Key,
value: &str,
revision: u64,
) -> Result<StorageOutcome, StorageError>;
) -> Result<StoreOutcome, StoreError>;
/// Fetch an item from the key-value storage
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError>;
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError>;
/// Delete an item from the bucket
async fn delete(&self, key: &Key) -> Result<(), StorageError>;
async fn delete(&self, key: &Key) -> Result<(), StoreError>;
/// A stream of items inserted into the bucket.
/// Every time the stream is polled it will either return a newly created entry, or block until
/// such time.
async fn watch(
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>;
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>;
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError>;
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError>;
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum StorageOutcome {
pub enum StoreOutcome {
/// The operation succeeded and created a new entry with this revision.
/// Note that "create" also means update, because each new revision is a "create".
Created(u64),
/// The operation did not do anything, the value was already present, with this revision.
Exists(u64),
}
impl fmt::Display for StorageOutcome {
impl fmt::Display for StoreOutcome {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StorageOutcome::Created(revision) => write!(f, "Created at {revision}"),
StorageOutcome::Exists(revision) => write!(f, "Exists at {revision}"),
StoreOutcome::Created(revision) => write!(f, "Created at {revision}"),
StoreOutcome::Exists(revision) => write!(f, "Exists at {revision}"),
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum StorageError {
pub enum StoreError {
#[error("Could not find bucket '{0}'")]
MissingBucket(String),
......@@ -291,12 +293,12 @@ mod tests {
async fn test_memory_storage() -> anyhow::Result<()> {
init();
let s = Arc::new(MemoryStorage::new());
let s = Arc::new(MemoryStore::new());
let s2 = Arc::clone(&s);
let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
let res = bucket.insert(&"test1".into(), "value1", 0).await?;
assert_eq!(res, StorageOutcome::Created(0));
assert_eq!(res, StoreOutcome::Created(0));
let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
let ingress = tokio::spawn(async move {
......@@ -315,27 +317,27 @@ mod tests {
let v = stream.next().await.unwrap();
assert_eq!(v, "value3".as_bytes());
Ok::<_, StorageError>(())
Ok::<_, StoreError>(())
});
// MemoryStorage uses a HashMap with no inherent ordering, so we must ensure test1 is
// MemoryStore uses a HashMap with no inherent ordering, so we must ensure test1 is
// fetched before test2 is inserted, otherwise they can come out in any order, and we
// wouldn't be testing the watch behavior.
got_first_rx.await?;
let res = bucket.insert(&"test2".into(), "value2", 0).await?;
assert_eq!(res, StorageOutcome::Created(0));
assert_eq!(res, StoreOutcome::Created(0));
// Repeat a key and revision. Ignored.
let res = bucket.insert(&"test2".into(), "value2", 0).await?;
assert_eq!(res, StorageOutcome::Exists(0));
assert_eq!(res, StoreOutcome::Exists(0));
// Increment revision
let res = bucket.insert(&"test2".into(), "value2", 1).await?;
assert_eq!(res, StorageOutcome::Created(1));
assert_eq!(res, StoreOutcome::Created(1));
let res = bucket.insert(&"test3".into(), "value3", 0).await?;
assert_eq!(res, StorageOutcome::Created(0));
assert_eq!(res, StoreOutcome::Created(0));
// ingress exits once it has received all values
let _ = ingress.await?;
......@@ -347,12 +349,12 @@ mod tests {
async fn test_broadcast_stream() -> anyhow::Result<()> {
init();
let s: &'static _ = Box::leak(Box::new(MemoryStorage::new()));
let s: &'static _ = Box::leak(Box::new(MemoryStore::new()));
let bucket: &'static _ =
Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));
let res = bucket.insert(&"test1".into(), "value1", 0).await?;
assert_eq!(res, StorageOutcome::Created(0));
assert_eq!(res, StoreOutcome::Created(0));
let stream = bucket.watch().await?;
let tap = TappableStream::new(stream, 10).await;
......
......@@ -10,27 +10,27 @@ use async_stream::stream;
use async_trait::async_trait;
use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
#[derive(Clone)]
pub struct EtcdStorage {
pub struct EtcdStore {
client: Client,
}
impl EtcdStorage {
impl EtcdStore {
pub fn new(client: Client) -> Self {
Self { client }
}
}
#[async_trait]
impl KeyValueStore for EtcdStorage {
impl KeyValueStore for EtcdStore {
/// A "bucket" in etcd is a path prefix
async fn get_or_create_bucket(
&self,
bucket_name: &str,
_ttl: Option<Duration>, // TODO ttl not used yet
) -> Result<Box<dyn KeyValueBucket>, StorageError> {
) -> Result<Box<dyn KeyValueBucket>, StoreError> {
Ok(self.get_bucket(bucket_name).await?.unwrap())
}
......@@ -39,12 +39,18 @@ impl KeyValueStore for EtcdStorage {
async fn get_bucket(
&self,
bucket_name: &str,
) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> {
) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
Ok(Some(Box::new(EtcdBucket {
client: self.client.clone(),
bucket_name: bucket_name.to_string(),
})))
}
fn connection_id(&self) -> u64 {
// This conversion from i64 to u64 is safe because etcd lease IDs are u64 internally.
// They present as i64 because of the limitations of the etcd grpc/HTTP JSON API.
self.client.lease_id() as u64
}
}
pub struct EtcdBucket {
......@@ -60,7 +66,7 @@ impl KeyValueBucket for EtcdBucket {
value: &str,
// "version" in etcd speak. revision is a global cluster-wide value
revision: u64,
) -> Result<StorageOutcome, StorageError> {
) -> Result<StoreOutcome, StoreError> {
let version = revision;
if version == 0 {
self.create(key, value).await
......@@ -69,7 +75,7 @@ impl KeyValueBucket for EtcdBucket {
}
}
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError> {
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd get: {k}");
......@@ -77,7 +83,7 @@ impl KeyValueBucket for EtcdBucket {
.client
.kv_get(k, None)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
.map_err(|e| StoreError::EtcdError(e.to_string()))?;
if kvs.is_empty() {
return Ok(None);
}
......@@ -85,20 +91,20 @@ impl KeyValueBucket for EtcdBucket {
Ok(Some(val.into()))
}
async fn delete(&self, key: &Key) -> Result<(), StorageError> {
async fn delete(&self, key: &Key) -> Result<(), StoreError> {
let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd delete: {k}");
let _ = self
.client
.kv_delete(k, None)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
.map_err(|e| StoreError::EtcdError(e.to_string()))?;
Ok(())
}
async fn watch(
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>
{
let k = make_key(&self.bucket_name, &"".into());
tracing::trace!("etcd watch: {k}");
......@@ -108,7 +114,7 @@ impl KeyValueBucket for EtcdBucket {
.clone()
.watch(k.as_bytes(), Some(WatchOptions::new().with_prefix()))
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
.map_err(|e| StoreError::EtcdError(e.to_string()))?;
let output = stream! {
while let Ok(Some(resp)) = watch_stream.message().await {
for e in resp.events() {
......@@ -122,7 +128,7 @@ impl KeyValueBucket for EtcdBucket {
Ok(Box::pin(output))
}
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> {
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
let k = make_key(&self.bucket_name, &"".into());
tracing::trace!("etcd entries: {k}");
......@@ -130,7 +136,7 @@ impl KeyValueBucket for EtcdBucket {
.client
.kv_get_prefix(k)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
.map_err(|e| StoreError::EtcdError(e.to_string()))?;
let out: HashMap<String, bytes::Bytes> = resp
.into_iter()
.map(|kv| {
......@@ -144,7 +150,7 @@ impl KeyValueBucket for EtcdBucket {
}
impl EtcdBucket {
async fn create(&self, key: &Key, value: &str) -> Result<StorageOutcome, StorageError> {
async fn create(&self, key: &Key, value: &str) -> Result<StoreOutcome, StoreError> {
let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd create: {k}");
......@@ -166,11 +172,11 @@ impl EtcdBucket {
.kv_client()
.txn(txn)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
.map_err(|e| StoreError::EtcdError(e.to_string()))?;
if result.succeeded() {
// Key was created successfully
return Ok(StorageOutcome::Created(1)); // version of new key is always 1
return Ok(StoreOutcome::Created(1)); // version of new key is always 1
}
// Key already existed, get its version
......@@ -179,10 +185,10 @@ impl EtcdBucket {
&& let Some(kv) = get_resp.kvs().first()
{
let version = kv.version() as u64;
return Ok(StorageOutcome::Exists(version));
return Ok(StoreOutcome::Exists(version));
}
// Shouldn't happen, but handle edge case
Err(StorageError::EtcdError(
Err(StoreError::EtcdError(
"Unexpected transaction response".to_string(),
))
}
......@@ -192,7 +198,7 @@ impl EtcdBucket {
key: &Key,
value: &str,
revision: u64,
) -> Result<StorageOutcome, StorageError> {
) -> Result<StoreOutcome, StoreError> {
let version = revision;
let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd update: {k}");
......@@ -201,9 +207,9 @@ impl EtcdBucket {
.client
.kv_get(k.clone(), None)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
.map_err(|e| StoreError::EtcdError(e.to_string()))?;
if kvs.is_empty() {
return Err(StorageError::MissingKey(key.to_string()));
return Err(StoreError::MissingKey(key.to_string()));
}
let current_version = kvs.first().unwrap().version() as u64;
if current_version != version + 1 {
......@@ -224,17 +230,17 @@ impl EtcdBucket {
.client
.kv_put_with_options(k, value, Some(put_options))
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
.map_err(|e| StoreError::EtcdError(e.to_string()))?;
Ok(match put_resp.take_prev_key() {
// Should this be an error?
// The key was deleted between our get and put. We re-created it.
// Version of new key is always 1.
// <https://etcd.io/docs/v3.5/learning/data_model/>
None => StorageOutcome::Created(1),
None => StoreOutcome::Created(1),
// Expected case, success
Some(kv) if kv.version() as u64 == version + 1 => StorageOutcome::Created(version),
Some(kv) if kv.version() as u64 == version + 1 => StoreOutcome::Created(version),
// Should this be an error? Something updated the version between our get and put
Some(kv) => StorageOutcome::Created(kv.version() as u64 + 1),
Some(kv) => StoreOutcome::Created(kv.version() as u64 + 1),
})
}
}
......@@ -263,9 +269,9 @@ mod concurrent_create_tests {
});
}
async fn test_concurrent_create(drt: DistributedRuntime) -> Result<(), StorageError> {
async fn test_concurrent_create(drt: DistributedRuntime) -> Result<(), StoreError> {
let etcd_client = drt.etcd_client().expect("etcd client should be available");
let storage = EtcdStorage::new(etcd_client);
let storage = EtcdStore::new(etcd_client);
// Create a bucket for testing
let bucket = Arc::new(tokio::sync::Mutex::new(
......@@ -307,7 +313,7 @@ mod concurrent_create_tests {
.await;
match result {
Ok(StorageOutcome::Created(version)) => {
Ok(StoreOutcome::Created(version)) => {
println!(
"Worker {} successfully created key with version {}",
worker_id, version
......@@ -316,7 +322,7 @@ mod concurrent_create_tests {
*count += 1;
Ok(version)
}
Ok(StorageOutcome::Exists(version)) => {
Ok(StoreOutcome::Exists(version)) => {
println!(
"Worker {} found key already exists with version {}",
worker_id, version
......
......@@ -8,25 +8,27 @@ use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use rand::Rng as _;
use tokio::sync::Mutex;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use crate::storage::key_value_store::Key;
use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
#[derive(Clone)]
pub struct MemoryStorage {
inner: Arc<MemoryStorageInner>,
pub struct MemoryStore {
inner: Arc<MemoryStoreInner>,
connection_id: u64,
}
impl Default for MemoryStorage {
impl Default for MemoryStore {
fn default() -> Self {
Self::new()
}
}
struct MemoryStorageInner {
struct MemoryStoreInner {
data: Mutex<HashMap<String, MemoryBucket>>,
change_sender: UnboundedSender<(String, String)>,
change_receiver: Mutex<UnboundedReceiver<(String, String)>>,
......@@ -34,7 +36,7 @@ struct MemoryStorageInner {
pub struct MemoryBucketRef {
name: String,
inner: Arc<MemoryStorageInner>,
inner: Arc<MemoryStoreInner>,
}
struct MemoryBucket {
......@@ -49,27 +51,28 @@ impl MemoryBucket {
}
}
impl MemoryStorage {
impl MemoryStore {
pub fn new() -> Self {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
MemoryStorage {
inner: Arc::new(MemoryStorageInner {
MemoryStore {
inner: Arc::new(MemoryStoreInner {
data: Mutex::new(HashMap::new()),
change_sender: tx,
change_receiver: Mutex::new(rx),
}),
connection_id: rand::rng().random(),
}
}
}
#[async_trait]
impl KeyValueStore for MemoryStorage {
impl KeyValueStore for MemoryStore {
async fn get_or_create_bucket(
&self,
bucket_name: &str,
// MemoryStorage doesn't respect TTL yet
// MemoryStore doesn't respect TTL yet
_ttl: Option<Duration>,
) -> Result<Box<dyn KeyValueBucket>, StorageError> {
) -> Result<Box<dyn KeyValueBucket>, StoreError> {
let mut locked_data = self.inner.data.lock().await;
// Ensure the bucket exists
locked_data
......@@ -82,11 +85,11 @@ impl KeyValueStore for MemoryStorage {
}))
}
/// This operation cannot fail on MemoryStorage. Always returns Ok.
/// This operation cannot fail on MemoryStore. Always returns Ok.
async fn get_bucket(
&self,
bucket_name: &str,
) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> {
) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
let locked_data = self.inner.data.lock().await;
match locked_data.get(bucket_name) {
Some(_) => Ok(Some(Box::new(MemoryBucketRef {
......@@ -96,6 +99,10 @@ impl KeyValueStore for MemoryStorage {
None => Ok(None),
}
}
fn connection_id(&self) -> u64 {
self.connection_id
}
}
#[async_trait]
......@@ -105,11 +112,11 @@ impl KeyValueBucket for MemoryBucketRef {
key: &Key,
value: &str,
revision: u64,
) -> Result<StorageOutcome, StorageError> {
) -> Result<StoreOutcome, StoreError> {
let mut locked_data = self.inner.data.lock().await;
let mut b = locked_data.get_mut(&self.name);
let Some(bucket) = b.as_mut() else {
return Err(StorageError::MissingBucket(self.name.to_string()));
return Err(StoreError::MissingBucket(self.name.to_string()));
};
let outcome = match bucket.data.entry(key.to_string()) {
Entry::Vacant(e) => {
......@@ -118,22 +125,22 @@ impl KeyValueBucket for MemoryBucketRef {
.inner
.change_sender
.send((key.to_string(), value.to_string()));
StorageOutcome::Created(revision)
StoreOutcome::Created(revision)
}
Entry::Occupied(mut entry) => {
let (rev, _v) = entry.get();
if *rev == revision {
StorageOutcome::Exists(revision)
StoreOutcome::Exists(revision)
} else {
entry.insert((revision, value.to_string()));
StorageOutcome::Created(revision)
StoreOutcome::Created(revision)
}
}
};
Ok(outcome)
}
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError> {
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
let locked_data = self.inner.data.lock().await;
let Some(bucket) = locked_data.get(&self.name) else {
return Ok(None);
......@@ -144,10 +151,10 @@ impl KeyValueBucket for MemoryBucketRef {
.map(|(_, v)| bytes::Bytes::from(v.clone())))
}
async fn delete(&self, key: &Key) -> Result<(), StorageError> {
async fn delete(&self, key: &Key) -> Result<(), StoreError> {
let mut locked_data = self.inner.data.lock().await;
let Some(bucket) = locked_data.get_mut(&self.name) else {
return Err(StorageError::MissingBucket(self.name.to_string()));
return Err(StoreError::MissingBucket(self.name.to_string()));
};
bucket.data.remove(&key.0);
Ok(())
......@@ -158,7 +165,7 @@ impl KeyValueBucket for MemoryBucketRef {
/// Caller takes the lock so only a single caller may use this at once.
async fn watch(
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>
{
Ok(Box::pin(async_stream::stream! {
// All the existing ones first
......@@ -192,7 +199,7 @@ impl KeyValueBucket for MemoryBucketRef {
}))
}
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> {
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
let locked_data = self.inner.data.lock().await;
match locked_data.get(&self.name) {
Some(bucket) => Ok(bucket
......@@ -200,7 +207,7 @@ impl KeyValueBucket for MemoryBucketRef {
.iter()
.map(|(k, (_rev, v))| (k.to_string(), bytes::Bytes::from(v.clone())))
.collect()),
None => Err(StorageError::MissingBucket(self.name.clone())),
None => Err(StoreError::MissingBucket(self.name.clone())),
}
}
}
......@@ -9,10 +9,10 @@ use crate::{
use async_trait::async_trait;
use futures::StreamExt;
use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
#[derive(Clone)]
pub struct NATSStorage {
pub struct NATSStore {
client: Client,
endpoint: EndpointId,
}
......@@ -22,12 +22,12 @@ pub struct NATSBucket {
}
#[async_trait]
impl KeyValueStore for NATSStorage {
impl KeyValueStore for NATSStore {
async fn get_or_create_bucket(
&self,
bucket_name: &str,
ttl: Option<Duration>,
) -> Result<Box<dyn KeyValueBucket>, StorageError> {
) -> Result<Box<dyn KeyValueBucket>, StoreError> {
let name = Slug::slugify(bucket_name);
let nats_store = self
.get_or_create_key_value(&self.endpoint.namespace, &name, ttl)
......@@ -38,18 +38,22 @@ impl KeyValueStore for NATSStorage {
async fn get_bucket(
&self,
bucket_name: &str,
) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> {
) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
let name = Slug::slugify(bucket_name);
match self.get_key_value(&self.endpoint.namespace, &name).await? {
Some(nats_store) => Ok(Some(Box::new(NATSBucket { nats_store }))),
None => Ok(None),
}
}
fn connection_id(&self) -> u64 {
self.client.client().server_info().client_id
}
}
impl NATSStorage {
impl NATSStore {
pub fn new(client: Client, endpoint: EndpointId) -> Self {
NATSStorage { client, endpoint }
NATSStore { client, endpoint }
}
/// Get or create a key-value store (aka bucket) in NATS.
......@@ -62,7 +66,7 @@ impl NATSStorage {
bucket_name: &Slug,
// Delete entries older than this
ttl: Option<Duration>,
) -> Result<async_nats::jetstream::kv::Store, StorageError> {
) -> Result<async_nats::jetstream::kv::Store, StoreError> {
if let Ok(Some(kv)) = self.get_key_value(namespace, bucket_name).await {
return Ok(kv);
}
......@@ -82,7 +86,7 @@ impl NATSStorage {
)
.await;
let nats_store = create_result
.map_err(|err| StorageError::KeyValueError(err.to_string(), bucket_name.clone()))?;
.map_err(|err| StoreError::KeyValueError(err.to_string(), bucket_name.clone()))?;
tracing::debug!("Created bucket {bucket_name}");
Ok(nats_store)
}
......@@ -91,7 +95,7 @@ impl NATSStorage {
&self,
namespace: &str,
bucket_name: &Slug,
) -> Result<Option<async_nats::jetstream::kv::Store>, StorageError> {
) -> Result<Option<async_nats::jetstream::kv::Store>, StoreError> {
let bucket_name = single_name(namespace, bucket_name);
let js = self.client.jetstream();
......@@ -102,7 +106,7 @@ impl NATSStorage {
// bucket doesn't exist
Ok(None)
}
Err(err) => Err(StorageError::KeyValueError(err.to_string(), bucket_name)),
Err(err) => Err(StoreError::KeyValueError(err.to_string(), bucket_name)),
}
}
}
......@@ -114,7 +118,7 @@ impl KeyValueBucket for NATSBucket {
key: &Key,
value: &str,
revision: u64,
) -> Result<StorageOutcome, StorageError> {
) -> Result<StoreOutcome, StoreError> {
if revision == 0 {
self.create(key, value).await
} else {
......@@ -122,29 +126,29 @@ impl KeyValueBucket for NATSBucket {
}
}
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StorageError> {
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
self.nats_store
.get(key)
.await
.map_err(|e| StorageError::NATSError(e.to_string()))
.map_err(|e| StoreError::NATSError(e.to_string()))
}
async fn delete(&self, key: &Key) -> Result<(), StorageError> {
async fn delete(&self, key: &Key) -> Result<(), StoreError> {
self.nats_store
.delete(key)
.await
.map_err(|e| StorageError::NATSError(e.to_string()))
.map_err(|e| StoreError::NATSError(e.to_string()))
}
async fn watch(
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>
{
let watch_stream = self
.nats_store
.watch_all()
.await
.map_err(|e| StorageError::NATSError(e.to_string()))?;
.map_err(|e| StoreError::NATSError(e.to_string()))?;
// Map the `Entry` to `Entry.value` which is Bytes of the stored value.
Ok(Box::pin(
watch_stream.filter_map(
......@@ -164,12 +168,12 @@ impl KeyValueBucket for NATSBucket {
))
}
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> {
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
let mut key_stream = self
.nats_store
.keys()
.await
.map_err(|e| StorageError::NATSError(e.to_string()))?;
.map_err(|e| StoreError::NATSError(e.to_string()))?;
let mut out = HashMap::new();
while let Some(Ok(key)) = key_stream.next().await {
if let Ok(Some(entry)) = self.nats_store.entry(&key).await {
......@@ -181,24 +185,24 @@ impl KeyValueBucket for NATSBucket {
}
impl NATSBucket {
async fn create(&self, key: &Key, value: &str) -> Result<StorageOutcome, StorageError> {
async fn create(&self, key: &Key, value: &str) -> Result<StoreOutcome, StoreError> {
match self.nats_store.create(&key, value.to_string().into()).await {
Ok(revision) => Ok(StorageOutcome::Created(revision)),
Ok(revision) => Ok(StoreOutcome::Created(revision)),
Err(err) if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists => {
// key exists, get the revsion
match self.nats_store.entry(key).await {
Ok(Some(entry)) => Ok(StorageOutcome::Exists(entry.revision)),
Ok(Some(entry)) => Ok(StoreOutcome::Exists(entry.revision)),
Ok(None) => {
tracing::error!(
%key,
"Race condition, key deleted between create and fetch. Retry."
);
Err(StorageError::Retry)
Err(StoreError::Retry)
}
Err(err) => Err(StorageError::NATSError(err.to_string())),
Err(err) => Err(StoreError::NATSError(err.to_string())),
}
}
Err(err) => Err(StorageError::NATSError(err.to_string())),
Err(err) => Err(StoreError::NATSError(err.to_string())),
}
}
......@@ -207,26 +211,26 @@ impl NATSBucket {
key: &Key,
value: &str,
revision: u64,
) -> Result<StorageOutcome, StorageError> {
) -> Result<StoreOutcome, StoreError> {
match self
.nats_store
.update(key, value.to_string().into(), revision)
.await
{
Ok(revision) => Ok(StorageOutcome::Created(revision)),
Ok(revision) => Ok(StoreOutcome::Created(revision)),
Err(err)
if err.kind() == async_nats::jetstream::kv::UpdateErrorKind::WrongLastRevision =>
{
tracing::warn!(revision, %key, "Update WrongLastRevision, resync");
self.resync_update(key, value).await
}
Err(err) => Err(StorageError::NATSError(err.to_string())),
Err(err) => Err(StoreError::NATSError(err.to_string())),
}
}
/// We have the wrong revision for a key. Fetch it's entry to get the correct revision,
/// and try the update again.
async fn resync_update(&self, key: &Key, value: &str) -> Result<StorageOutcome, StorageError> {
async fn resync_update(&self, key: &Key, value: &str) -> Result<StoreOutcome, StoreError> {
match self.nats_store.entry(key).await {
Ok(Some(entry)) => {
// Re-try the update with new version number
......@@ -236,8 +240,8 @@ impl NATSBucket {
.update(key, value.to_string().into(), next_rev)
.await
{
Ok(correct_revision) => Ok(StorageOutcome::Created(correct_revision)),
Err(err) => Err(StorageError::NATSError(format!(
Ok(correct_revision) => Ok(StoreOutcome::Created(correct_revision)),
Err(err) => Err(StoreError::NATSError(format!(
"Error during update of key {key} after resync: {err}"
))),
}
......@@ -248,7 +252,7 @@ impl NATSBucket {
}
Err(err) => {
tracing::error!(%key, %err, "Failed fetching entry during resync");
Err(StorageError::NATSError(err.to_string()))
Err(StoreError::NATSError(err.to_string()))
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment