Unverified Commit b77696da authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(runtime): Watched storage buckets now also return Delete events (#3875)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 9defc01b
......@@ -62,6 +62,24 @@ impl From<&Key> for String {
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct KeyValue {
key: String,
value: bytes::Bytes,
}
impl KeyValue {
pub fn new(key: String, value: bytes::Bytes) -> Self {
KeyValue { key, value }
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum WatchEvent {
Put(KeyValue),
Delete(KeyValue),
}
#[async_trait]
pub trait KeyValueStore: Send + Sync {
type Bucket: KeyValueBucket + Send + Sync + 'static;
......@@ -196,14 +214,14 @@ impl KeyValueStoreManager {
/// Returns a receiver that will receive all the existing keys, and
/// then block and receive new keys as they are created.
/// Starts a task that runs forever, watches the store.
pub fn watch<T: for<'a> Deserialize<'a> + Send + 'static>(
pub fn watch(
self: Arc<Self>,
bucket_name: &str,
bucket_ttl: Option<Duration>,
cancel_token: CancellationToken,
) -> (
tokio::task::JoinHandle<Result<(), StoreError>>,
tokio::sync::mpsc::UnboundedReceiver<T>,
tokio::sync::mpsc::UnboundedReceiver<WatchEvent>,
) {
let bucket_name = bucket_name.to_string();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
......@@ -216,22 +234,20 @@ impl KeyValueStoreManager {
let mut stream = bucket.watch().await?;
// Send all the existing keys
for (_, card_bytes) in bucket.entries().await? {
let card: T = serde_json::from_slice(card_bytes.as_ref())?;
let _ = tx.send(card);
for (key, bytes) in bucket.entries().await? {
let _ = tx.send(WatchEvent::Put(KeyValue::new(key, bytes)));
}
// Now block waiting for new entries
loop {
let card_bytes = tokio::select! {
let event = tokio::select! {
_ = cancel_token.cancelled() => break,
result = stream.next() => match result {
Some(bytes) => bytes,
Some(event) => event,
None => break,
}
};
let card: T = serde_json::from_slice(card_bytes.as_ref())?;
let _ = tx.send(card);
let _ = tx.send(event);
}
Ok::<(), StoreError>(())
......@@ -284,7 +300,7 @@ pub trait KeyValueBucket: Send + Sync {
/// such time.
async fn watch(
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + '_>>, StoreError>;
) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + '_>>, StoreError>;
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError>;
}
......@@ -353,14 +369,14 @@ mod tests {
/// clients can listen to.
#[allow(dead_code)]
pub struct TappableStream {
tx: tokio::sync::broadcast::Sender<bytes::Bytes>,
tx: tokio::sync::broadcast::Sender<WatchEvent>,
}
#[allow(dead_code)]
impl TappableStream {
async fn new<T>(stream: T, max_size: usize) -> Self
where
T: futures::Stream<Item = bytes::Bytes> + Send + 'static,
T: futures::Stream<Item = WatchEvent> + Send + 'static,
{
let (tx, _) = tokio::sync::broadcast::channel(max_size);
let tx2 = tx.clone();
......@@ -373,7 +389,7 @@ mod tests {
TappableStream { tx }
}
fn subscribe(&self) -> tokio::sync::broadcast::Receiver<bytes::Bytes> {
fn subscribe(&self) -> tokio::sync::broadcast::Receiver<WatchEvent> {
self.tx.subscribe()
}
}
......@@ -393,6 +409,15 @@ mod tests {
let res = bucket.insert(&"test1".into(), "value1", 0).await?;
assert_eq!(res, StoreOutcome::Created(0));
let mut expected = Vec::with_capacity(3);
for i in 1..=3 {
let item = WatchEvent::Put(KeyValue::new(
format!("test{i}"),
bytes::Bytes::from(format!("value{i}").into_bytes()),
));
expected.push(item);
}
let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
let ingress = tokio::spawn(async move {
let b2 = s2.get_or_create_bucket(BUCKET_NAME, None).await?;
......@@ -400,15 +425,16 @@ mod tests {
// Put in before starting the watch-all
let v = stream.next().await.unwrap();
assert_eq!(v, "value1".as_bytes());
assert_eq!(v, expected[0]);
got_first_tx.send(()).unwrap();
// Put in after
let v = stream.next().await.unwrap();
assert_eq!(v, "value2".as_bytes());
assert_eq!(v, expected[1]);
let v = stream.next().await.unwrap();
assert_eq!(v, "value3".as_bytes());
assert_eq!(v, expected[2]);
Ok::<_, StoreError>(())
});
......@@ -455,13 +481,18 @@ mod tests {
let mut rx1 = tap.subscribe();
let mut rx2 = tap.subscribe();
let item = WatchEvent::Put(KeyValue::new(
"test1".to_string(),
bytes::Bytes::from(b"GK".as_slice()),
));
let item_clone = item.clone();
let handle1 = tokio::spawn(async move {
let b = rx1.recv().await.unwrap();
assert_eq!(b, bytes::Bytes::from(vec![b'G', b'K']));
assert_eq!(b, item_clone);
});
let handle2 = tokio::spawn(async move {
let b = rx2.recv().await.unwrap();
assert_eq!(b, bytes::Bytes::from(vec![b'G', b'K']));
assert_eq!(b, item);
});
bucket.insert(&"test1".into(), "GK", 1).await?;
......
......@@ -5,7 +5,10 @@ use std::collections::HashMap;
use std::pin::Pin;
use std::time::Duration;
use crate::{storage::key_value_store::Key, transports::etcd::Client};
use crate::{
storage::key_value_store::{Key, KeyValue, WatchEvent},
transports::etcd::Client,
};
use async_stream::stream;
use async_trait::async_trait;
use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
......@@ -104,24 +107,39 @@ impl KeyValueBucket for EtcdBucket {
async fn watch(
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>
{
let k = make_key(&self.bucket_name, &"".into());
tracing::trace!("etcd watch: {k}");
) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
let prefix = make_key(&self.bucket_name, &"".into());
tracing::trace!("etcd watch: {prefix}");
let (watcher, mut watch_stream) = self
.client
.etcd_client()
.clone()
.watch(k.as_bytes(), Some(WatchOptions::new().with_prefix()))
.watch(prefix.as_bytes(), Some(WatchOptions::new().with_prefix()))
.await
.map_err(|e| StoreError::EtcdError(e.to_string()))?;
let output = stream! {
let _watcher = watcher; // Keep it alive. Not sure if necessary.
while let Ok(Some(resp)) = watch_stream.message().await {
for e in resp.events() {
if matches!(e.event_type(), EventType::Put) && e.kv().is_some() {
let b: bytes::Bytes = e.kv().unwrap().value().to_vec().into();
yield b;
let Some(kv) = e.kv() else {
continue;
};
let (k_bytes, v_bytes) = kv.clone().into_key_value();
let key = match String::from_utf8(k_bytes) {
Ok(k) => k,
Err(err) => {
tracing::error!(%err, prefix, "Invalid UTF8 in etcd key");
continue;
}
};
let item = KeyValue::new(key, v_bytes.into());
match e.event_type() {
EventType::Put => {
yield WatchEvent::Put(item);
}
EventType::Delete => {
yield WatchEvent::Delete(item);
}
}
}
}
......
......@@ -9,13 +9,18 @@ use std::time::Duration;
use async_trait::async_trait;
use rand::Rng as _;
use tokio::sync::Mutex;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use crate::storage::key_value_store::Key;
use crate::storage::key_value_store::{Key, KeyValue, WatchEvent};
use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
#[derive(Clone, Debug)]
enum MemoryEvent {
Put { key: String, value: String },
Delete { key: String },
}
#[derive(Clone)]
pub struct MemoryStore {
inner: Arc<MemoryStoreInner>,
......@@ -29,9 +34,9 @@ impl Default for MemoryStore {
}
struct MemoryStoreInner {
data: Mutex<HashMap<String, MemoryBucket>>,
change_sender: UnboundedSender<(String, String)>,
change_receiver: Mutex<UnboundedReceiver<(String, String)>>,
data: parking_lot::Mutex<HashMap<String, MemoryBucket>>,
change_sender: UnboundedSender<MemoryEvent>,
change_receiver: tokio::sync::Mutex<UnboundedReceiver<MemoryEvent>>,
}
pub struct MemoryBucketRef {
......@@ -56,9 +61,9 @@ impl MemoryStore {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
MemoryStore {
inner: Arc::new(MemoryStoreInner {
data: Mutex::new(HashMap::new()),
data: parking_lot::Mutex::new(HashMap::new()),
change_sender: tx,
change_receiver: Mutex::new(rx),
change_receiver: tokio::sync::Mutex::new(rx),
}),
connection_id: rand::rng().random(),
}
......@@ -75,7 +80,7 @@ impl KeyValueStore for MemoryStore {
// MemoryStore doesn't respect TTL yet
_ttl: Option<Duration>,
) -> Result<Self::Bucket, StoreError> {
let mut locked_data = self.inner.data.lock().await;
let mut locked_data = self.inner.data.lock();
// Ensure the bucket exists
locked_data
.entry(bucket_name.to_string())
......@@ -89,7 +94,7 @@ impl KeyValueStore for MemoryStore {
/// This operation cannot fail on MemoryStore. Always returns Ok.
async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
let locked_data = self.inner.data.lock().await;
let locked_data = self.inner.data.lock();
match locked_data.get(bucket_name) {
Some(_) => Ok(Some(MemoryBucketRef {
name: bucket_name.to_string(),
......@@ -112,7 +117,7 @@ impl KeyValueBucket for MemoryBucketRef {
value: &str,
revision: u64,
) -> Result<StoreOutcome, StoreError> {
let mut locked_data = self.inner.data.lock().await;
let mut locked_data = self.inner.data.lock();
let mut b = locked_data.get_mut(&self.name);
let Some(bucket) = b.as_mut() else {
return Err(StoreError::MissingBucket(self.name.to_string()));
......@@ -120,10 +125,10 @@ impl KeyValueBucket for MemoryBucketRef {
let outcome = match bucket.data.entry(key.to_string()) {
Entry::Vacant(e) => {
e.insert((revision, value.to_string()));
let _ = self
.inner
.change_sender
.send((key.to_string(), value.to_string()));
let _ = self.inner.change_sender.send(MemoryEvent::Put {
key: key.to_string(),
value: value.to_string(),
});
StoreOutcome::Created(revision)
}
Entry::Occupied(mut entry) => {
......@@ -140,7 +145,7 @@ impl KeyValueBucket for MemoryBucketRef {
}
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
let locked_data = self.inner.data.lock().await;
let locked_data = self.inner.data.lock();
let Some(bucket) = locked_data.get(&self.name) else {
return Ok(None);
};
......@@ -151,11 +156,15 @@ impl KeyValueBucket for MemoryBucketRef {
}
async fn delete(&self, key: &Key) -> Result<(), StoreError> {
let mut locked_data = self.inner.data.lock().await;
let mut locked_data = self.inner.data.lock();
let Some(bucket) = locked_data.get_mut(&self.name) else {
return Err(StoreError::MissingBucket(self.name.to_string()));
};
bucket.data.remove(&key.0);
if bucket.data.remove(&key.0).is_some() {
let _ = self.inner.change_sender.send(MemoryEvent::Delete {
key: key.to_string(),
});
}
Ok(())
}
......@@ -164,21 +173,25 @@ impl KeyValueBucket for MemoryBucketRef {
/// Caller takes the lock so only a single caller may use this at once.
async fn watch(
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>
{
) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
// All the existing ones first
let mut existing_items = vec![];
let mut seen_keys = HashSet::new();
let data_lock = self.inner.data.lock();
let Some(bucket) = data_lock.get(&self.name) else {
return Err(StoreError::MissingBucket(self.name.to_string()));
};
for (key, (_rev, v)) in &bucket.data {
seen_keys.insert(key.clone());
let item = KeyValue::new(key.clone(), bytes::Bytes::from(v.clone().into_bytes()));
existing_items.push(WatchEvent::Put(item));
}
drop(data_lock);
Ok(Box::pin(async_stream::stream! {
// All the existing ones first
let mut seen = HashSet::new();
let data_lock = self.inner.data.lock().await;
let Some(bucket) = data_lock.get(&self.name) else {
tracing::error!(bucket_name = self.name, "watch: Missing bucket");
return;
};
for (_rev, v) in bucket.data.values() {
seen.insert(v.clone());
yield bytes::Bytes::from(v.clone());
for event in existing_items {
yield event;
}
drop(data_lock);
// Now any new ones
let mut rcv_lock = self.inner.change_receiver.lock().await;
loop {
......@@ -187,11 +200,16 @@ impl KeyValueBucket for MemoryBucketRef {
// Channel is closed, no more values coming
break;
},
Some((_k, v)) => {
if seen.contains(&v) {
Some(MemoryEvent::Put { key, value }) => {
if seen_keys.contains(&key) {
continue;
}
yield bytes::Bytes::from(v.clone());
let item = KeyValue::new(key, bytes::Bytes::from(value));
yield WatchEvent::Put(item);
},
Some(MemoryEvent::Delete { key }) => {
let item = KeyValue::new(key, bytes::Bytes::new());
yield WatchEvent::Delete(item);
}
}
}
......@@ -199,7 +217,7 @@ impl KeyValueBucket for MemoryBucketRef {
}
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
let locked_data = self.inner.data.lock().await;
let locked_data = self.inner.data.lock();
match locked_data.get(&self.name) {
Some(bucket) => Ok(bucket
.data
......
......@@ -4,8 +4,12 @@
use std::{collections::HashMap, pin::Pin, time::Duration};
use crate::{
protocols::EndpointId, slug::Slug, storage::key_value_store::Key, transports::nats::Client,
protocols::EndpointId,
slug::Slug,
storage::key_value_store::{Key, KeyValue, WatchEvent},
transports::nats::Client,
};
use async_nats::jetstream::kv::Operation;
use async_trait::async_trait;
use futures::StreamExt;
......@@ -141,8 +145,7 @@ impl KeyValueBucket for NATSBucket {
async fn watch(
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>
{
) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
let watch_stream = self
.nats_store
.watch_all()
......@@ -156,7 +159,15 @@ impl KeyValueBucket for NATSBucket {
async_nats::error::Error<_>,
>| async move {
match maybe_entry {
Ok(entry) => Some(entry.value),
Ok(entry) => {
let item = KeyValue::new(entry.key, entry.value);
Some(match entry.operation {
Operation::Put => WatchEvent::Put(item),
Operation::Delete => WatchEvent::Delete(item),
// TODO: What is Purge? Not urgent, NATS impl not used
Operation::Purge => WatchEvent::Delete(item),
})
}
Err(e) => {
tracing::error!(error=%e, "watch fatal err");
None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment