nats.rs 9.54 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
// SPDX-License-Identifier: Apache-2.0

use std::{collections::HashMap, pin::Pin, time::Duration};

6
use crate::{protocols::EndpointId, slug::Slug, storage::kv, transports::nats::Client};
7
use async_nats::jetstream::kv::Operation;
8
9
10
use async_trait::async_trait;
use futures::StreamExt;

11
use super::{Bucket, Store, StoreError, StoreOutcome};
12
13

#[derive(Clone)]
14
pub struct NATSStore {
15
    client: Client,
16
    endpoint: EndpointId,
17
18
19
20
21
22
23
}

pub struct NATSBucket {
    nats_store: async_nats::jetstream::kv::Store,
}

#[async_trait]
24
impl Store for NATSStore {
25
26
    type Bucket = NATSBucket;

27
28
29
30
    async fn get_or_create_bucket(
        &self,
        bucket_name: &str,
        ttl: Option<Duration>,
31
    ) -> Result<Self::Bucket, StoreError> {
32
33
34
35
        let name = Slug::slugify(bucket_name);
        let nats_store = self
            .get_or_create_key_value(&self.endpoint.namespace, &name, ttl)
            .await?;
36
        Ok(NATSBucket { nats_store })
37
38
    }

39
    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
40
41
        let name = Slug::slugify(bucket_name);
        match self.get_key_value(&self.endpoint.namespace, &name).await? {
42
            Some(nats_store) => Ok(Some(NATSBucket { nats_store })),
43
44
45
            None => Ok(None),
        }
    }
46
47
48
49

    fn connection_id(&self) -> u64 {
        self.client.client().server_info().client_id
    }
50
51
52
53
54

    fn shutdown(&self) {
        // TODO: Track and delete any owned keys
        // The TTL should ensure NATS does it, but best we do it immediately
    }
55
56
}

57
impl NATSStore {
58
    pub fn new(client: Client, endpoint: EndpointId) -> Self {
59
        NATSStore { client, endpoint }
60
61
62
63
64
65
66
67
68
69
70
71
    }

    /// Get or create a key-value store (aka bucket) in NATS.
    ///
    /// ttl is only used if we are creating the bucket, so if that has
    /// changed first delete the bucket.
    async fn get_or_create_key_value(
        &self,
        namespace: &str,
        bucket_name: &Slug,
        // Delete entries older than this
        ttl: Option<Duration>,
72
    ) -> Result<async_nats::jetstream::kv::Store, StoreError> {
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        if let Ok(Some(kv)) = self.get_key_value(namespace, bucket_name).await {
            return Ok(kv);
        }

        // It doesn't exist, create it

        let bucket_name = single_name(namespace, bucket_name);
        let js = self.client.jetstream();
        let create_result = js
            .create_key_value(
                // TODO: configure the bucket, probably need to pass some of these values in
                async_nats::jetstream::kv::Config {
                    bucket: bucket_name.clone(),
                    max_age: ttl.unwrap_or_default(),
                    ..Default::default()
                },
            )
            .await;
91
        let nats_store = create_result
92
            .map_err(|err| StoreError::KeyValueError(err.to_string(), bucket_name.clone()))?;
93
        tracing::debug!("Created bucket {bucket_name}");
94
        Ok(nats_store)
95
96
97
98
99
100
    }

    async fn get_key_value(
        &self,
        namespace: &str,
        bucket_name: &Slug,
101
    ) -> Result<Option<async_nats::jetstream::kv::Store>, StoreError> {
102
103
104
105
106
107
108
109
110
111
        let bucket_name = single_name(namespace, bucket_name);
        let js = self.client.jetstream();

        use async_nats::jetstream::context::KeyValueErrorKind;
        match js.get_key_value(&bucket_name).await {
            Ok(store) => Ok(Some(store)),
            Err(err) if err.kind() == KeyValueErrorKind::GetBucket => {
                // bucket doesn't exist
                Ok(None)
            }
112
            Err(err) => Err(StoreError::KeyValueError(err.to_string(), bucket_name)),
113
114
115
116
117
        }
    }
}

#[async_trait]
118
impl Bucket for NATSBucket {
119
120
    async fn insert(
        &self,
121
        key: &kv::Key,
122
        value: bytes::Bytes,
123
        revision: u64,
124
    ) -> Result<StoreOutcome, StoreError> {
125
126
127
128
129
130
131
        if revision == 0 {
            self.create(key, value).await
        } else {
            self.update(key, value, revision).await
        }
    }

132
    async fn get(&self, key: &kv::Key) -> Result<Option<bytes::Bytes>, StoreError> {
133
134
135
        self.nats_store
            .get(key)
            .await
136
            .map_err(|e| StoreError::NATSError(e.to_string()))
137
138
    }

139
    async fn delete(&self, key: &kv::Key) -> Result<(), StoreError> {
140
141
142
        self.nats_store
            .delete(key)
            .await
143
            .map_err(|e| StoreError::NATSError(e.to_string()))
144
145
146
147
    }

    async fn watch(
        &self,
148
149
    ) -> Result<Pin<Box<dyn futures::Stream<Item = kv::WatchEvent> + Send + 'life0>>, StoreError>
    {
150
151
152
153
        let watch_stream = self
            .nats_store
            .watch_all()
            .await
154
            .map_err(|e| StoreError::NATSError(e.to_string()))?;
155
156
157
158
159
160
161
162
        // Map the `Entry` to `Entry.value` which is Bytes of the stored value.
        Ok(Box::pin(
            watch_stream.filter_map(
                |maybe_entry: Result<
                    async_nats::jetstream::kv::Entry,
                    async_nats::error::Error<_>,
                >| async move {
                    match maybe_entry {
163
                        Ok(entry) => {
164
                            let key = kv::Key::new(entry.key);
165
                            Some(match entry.operation {
166
                                Operation::Put => {
167
168
                                    let item = kv::KeyValue::new(key, entry.value);
                                    kv::WatchEvent::Put(item)
169
                                }
170
                                Operation::Delete => kv::WatchEvent::Delete(key),
171
                                // TODO: What is Purge? Not urgent, NATS impl not used
172
                                Operation::Purge => kv::WatchEvent::Delete(key),
173
174
                            })
                        }
175
176
177
178
179
180
181
182
183
184
                        Err(e) => {
                            tracing::error!(error=%e, "watch fatal err");
                            None
                        }
                    }
                },
            ),
        ))
    }

185
    async fn entries(&self) -> Result<HashMap<kv::Key, bytes::Bytes>, StoreError> {
186
187
188
189
        let mut key_stream = self
            .nats_store
            .keys()
            .await
190
            .map_err(|e| StoreError::NATSError(e.to_string()))?;
191
192
193
        let mut out = HashMap::new();
        while let Some(Ok(key)) = key_stream.next().await {
            if let Ok(Some(entry)) = self.nats_store.entry(&key).await {
194
                out.insert(kv::Key::new(key), entry.value);
195
196
197
198
199
200
201
            }
        }
        Ok(out)
    }
}

impl NATSBucket {
202
    async fn create(&self, key: &kv::Key, value: bytes::Bytes) -> Result<StoreOutcome, StoreError> {
203
        match self.nats_store.create(&key, value).await {
204
            Ok(revision) => Ok(StoreOutcome::Created(revision)),
205
206
            Err(err) if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists => {
                // key exists, get the revsion
207
                match self.nats_store.entry(key).await {
208
                    Ok(Some(entry)) => Ok(StoreOutcome::Exists(entry.revision)),
209
210
                    Ok(None) => {
                        tracing::error!(
211
                            %key,
212
213
                            "Race condition, key deleted between create and fetch. Retry."
                        );
214
                        Err(StoreError::Retry)
215
                    }
216
                    Err(err) => Err(StoreError::NATSError(err.to_string())),
217
218
                }
            }
219
            Err(err) => Err(StoreError::NATSError(err.to_string())),
220
221
222
223
224
        }
    }

    async fn update(
        &self,
225
        key: &kv::Key,
226
        value: bytes::Bytes,
227
        revision: u64,
228
    ) -> Result<StoreOutcome, StoreError> {
229
        match self.nats_store.update(key, value.clone(), revision).await {
230
            Ok(revision) => Ok(StoreOutcome::Created(revision)),
231
232
233
            Err(err)
                if err.kind() == async_nats::jetstream::kv::UpdateErrorKind::WrongLastRevision =>
            {
234
                tracing::warn!(revision, %key, "Update WrongLastRevision, resync");
235
236
                self.resync_update(key, value).await
            }
237
            Err(err) => Err(StoreError::NATSError(err.to_string())),
238
239
240
241
242
        }
    }

    /// We have the wrong revision for a key. Fetch it's entry to get the correct revision,
    /// and try the update again.
243
244
    async fn resync_update(
        &self,
245
        key: &kv::Key,
246
247
        value: bytes::Bytes,
    ) -> Result<StoreOutcome, StoreError> {
248
        match self.nats_store.entry(key).await {
249
250
251
            Ok(Some(entry)) => {
                // Re-try the update with new version number
                let next_rev = entry.revision + 1;
252
                match self.nats_store.update(key, value, next_rev).await {
253
254
                    Ok(correct_revision) => Ok(StoreOutcome::Created(correct_revision)),
                    Err(err) => Err(StoreError::NATSError(format!(
255
256
257
258
259
                        "Error during update of key {key} after resync: {err}"
                    ))),
                }
            }
            Ok(None) => {
260
                tracing::warn!(%key, "Entry does not exist during resync, creating.");
261
262
263
                self.create(key, value).await
            }
            Err(err) => {
264
                tracing::error!(%key, %err, "Failed fetching entry during resync");
265
                Err(StoreError::NATSError(err.to_string()))
266
267
268
269
270
271
272
273
274
275
            }
        }
    }
}

/// async-nats won't let us use a multi-part subject to create KV buckets (and probably many other
/// things).
fn single_name(namespace: &str, name: &Slug) -> String {
    format!("{namespace}_{name}")
}