// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; use rand::Rng as _; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use super::{Bucket, Key, KeyValue, Store, StoreError, StoreOutcome, WatchEvent}; #[derive(Clone, Debug)] enum MemoryEvent { Put { key: String, value: bytes::Bytes }, Delete { key: String }, } #[derive(Clone)] pub struct MemoryStore { inner: Arc, connection_id: u64, } impl Default for MemoryStore { fn default() -> Self { Self::new() } } struct MemoryStoreInner { data: parking_lot::Mutex>, change_sender: UnboundedSender, change_receiver: tokio::sync::Mutex>, } pub struct MemoryBucketRef { name: String, inner: Arc, } struct MemoryBucket { data: HashMap, } impl MemoryBucket { fn new() -> Self { MemoryBucket { data: HashMap::new(), } } } impl MemoryStore { pub(super) fn new() -> Self { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); MemoryStore { inner: Arc::new(MemoryStoreInner { data: parking_lot::Mutex::new(HashMap::new()), change_sender: tx, change_receiver: tokio::sync::Mutex::new(rx), }), connection_id: rand::rng().random(), } } } #[async_trait] impl Store for MemoryStore { type Bucket = MemoryBucketRef; async fn get_or_create_bucket( &self, bucket_name: &str, // MemoryStore doesn't respect TTL yet _ttl: Option, ) -> Result { let mut locked_data = self.inner.data.lock(); // Ensure the bucket exists locked_data .entry(bucket_name.to_string()) .or_insert_with(MemoryBucket::new); // Return an object able to access it Ok(MemoryBucketRef { name: bucket_name.to_string(), inner: self.inner.clone(), }) } /// This operation cannot fail on MemoryStore. Always returns Ok. async fn get_bucket(&self, bucket_name: &str) -> Result, StoreError> { let locked_data = self.inner.data.lock(); match locked_data.get(bucket_name) { Some(_) => Ok(Some(MemoryBucketRef { name: bucket_name.to_string(), inner: self.inner.clone(), })), None => Ok(None), } } fn connection_id(&self) -> u64 { self.connection_id } fn shutdown(&self) {} } #[async_trait] impl Bucket for MemoryBucketRef { async fn insert( &self, key: &Key, value: bytes::Bytes, revision: u64, ) -> Result { let mut locked_data = self.inner.data.lock(); let mut b = locked_data.get_mut(&self.name); let Some(bucket) = b.as_mut() else { return Err(StoreError::MissingBucket(self.name.to_string())); }; let outcome = match bucket.data.entry(key.to_string()) { Entry::Vacant(e) => { e.insert((revision, value.clone())); let _ = self.inner.change_sender.send(MemoryEvent::Put { key: key.to_string(), value, }); StoreOutcome::Created(revision) } Entry::Occupied(mut entry) => { let (rev, _v) = entry.get(); if *rev == revision { StoreOutcome::Exists(revision) } else { entry.insert((revision, value)); StoreOutcome::Created(revision) } } }; Ok(outcome) } async fn get(&self, key: &Key) -> Result, StoreError> { let locked_data = self.inner.data.lock(); let Some(bucket) = locked_data.get(&self.name) else { return Ok(None); }; Ok(bucket.data.get(&key.0).map(|(_, v)| v.clone())) } async fn delete(&self, key: &Key) -> Result<(), StoreError> { let mut locked_data = self.inner.data.lock(); let Some(bucket) = locked_data.get_mut(&self.name) else { return Err(StoreError::MissingBucket(self.name.to_string())); }; if bucket.data.remove(&key.0).is_some() { let _ = self.inner.change_sender.send(MemoryEvent::Delete { key: key.to_string(), }); } Ok(()) } /// All current values in the bucket first, then block waiting for new /// values to be published. /// Caller takes the lock so only a single caller may use this at once. async fn watch( &self, ) -> Result + Send + 'life0>>, StoreError> { // All the existing ones first let mut existing_items = vec![]; let mut seen_keys = HashSet::new(); let data_lock = self.inner.data.lock(); let Some(bucket) = data_lock.get(&self.name) else { return Err(StoreError::MissingBucket(self.name.to_string())); }; for (key, (_rev, v)) in &bucket.data { seen_keys.insert(key.clone()); let item = KeyValue::new(Key::new(key.clone()), v.clone()); existing_items.push(WatchEvent::Put(item)); } drop(data_lock); Ok(Box::pin(async_stream::stream! { for event in existing_items { yield event; } // Now any new ones let mut rcv_lock = self.inner.change_receiver.lock().await; loop { match rcv_lock.recv().await { None => { // Channel is closed, no more values coming break; }, Some(MemoryEvent::Put { key, value }) => { if seen_keys.contains(&key) { continue; } let item = KeyValue::new(Key::new(key), value); yield WatchEvent::Put(item); }, Some(MemoryEvent::Delete { key }) => { yield WatchEvent::Delete(Key::new(key)); } } } })) } async fn entries(&self) -> Result, StoreError> { let locked_data = self.inner.data.lock(); match locked_data.get(&self.name) { Some(bucket) => { let mut out = HashMap::new(); for (k, (_rev, v)) in bucket.data.iter() { let key = Key::new([self.name.clone(), k.to_string()].join("/")); let value = v.clone(); out.insert(key, value); } Ok(out) } None => Err(StoreError::MissingBucket(self.name.clone())), } } } #[cfg(test)] mod tests { use crate::storage::kv::{Bucket as _, Key, MemoryStore, Store as _}; use std::collections::HashSet; #[tokio::test] async fn test_entries_full_path() { let m = MemoryStore::new(); let bucket = m.get_or_create_bucket("bucket1", None).await.unwrap(); let _ = bucket .insert(&Key::new("key1".to_string()), "value1".into(), 0) .await .unwrap(); let _ = bucket .insert(&Key::new("key2".to_string()), "value2".into(), 0) .await .unwrap(); let entries = bucket.entries().await.unwrap(); let keys: HashSet = entries.into_keys().collect(); assert!(keys.contains(&Key::new("bucket1/key1".to_string()))); assert!(keys.contains(&Key::new("bucket1/key2".to_string()))); } }