mem.rs 8.05 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
// SPDX-License-Identifier: Apache-2.0

use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
11
use rand::Rng as _;
12
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
13

14
use super::{Bucket, Key, KeyValue, Store, StoreError, StoreOutcome, WatchEvent};
15

16
17
#[derive(Clone, Debug)]
enum MemoryEvent {
18
    Put { key: String, value: bytes::Bytes },
19
20
21
    Delete { key: String },
}

22
#[derive(Clone)]
23
24
25
pub struct MemoryStore {
    inner: Arc<MemoryStoreInner>,
    connection_id: u64,
26
27
}

28
impl Default for MemoryStore {
29
30
31
32
33
    fn default() -> Self {
        Self::new()
    }
}

34
struct MemoryStoreInner {
35
36
37
    data: parking_lot::Mutex<HashMap<String, MemoryBucket>>,
    change_sender: UnboundedSender<MemoryEvent>,
    change_receiver: tokio::sync::Mutex<UnboundedReceiver<MemoryEvent>>,
38
39
40
41
}

pub struct MemoryBucketRef {
    name: String,
42
    inner: Arc<MemoryStoreInner>,
43
44
45
}

struct MemoryBucket {
46
    data: HashMap<String, (u64, bytes::Bytes)>,
47
48
49
50
51
52
53
54
55
56
}

impl MemoryBucket {
    fn new() -> Self {
        MemoryBucket {
            data: HashMap::new(),
        }
    }
}

57
impl MemoryStore {
58
    pub(super) fn new() -> Self {
59
        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
60
61
        MemoryStore {
            inner: Arc::new(MemoryStoreInner {
62
                data: parking_lot::Mutex::new(HashMap::new()),
63
                change_sender: tx,
64
                change_receiver: tokio::sync::Mutex::new(rx),
65
            }),
66
            connection_id: rand::rng().random(),
67
68
69
70
71
        }
    }
}

#[async_trait]
72
impl Store for MemoryStore {
73
74
    type Bucket = MemoryBucketRef;

75
76
77
    async fn get_or_create_bucket(
        &self,
        bucket_name: &str,
78
        // MemoryStore doesn't respect TTL yet
79
        _ttl: Option<Duration>,
80
    ) -> Result<Self::Bucket, StoreError> {
81
        let mut locked_data = self.inner.data.lock();
82
83
84
85
86
        // Ensure the bucket exists
        locked_data
            .entry(bucket_name.to_string())
            .or_insert_with(MemoryBucket::new);
        // Return an object able to access it
87
        Ok(MemoryBucketRef {
88
89
            name: bucket_name.to_string(),
            inner: self.inner.clone(),
90
        })
91
92
    }

93
    /// This operation cannot fail on MemoryStore. Always returns Ok.
94
    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
95
        let locked_data = self.inner.data.lock();
96
        match locked_data.get(bucket_name) {
97
            Some(_) => Ok(Some(MemoryBucketRef {
98
99
                name: bucket_name.to_string(),
                inner: self.inner.clone(),
100
            })),
101
102
103
            None => Ok(None),
        }
    }
104
105
106
107

    fn connection_id(&self) -> u64 {
        self.connection_id
    }
108
109

    fn shutdown(&self) {}
110
111
112
}

#[async_trait]
113
impl Bucket for MemoryBucketRef {
114
115
    async fn insert(
        &self,
116
        key: &Key,
117
        value: bytes::Bytes,
118
        revision: u64,
119
    ) -> Result<StoreOutcome, StoreError> {
120
        let mut locked_data = self.inner.data.lock();
121
122
        let mut b = locked_data.get_mut(&self.name);
        let Some(bucket) = b.as_mut() else {
123
            return Err(StoreError::MissingBucket(self.name.to_string()));
124
125
126
        };
        let outcome = match bucket.data.entry(key.to_string()) {
            Entry::Vacant(e) => {
127
                e.insert((revision, value.clone()));
128
129
                let _ = self.inner.change_sender.send(MemoryEvent::Put {
                    key: key.to_string(),
130
                    value,
131
                });
132
                StoreOutcome::Created(revision)
133
134
135
136
            }
            Entry::Occupied(mut entry) => {
                let (rev, _v) = entry.get();
                if *rev == revision {
137
                    StoreOutcome::Exists(revision)
138
                } else {
139
                    entry.insert((revision, value));
140
                    StoreOutcome::Created(revision)
141
142
143
144
145
146
                }
            }
        };
        Ok(outcome)
    }

147
    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
148
        let locked_data = self.inner.data.lock();
149
150
151
        let Some(bucket) = locked_data.get(&self.name) else {
            return Ok(None);
        };
152
        Ok(bucket.data.get(&key.0).map(|(_, v)| v.clone()))
153
154
    }

155
    async fn delete(&self, key: &Key) -> Result<(), StoreError> {
156
        let mut locked_data = self.inner.data.lock();
157
        let Some(bucket) = locked_data.get_mut(&self.name) else {
158
            return Err(StoreError::MissingBucket(self.name.to_string()));
159
        };
160
161
162
163
164
        if bucket.data.remove(&key.0).is_some() {
            let _ = self.inner.change_sender.send(MemoryEvent::Delete {
                key: key.to_string(),
            });
        }
165
166
167
168
169
170
171
172
        Ok(())
    }

    /// All current values in the bucket first, then block waiting for new
    /// values to be published.
    /// Caller takes the lock so only a single caller may use this at once.
    async fn watch(
        &self,
173
174
175
176
177
178
179
180
181
182
    ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
        // All the existing ones first
        let mut existing_items = vec![];
        let mut seen_keys = HashSet::new();
        let data_lock = self.inner.data.lock();
        let Some(bucket) = data_lock.get(&self.name) else {
            return Err(StoreError::MissingBucket(self.name.to_string()));
        };
        for (key, (_rev, v)) in &bucket.data {
            seen_keys.insert(key.clone());
183
            let item = KeyValue::new(Key::new(key.clone()), v.clone());
184
185
186
187
            existing_items.push(WatchEvent::Put(item));
        }
        drop(data_lock);

188
        Ok(Box::pin(async_stream::stream! {
189
190
            for event in existing_items {
                yield event;
191
192
193
194
195
196
197
198
199
            }
            // Now any new ones
            let mut rcv_lock = self.inner.change_receiver.lock().await;
            loop {
                match rcv_lock.recv().await {
                    None => {
                        // Channel is closed, no more values coming
                        break;
                    },
200
201
                    Some(MemoryEvent::Put { key, value }) => {
                        if seen_keys.contains(&key) {
202
203
                            continue;
                        }
204
                        let item = KeyValue::new(Key::new(key), value);
205
206
207
                        yield WatchEvent::Put(item);
                    },
                    Some(MemoryEvent::Delete { key }) => {
208
                        yield WatchEvent::Delete(Key::new(key));
209
210
211
212
213
214
                    }
                }
            }
        }))
    }

215
    async fn entries(&self) -> Result<HashMap<Key, bytes::Bytes>, StoreError> {
216
        let locked_data = self.inner.data.lock();
217
        match locked_data.get(&self.name) {
218
219
220
221
222
223
224
225
226
            Some(bucket) => {
                let mut out = HashMap::new();
                for (k, (_rev, v)) in bucket.data.iter() {
                    let key = Key::new([self.name.clone(), k.to_string()].join("/"));
                    let value = v.clone();
                    out.insert(key, value);
                }
                Ok(out)
            }
227
            None => Err(StoreError::MissingBucket(self.name.clone())),
228
229
230
        }
    }
}
231
232
233

#[cfg(test)]
mod tests {
234
    use crate::storage::kv::{Bucket as _, Key, MemoryStore, Store as _};
235
236
237
238
239
240
241
    use std::collections::HashSet;

    #[tokio::test]
    async fn test_entries_full_path() {
        let m = MemoryStore::new();
        let bucket = m.get_or_create_bucket("bucket1", None).await.unwrap();
        let _ = bucket
242
            .insert(&Key::new("key1".to_string()), "value1".into(), 0)
243
244
245
            .await
            .unwrap();
        let _ = bucket
246
            .insert(&Key::new("key2".to_string()), "value2".into(), 0)
247
248
249
            .await
            .unwrap();
        let entries = bucket.entries().await.unwrap();
250
251
252
        let keys: HashSet<Key> = entries.into_keys().collect();
        assert!(keys.contains(&Key::new("bucket1/key1".to_string())));
        assert!(keys.contains(&Key::new("bucket1/key2".to_string())));
253
254
    }
}