Unverified Commit b77696da authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(runtime): Watched storage buckets now also return Delete events (#3875)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 9defc01b
...@@ -62,6 +62,24 @@ impl From<&Key> for String { ...@@ -62,6 +62,24 @@ impl From<&Key> for String {
} }
} }
#[derive(Debug, Clone, PartialEq)]
pub struct KeyValue {
key: String,
value: bytes::Bytes,
}
impl KeyValue {
pub fn new(key: String, value: bytes::Bytes) -> Self {
KeyValue { key, value }
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum WatchEvent {
Put(KeyValue),
Delete(KeyValue),
}
#[async_trait] #[async_trait]
pub trait KeyValueStore: Send + Sync { pub trait KeyValueStore: Send + Sync {
type Bucket: KeyValueBucket + Send + Sync + 'static; type Bucket: KeyValueBucket + Send + Sync + 'static;
...@@ -196,14 +214,14 @@ impl KeyValueStoreManager { ...@@ -196,14 +214,14 @@ impl KeyValueStoreManager {
/// Returns a receiver that will receive all the existing keys, and /// Returns a receiver that will receive all the existing keys, and
/// then block and receive new keys as they are created. /// then block and receive new keys as they are created.
/// Starts a task that runs forever, watches the store. /// Starts a task that runs forever, watches the store.
pub fn watch<T: for<'a> Deserialize<'a> + Send + 'static>( pub fn watch(
self: Arc<Self>, self: Arc<Self>,
bucket_name: &str, bucket_name: &str,
bucket_ttl: Option<Duration>, bucket_ttl: Option<Duration>,
cancel_token: CancellationToken, cancel_token: CancellationToken,
) -> ( ) -> (
tokio::task::JoinHandle<Result<(), StoreError>>, tokio::task::JoinHandle<Result<(), StoreError>>,
tokio::sync::mpsc::UnboundedReceiver<T>, tokio::sync::mpsc::UnboundedReceiver<WatchEvent>,
) { ) {
let bucket_name = bucket_name.to_string(); let bucket_name = bucket_name.to_string();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
...@@ -216,22 +234,20 @@ impl KeyValueStoreManager { ...@@ -216,22 +234,20 @@ impl KeyValueStoreManager {
let mut stream = bucket.watch().await?; let mut stream = bucket.watch().await?;
// Send all the existing keys // Send all the existing keys
for (_, card_bytes) in bucket.entries().await? { for (key, bytes) in bucket.entries().await? {
let card: T = serde_json::from_slice(card_bytes.as_ref())?; let _ = tx.send(WatchEvent::Put(KeyValue::new(key, bytes)));
let _ = tx.send(card);
} }
// Now block waiting for new entries // Now block waiting for new entries
loop { loop {
let card_bytes = tokio::select! { let event = tokio::select! {
_ = cancel_token.cancelled() => break, _ = cancel_token.cancelled() => break,
result = stream.next() => match result { result = stream.next() => match result {
Some(bytes) => bytes, Some(event) => event,
None => break, None => break,
} }
}; };
let card: T = serde_json::from_slice(card_bytes.as_ref())?; let _ = tx.send(event);
let _ = tx.send(card);
} }
Ok::<(), StoreError>(()) Ok::<(), StoreError>(())
...@@ -284,7 +300,7 @@ pub trait KeyValueBucket: Send + Sync { ...@@ -284,7 +300,7 @@ pub trait KeyValueBucket: Send + Sync {
/// such time. /// such time.
async fn watch( async fn watch(
&self, &self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + '_>>, StoreError>; ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + '_>>, StoreError>;
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError>; async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError>;
} }
...@@ -353,14 +369,14 @@ mod tests { ...@@ -353,14 +369,14 @@ mod tests {
/// clients can listen to. /// clients can listen to.
#[allow(dead_code)] #[allow(dead_code)]
pub struct TappableStream { pub struct TappableStream {
tx: tokio::sync::broadcast::Sender<bytes::Bytes>, tx: tokio::sync::broadcast::Sender<WatchEvent>,
} }
#[allow(dead_code)] #[allow(dead_code)]
impl TappableStream { impl TappableStream {
async fn new<T>(stream: T, max_size: usize) -> Self async fn new<T>(stream: T, max_size: usize) -> Self
where where
T: futures::Stream<Item = bytes::Bytes> + Send + 'static, T: futures::Stream<Item = WatchEvent> + Send + 'static,
{ {
let (tx, _) = tokio::sync::broadcast::channel(max_size); let (tx, _) = tokio::sync::broadcast::channel(max_size);
let tx2 = tx.clone(); let tx2 = tx.clone();
...@@ -373,7 +389,7 @@ mod tests { ...@@ -373,7 +389,7 @@ mod tests {
TappableStream { tx } TappableStream { tx }
} }
fn subscribe(&self) -> tokio::sync::broadcast::Receiver<bytes::Bytes> { fn subscribe(&self) -> tokio::sync::broadcast::Receiver<WatchEvent> {
self.tx.subscribe() self.tx.subscribe()
} }
} }
...@@ -393,6 +409,15 @@ mod tests { ...@@ -393,6 +409,15 @@ mod tests {
let res = bucket.insert(&"test1".into(), "value1", 0).await?; let res = bucket.insert(&"test1".into(), "value1", 0).await?;
assert_eq!(res, StoreOutcome::Created(0)); assert_eq!(res, StoreOutcome::Created(0));
let mut expected = Vec::with_capacity(3);
for i in 1..=3 {
let item = WatchEvent::Put(KeyValue::new(
format!("test{i}"),
bytes::Bytes::from(format!("value{i}").into_bytes()),
));
expected.push(item);
}
let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel(); let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
let ingress = tokio::spawn(async move { let ingress = tokio::spawn(async move {
let b2 = s2.get_or_create_bucket(BUCKET_NAME, None).await?; let b2 = s2.get_or_create_bucket(BUCKET_NAME, None).await?;
...@@ -400,15 +425,16 @@ mod tests { ...@@ -400,15 +425,16 @@ mod tests {
// Put in before starting the watch-all // Put in before starting the watch-all
let v = stream.next().await.unwrap(); let v = stream.next().await.unwrap();
assert_eq!(v, "value1".as_bytes()); assert_eq!(v, expected[0]);
got_first_tx.send(()).unwrap(); got_first_tx.send(()).unwrap();
// Put in after // Put in after
let v = stream.next().await.unwrap(); let v = stream.next().await.unwrap();
assert_eq!(v, "value2".as_bytes()); assert_eq!(v, expected[1]);
let v = stream.next().await.unwrap(); let v = stream.next().await.unwrap();
assert_eq!(v, "value3".as_bytes()); assert_eq!(v, expected[2]);
Ok::<_, StoreError>(()) Ok::<_, StoreError>(())
}); });
...@@ -455,13 +481,18 @@ mod tests { ...@@ -455,13 +481,18 @@ mod tests {
let mut rx1 = tap.subscribe(); let mut rx1 = tap.subscribe();
let mut rx2 = tap.subscribe(); let mut rx2 = tap.subscribe();
let item = WatchEvent::Put(KeyValue::new(
"test1".to_string(),
bytes::Bytes::from(b"GK".as_slice()),
));
let item_clone = item.clone();
let handle1 = tokio::spawn(async move { let handle1 = tokio::spawn(async move {
let b = rx1.recv().await.unwrap(); let b = rx1.recv().await.unwrap();
assert_eq!(b, bytes::Bytes::from(vec![b'G', b'K'])); assert_eq!(b, item_clone);
}); });
let handle2 = tokio::spawn(async move { let handle2 = tokio::spawn(async move {
let b = rx2.recv().await.unwrap(); let b = rx2.recv().await.unwrap();
assert_eq!(b, bytes::Bytes::from(vec![b'G', b'K'])); assert_eq!(b, item);
}); });
bucket.insert(&"test1".into(), "GK", 1).await?; bucket.insert(&"test1".into(), "GK", 1).await?;
......
...@@ -5,7 +5,10 @@ use std::collections::HashMap; ...@@ -5,7 +5,10 @@ use std::collections::HashMap;
use std::pin::Pin; use std::pin::Pin;
use std::time::Duration; use std::time::Duration;
use crate::{storage::key_value_store::Key, transports::etcd::Client}; use crate::{
storage::key_value_store::{Key, KeyValue, WatchEvent},
transports::etcd::Client,
};
use async_stream::stream; use async_stream::stream;
use async_trait::async_trait; use async_trait::async_trait;
use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions}; use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
...@@ -104,24 +107,39 @@ impl KeyValueBucket for EtcdBucket { ...@@ -104,24 +107,39 @@ impl KeyValueBucket for EtcdBucket {
async fn watch( async fn watch(
&self, &self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError> ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
{ let prefix = make_key(&self.bucket_name, &"".into());
let k = make_key(&self.bucket_name, &"".into()); tracing::trace!("etcd watch: {prefix}");
tracing::trace!("etcd watch: {k}");
let (watcher, mut watch_stream) = self let (watcher, mut watch_stream) = self
.client .client
.etcd_client() .etcd_client()
.clone() .clone()
.watch(k.as_bytes(), Some(WatchOptions::new().with_prefix())) .watch(prefix.as_bytes(), Some(WatchOptions::new().with_prefix()))
.await .await
.map_err(|e| StoreError::EtcdError(e.to_string()))?; .map_err(|e| StoreError::EtcdError(e.to_string()))?;
let output = stream! { let output = stream! {
let _watcher = watcher; // Keep it alive. Not sure if necessary. let _watcher = watcher; // Keep it alive. Not sure if necessary.
while let Ok(Some(resp)) = watch_stream.message().await { while let Ok(Some(resp)) = watch_stream.message().await {
for e in resp.events() { for e in resp.events() {
if matches!(e.event_type(), EventType::Put) && e.kv().is_some() { let Some(kv) = e.kv() else {
let b: bytes::Bytes = e.kv().unwrap().value().to_vec().into(); continue;
yield b; };
let (k_bytes, v_bytes) = kv.clone().into_key_value();
let key = match String::from_utf8(k_bytes) {
Ok(k) => k,
Err(err) => {
tracing::error!(%err, prefix, "Invalid UTF8 in etcd key");
continue;
}
};
let item = KeyValue::new(key, v_bytes.into());
match e.event_type() {
EventType::Put => {
yield WatchEvent::Put(item);
}
EventType::Delete => {
yield WatchEvent::Delete(item);
}
} }
} }
} }
......
...@@ -9,13 +9,18 @@ use std::time::Duration; ...@@ -9,13 +9,18 @@ use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use rand::Rng as _; use rand::Rng as _;
use tokio::sync::Mutex;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use crate::storage::key_value_store::Key; use crate::storage::key_value_store::{Key, KeyValue, WatchEvent};
use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome}; use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
#[derive(Clone, Debug)]
enum MemoryEvent {
Put { key: String, value: String },
Delete { key: String },
}
#[derive(Clone)] #[derive(Clone)]
pub struct MemoryStore { pub struct MemoryStore {
inner: Arc<MemoryStoreInner>, inner: Arc<MemoryStoreInner>,
...@@ -29,9 +34,9 @@ impl Default for MemoryStore { ...@@ -29,9 +34,9 @@ impl Default for MemoryStore {
} }
struct MemoryStoreInner { struct MemoryStoreInner {
data: Mutex<HashMap<String, MemoryBucket>>, data: parking_lot::Mutex<HashMap<String, MemoryBucket>>,
change_sender: UnboundedSender<(String, String)>, change_sender: UnboundedSender<MemoryEvent>,
change_receiver: Mutex<UnboundedReceiver<(String, String)>>, change_receiver: tokio::sync::Mutex<UnboundedReceiver<MemoryEvent>>,
} }
pub struct MemoryBucketRef { pub struct MemoryBucketRef {
...@@ -56,9 +61,9 @@ impl MemoryStore { ...@@ -56,9 +61,9 @@ impl MemoryStore {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
MemoryStore { MemoryStore {
inner: Arc::new(MemoryStoreInner { inner: Arc::new(MemoryStoreInner {
data: Mutex::new(HashMap::new()), data: parking_lot::Mutex::new(HashMap::new()),
change_sender: tx, change_sender: tx,
change_receiver: Mutex::new(rx), change_receiver: tokio::sync::Mutex::new(rx),
}), }),
connection_id: rand::rng().random(), connection_id: rand::rng().random(),
} }
...@@ -75,7 +80,7 @@ impl KeyValueStore for MemoryStore { ...@@ -75,7 +80,7 @@ impl KeyValueStore for MemoryStore {
// MemoryStore doesn't respect TTL yet // MemoryStore doesn't respect TTL yet
_ttl: Option<Duration>, _ttl: Option<Duration>,
) -> Result<Self::Bucket, StoreError> { ) -> Result<Self::Bucket, StoreError> {
let mut locked_data = self.inner.data.lock().await; let mut locked_data = self.inner.data.lock();
// Ensure the bucket exists // Ensure the bucket exists
locked_data locked_data
.entry(bucket_name.to_string()) .entry(bucket_name.to_string())
...@@ -89,7 +94,7 @@ impl KeyValueStore for MemoryStore { ...@@ -89,7 +94,7 @@ impl KeyValueStore for MemoryStore {
/// This operation cannot fail on MemoryStore. Always returns Ok. /// This operation cannot fail on MemoryStore. Always returns Ok.
async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> { async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
let locked_data = self.inner.data.lock().await; let locked_data = self.inner.data.lock();
match locked_data.get(bucket_name) { match locked_data.get(bucket_name) {
Some(_) => Ok(Some(MemoryBucketRef { Some(_) => Ok(Some(MemoryBucketRef {
name: bucket_name.to_string(), name: bucket_name.to_string(),
...@@ -112,7 +117,7 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -112,7 +117,7 @@ impl KeyValueBucket for MemoryBucketRef {
value: &str, value: &str,
revision: u64, revision: u64,
) -> Result<StoreOutcome, StoreError> { ) -> Result<StoreOutcome, StoreError> {
let mut locked_data = self.inner.data.lock().await; let mut locked_data = self.inner.data.lock();
let mut b = locked_data.get_mut(&self.name); let mut b = locked_data.get_mut(&self.name);
let Some(bucket) = b.as_mut() else { let Some(bucket) = b.as_mut() else {
return Err(StoreError::MissingBucket(self.name.to_string())); return Err(StoreError::MissingBucket(self.name.to_string()));
...@@ -120,10 +125,10 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -120,10 +125,10 @@ impl KeyValueBucket for MemoryBucketRef {
let outcome = match bucket.data.entry(key.to_string()) { let outcome = match bucket.data.entry(key.to_string()) {
Entry::Vacant(e) => { Entry::Vacant(e) => {
e.insert((revision, value.to_string())); e.insert((revision, value.to_string()));
let _ = self let _ = self.inner.change_sender.send(MemoryEvent::Put {
.inner key: key.to_string(),
.change_sender value: value.to_string(),
.send((key.to_string(), value.to_string())); });
StoreOutcome::Created(revision) StoreOutcome::Created(revision)
} }
Entry::Occupied(mut entry) => { Entry::Occupied(mut entry) => {
...@@ -140,7 +145,7 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -140,7 +145,7 @@ impl KeyValueBucket for MemoryBucketRef {
} }
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> { async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
let locked_data = self.inner.data.lock().await; let locked_data = self.inner.data.lock();
let Some(bucket) = locked_data.get(&self.name) else { let Some(bucket) = locked_data.get(&self.name) else {
return Ok(None); return Ok(None);
}; };
...@@ -151,11 +156,15 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -151,11 +156,15 @@ impl KeyValueBucket for MemoryBucketRef {
} }
async fn delete(&self, key: &Key) -> Result<(), StoreError> { async fn delete(&self, key: &Key) -> Result<(), StoreError> {
let mut locked_data = self.inner.data.lock().await; let mut locked_data = self.inner.data.lock();
let Some(bucket) = locked_data.get_mut(&self.name) else { let Some(bucket) = locked_data.get_mut(&self.name) else {
return Err(StoreError::MissingBucket(self.name.to_string())); return Err(StoreError::MissingBucket(self.name.to_string()));
}; };
bucket.data.remove(&key.0); if bucket.data.remove(&key.0).is_some() {
let _ = self.inner.change_sender.send(MemoryEvent::Delete {
key: key.to_string(),
});
}
Ok(()) Ok(())
} }
...@@ -164,21 +173,25 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -164,21 +173,25 @@ impl KeyValueBucket for MemoryBucketRef {
/// Caller takes the lock so only a single caller may use this at once. /// Caller takes the lock so only a single caller may use this at once.
async fn watch( async fn watch(
&self, &self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError> ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
{ // All the existing ones first
let mut existing_items = vec![];
let mut seen_keys = HashSet::new();
let data_lock = self.inner.data.lock();
let Some(bucket) = data_lock.get(&self.name) else {
return Err(StoreError::MissingBucket(self.name.to_string()));
};
for (key, (_rev, v)) in &bucket.data {
seen_keys.insert(key.clone());
let item = KeyValue::new(key.clone(), bytes::Bytes::from(v.clone().into_bytes()));
existing_items.push(WatchEvent::Put(item));
}
drop(data_lock);
Ok(Box::pin(async_stream::stream! { Ok(Box::pin(async_stream::stream! {
// All the existing ones first for event in existing_items {
let mut seen = HashSet::new(); yield event;
let data_lock = self.inner.data.lock().await;
let Some(bucket) = data_lock.get(&self.name) else {
tracing::error!(bucket_name = self.name, "watch: Missing bucket");
return;
};
for (_rev, v) in bucket.data.values() {
seen.insert(v.clone());
yield bytes::Bytes::from(v.clone());
} }
drop(data_lock);
// Now any new ones // Now any new ones
let mut rcv_lock = self.inner.change_receiver.lock().await; let mut rcv_lock = self.inner.change_receiver.lock().await;
loop { loop {
...@@ -187,11 +200,16 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -187,11 +200,16 @@ impl KeyValueBucket for MemoryBucketRef {
// Channel is closed, no more values coming // Channel is closed, no more values coming
break; break;
}, },
Some((_k, v)) => { Some(MemoryEvent::Put { key, value }) => {
if seen.contains(&v) { if seen_keys.contains(&key) {
continue; continue;
} }
yield bytes::Bytes::from(v.clone()); let item = KeyValue::new(key, bytes::Bytes::from(value));
yield WatchEvent::Put(item);
},
Some(MemoryEvent::Delete { key }) => {
let item = KeyValue::new(key, bytes::Bytes::new());
yield WatchEvent::Delete(item);
} }
} }
} }
...@@ -199,7 +217,7 @@ impl KeyValueBucket for MemoryBucketRef { ...@@ -199,7 +217,7 @@ impl KeyValueBucket for MemoryBucketRef {
} }
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> { async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
let locked_data = self.inner.data.lock().await; let locked_data = self.inner.data.lock();
match locked_data.get(&self.name) { match locked_data.get(&self.name) {
Some(bucket) => Ok(bucket Some(bucket) => Ok(bucket
.data .data
......
...@@ -4,8 +4,12 @@ ...@@ -4,8 +4,12 @@
use std::{collections::HashMap, pin::Pin, time::Duration}; use std::{collections::HashMap, pin::Pin, time::Duration};
use crate::{ use crate::{
protocols::EndpointId, slug::Slug, storage::key_value_store::Key, transports::nats::Client, protocols::EndpointId,
slug::Slug,
storage::key_value_store::{Key, KeyValue, WatchEvent},
transports::nats::Client,
}; };
use async_nats::jetstream::kv::Operation;
use async_trait::async_trait; use async_trait::async_trait;
use futures::StreamExt; use futures::StreamExt;
...@@ -141,8 +145,7 @@ impl KeyValueBucket for NATSBucket { ...@@ -141,8 +145,7 @@ impl KeyValueBucket for NATSBucket {
async fn watch( async fn watch(
&self, &self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError> ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
{
let watch_stream = self let watch_stream = self
.nats_store .nats_store
.watch_all() .watch_all()
...@@ -156,7 +159,15 @@ impl KeyValueBucket for NATSBucket { ...@@ -156,7 +159,15 @@ impl KeyValueBucket for NATSBucket {
async_nats::error::Error<_>, async_nats::error::Error<_>,
>| async move { >| async move {
match maybe_entry { match maybe_entry {
Ok(entry) => Some(entry.value), Ok(entry) => {
let item = KeyValue::new(entry.key, entry.value);
Some(match entry.operation {
Operation::Put => WatchEvent::Put(item),
Operation::Delete => WatchEvent::Delete(item),
// TODO: What is Purge? Not urgent, NATS impl not used
Operation::Purge => WatchEvent::Delete(item),
})
}
Err(e) => { Err(e) => {
tracing::error!(error=%e, "watch fatal err"); tracing::error!(error=%e, "watch fatal err");
None None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment