kv.rs 18.9 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
// SPDX-License-Identifier: Apache-2.0

//! Interface to a traditional key-value store such as etcd.
//! "key_value_store" spelt out because in AI land "KV" means something else.

7
use std::borrow::Cow;
8
use std::pin::Pin;
9
use std::str::FromStr;
10
11
use std::sync::Arc;
use std::time::Duration;
12
13
use std::{collections::HashMap, path::PathBuf};
use std::{env, fmt};
14

15
use crate::CancellationToken;
16
use crate::transports::etcd as etcd_transport;
17
18
use async_trait::async_trait;
use futures::StreamExt;
19
use percent_encoding::{NON_ALPHANUMERIC, percent_decode_str, percent_encode};
20
21
22
use serde::{Deserialize, Serialize};

mod mem;
23
pub use mem::MemoryStore;
24
mod nats;
25
pub use nats::NATSStore;
26
mod etcd;
27
pub use etcd::EtcdStore;
28
29
mod file;
pub use file::FileStore;
30

31
const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(1000);
32

33
34
35
/// String we use as the Key in a key-value storage operation. Simple String wrapper
/// that can encode / decode a string.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36
37
38
pub struct Key(String);

impl Key {
39
40
    pub fn new(s: String) -> Key {
        Key(s)
41
42
    }

43
44
45
46
47
48
49
50
51
52
    /// Takes a URL-safe percent-encoded string and creates a Key from it by decoding first.
    /// dynamo%2Fbackend%2Fgenerate%2F17216e63492ef21f becomes dynamo/backend/generate/17216e63492ef21f
    pub fn from_url_safe(s: &str) -> Key {
        Key(percent_decode_str(s).decode_utf8_lossy().to_string())
    }

    /// A URL-safe percent-encoded representation of this key.
    /// e.g.  dynamo/backend/generate/17216e63492ef21f becomes dynamo%2Fbackend%2Fgenerate%2F17216e63492ef21f
    pub fn url_safe(&self) -> Cow<'_, str> {
        percent_encode(self.0.as_bytes(), NON_ALPHANUMERIC).into()
53
54
55
56
57
    }
}

impl From<&str> for Key {
    fn from(s: &str) -> Key {
58
        Key::new(s.to_string())
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    }
}

impl fmt::Display for Key {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self.0)
    }
}

impl AsRef<str> for Key {
    fn as_ref(&self) -> &str {
        &self.0
    }
}

impl From<&Key> for String {
    fn from(k: &Key) -> String {
        k.0.clone()
    }
}

80
81
#[derive(Debug, Clone, PartialEq)]
pub struct KeyValue {
82
    key: Key,
83
84
85
86
    value: bytes::Bytes,
}

impl KeyValue {
87
    pub fn new(key: Key, value: bytes::Bytes) -> Self {
88
89
        KeyValue { key, value }
    }
90
91

    pub fn key(&self) -> String {
92
        self.key.clone().to_string()
93
94
95
    }

    pub fn key_str(&self) -> &str {
96
        self.key.as_ref()
97
98
99
100
101
102
103
104
105
    }

    pub fn value(&self) -> &[u8] {
        &self.value
    }

    pub fn value_str(&self) -> anyhow::Result<&str> {
        std::str::from_utf8(self.value()).map_err(From::from)
    }
106
107
108
109
110
}

#[derive(Debug, Clone, PartialEq)]
pub enum WatchEvent {
    Put(KeyValue),
111
    Delete(Key),
112
113
}

114
#[async_trait]
115
116
pub trait Store: Send + Sync {
    type Bucket: Bucket + Send + Sync + 'static;
117

118
119
120
121
122
    async fn get_or_create_bucket(
        &self,
        bucket_name: &str,
        // auto-delete items older than this
        ttl: Option<Duration>,
123
124
125
126
127
    ) -> Result<Self::Bucket, StoreError>;

    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError>;

    fn connection_id(&self) -> u64;
128
129
130
131
132

    fn shutdown(&self);
}

#[derive(Clone, Debug, Default)]
133
pub enum Selector {
134
135
136
137
138
139
140
141
    // Box it because it is significantly bigger than the other variants
    Etcd(Box<etcd_transport::ClientOptions>),
    File(PathBuf),
    #[default]
    Memory,
    // Nats not listed because likely we want to remove that impl. It is not currently used and not well tested.
}

142
impl fmt::Display for Selector {
143
144
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
145
            Selector::Etcd(opts) => {
146
147
148
                let urls = opts.etcd_url.join(",");
                write!(f, "Etcd({urls})")
            }
149
150
            Selector::File(path) => write!(f, "File({})", path.display()),
            Selector::Memory => write!(f, "Memory"),
151
152
153
154
        }
    }
}

155
impl FromStr for Selector {
156
157
    type Err = anyhow::Error;

158
    fn from_str(s: &str) -> anyhow::Result<Selector> {
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        match s {
            "etcd" => Ok(Self::Etcd(Box::default())),
            "file" => {
                let root = env::var("DYN_FILE_KV")
                    .map(PathBuf::from)
                    .unwrap_or_else(|_| env::temp_dir().join("dynamo_store_kv"));
                Ok(Self::File(root))
            }
            "mem" => Ok(Self::Memory),
            x => anyhow::bail!("Unknown key-value store type '{x}'"),
        }
    }
}

173
impl TryFrom<String> for Selector {
174
175
    type Error = anyhow::Error;

176
    fn try_from(s: String) -> anyhow::Result<Selector> {
177
178
        s.parse()
    }
179
180
181
}

#[allow(clippy::large_enum_variant)]
182
enum KeyValueStoreEnum {
183
184
185
    Memory(MemoryStore),
    Nats(NATSStore),
    Etcd(EtcdStore),
186
    File(FileStore),
187
188
189
190
191
192
193
194
}

impl KeyValueStoreEnum {
    async fn get_or_create_bucket(
        &self,
        bucket_name: &str,
        // auto-delete items older than this
        ttl: Option<Duration>,
195
    ) -> Result<Box<dyn Bucket>, StoreError> {
196
197
198
199
200
        use KeyValueStoreEnum::*;
        Ok(match self {
            Memory(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
            Nats(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
            Etcd(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
201
            File(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
202
203
        })
    }
204

205
    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Box<dyn Bucket>>, StoreError> {
206
        use KeyValueStoreEnum::*;
207
        let maybe_bucket: Option<Box<dyn Bucket>> = match self {
208
209
210
            Memory(x) => x
                .get_bucket(bucket_name)
                .await?
211
                .map(|b| Box::new(b) as Box<dyn Bucket>),
212
213
214
            Nats(x) => x
                .get_bucket(bucket_name)
                .await?
215
                .map(|b| Box::new(b) as Box<dyn Bucket>),
216
217
218
            Etcd(x) => x
                .get_bucket(bucket_name)
                .await?
219
                .map(|b| Box::new(b) as Box<dyn Bucket>),
220
221
222
            File(x) => x
                .get_bucket(bucket_name)
                .await?
223
                .map(|b| Box::new(b) as Box<dyn Bucket>),
224
225
226
        };
        Ok(maybe_bucket)
    }
227

228
229
230
231
232
233
    fn connection_id(&self) -> u64 {
        use KeyValueStoreEnum::*;
        match self {
            Memory(x) => x.connection_id(),
            Etcd(x) => x.connection_id(),
            Nats(x) => x.connection_id(),
234
235
236
237
238
239
240
241
242
243
244
            File(x) => x.connection_id(),
        }
    }

    fn shutdown(&self) {
        use KeyValueStoreEnum::*;
        match self {
            Memory(x) => x.shutdown(),
            Etcd(x) => x.shutdown(),
            Nats(x) => x.shutdown(),
            File(x) => x.shutdown(),
245
246
        }
    }
247
248
}

249
#[derive(Clone)]
250
pub struct Manager(Arc<KeyValueStoreEnum>);
251

252
impl Default for Manager {
253
    fn default() -> Self {
254
        Manager::memory()
255
256
    }
}
257

258
impl Manager {
259
260
261
262
263
264
265
266
267
    /// In-memory KeyValueStoreManager for testing
    pub fn memory() -> Self {
        Self::new(KeyValueStoreEnum::Memory(MemoryStore::new()))
    }

    pub fn etcd(etcd_client: crate::transports::etcd::Client) -> Self {
        Self::new(KeyValueStoreEnum::Etcd(EtcdStore::new(etcd_client)))
    }

268
269
    pub fn file<P: Into<PathBuf>>(cancel_token: CancellationToken, root: P) -> Self {
        Self::new(KeyValueStoreEnum::File(FileStore::new(cancel_token, root)))
270
271
    }

272
273
    fn new(s: KeyValueStoreEnum) -> Manager {
        Manager(Arc::new(s))
274
275
276
277
278
279
280
    }

    pub async fn get_or_create_bucket(
        &self,
        bucket_name: &str,
        // auto-delete items older than this
        ttl: Option<Duration>,
281
    ) -> Result<Box<dyn Bucket>, StoreError> {
282
283
284
285
286
287
        self.0.get_or_create_bucket(bucket_name, ttl).await
    }

    pub async fn get_bucket(
        &self,
        bucket_name: &str,
288
    ) -> Result<Option<Box<dyn Bucket>>, StoreError> {
289
290
291
292
293
        self.0.get_bucket(bucket_name).await
    }

    pub fn connection_id(&self) -> u64 {
        self.0.connection_id()
294
295
296
297
298
    }

    pub async fn load<T: for<'a> Deserialize<'a>>(
        &self,
        bucket: &str,
299
        key: &Key,
300
    ) -> Result<Option<T>, StoreError> {
301
302
303
304
        let Some(bucket) = self.0.get_bucket(bucket).await? else {
            // No bucket means no cards
            return Ok(None);
        };
305
306
        Ok(match bucket.get(key).await? {
            Some(card_bytes) => {
307
                let card: T = serde_json::from_slice(card_bytes.as_ref())?;
308
                Some(card)
309
            }
310
311
            None => None,
        })
312
313
314
315
316
    }

    /// Returns a receiver that will receive all the existing keys, and
    /// then block and receive new keys as they are created.
    /// Starts a task that runs forever, watches the store.
317
    pub fn watch(
318
319
320
        self: Arc<Self>,
        bucket_name: &str,
        bucket_ttl: Option<Duration>,
321
        cancel_token: CancellationToken,
322
    ) -> (
323
        tokio::task::JoinHandle<Result<(), StoreError>>,
324
        tokio::sync::mpsc::Receiver<WatchEvent>,
325
326
    ) {
        let bucket_name = bucket_name.to_string();
327
328
329
330
        // Use a larger channel capacity to reduce the likelihood that a slow consumer
        // during the initial KV-store replay phase triggers send timeouts. Events may
        // still be dropped if the consumer cannot keep up within `WATCH_SEND_TIMEOUT`.
        let (tx, rx) = tokio::sync::mpsc::channel(16384);
331
332
333
334
335
336
337
338
339
        let watch_task = tokio::spawn(async move {
            // Start listening for changes but don't poll this yet
            let bucket = self
                .0
                .get_or_create_bucket(&bucket_name, bucket_ttl)
                .await?;
            let mut stream = bucket.watch().await?;

            // Send all the existing keys
340
            for (key, bytes) in bucket.entries().await? {
341
342
343
344
345
346
347
348
349
                if let Err(err) = tx
                    .send_timeout(
                        WatchEvent::Put(KeyValue::new(key, bytes)),
                        WATCH_SEND_TIMEOUT,
                    )
                    .await
                {
                    tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding existing key to channel");
                }
350
351
352
            }

            // Now block waiting for new entries
353
            loop {
354
                let event = tokio::select! {
355
356
                    _ = cancel_token.cancelled() => break,
                    result = stream.next() => match result {
357
                        Some(event) => event,
358
359
360
                        None => break,
                    }
                };
361
362
363
                if let Err(err) = tx.send_timeout(event, WATCH_SEND_TIMEOUT).await {
                    tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding new key to channel");
                }
364
365
            }

366
            Ok::<(), StoreError>(())
367
368
369
370
371
372
373
374
        });
        (watch_task, rx)
    }

    pub async fn publish<T: Serialize + Versioned + Send + Sync>(
        &self,
        bucket_name: &str,
        bucket_ttl: Option<Duration>,
375
        key: &Key,
376
        obj: &mut T,
377
    ) -> anyhow::Result<StoreOutcome> {
378
        let obj_json = serde_json::to_vec(obj)?;
379
380
        let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;

381
        let outcome = bucket.insert(key, obj_json.into(), obj.revision()).await?;
382
383

        match outcome {
384
            StoreOutcome::Created(revision) | StoreOutcome::Exists(revision) => {
385
386
387
388
389
                obj.set_revision(revision);
            }
        }
        Ok(outcome)
    }
390
391
392
393
394
395

    /// Cleanup any temporary state.
    /// TODO: Should this be async? Take &mut self?
    pub fn shutdown(&self) {
        self.0.shutdown()
    }
396
397
398
399
}

/// An online storage for key-value config values.
#[async_trait]
400
pub trait Bucket: Send + Sync {
401
402
    /// A bucket is a collection of key/value pairs.
    /// Insert a value into a bucket, if it doesn't exist already
403
    /// The Key should be the name of the item, not including the bucket name.
404
405
    async fn insert(
        &self,
406
        key: &Key,
407
        value: bytes::Bytes,
408
        revision: u64,
409
    ) -> Result<StoreOutcome, StoreError>;
410
411

    /// Fetch an item from the key-value storage
412
    /// The Key should be the name of the item, not including the bucket name.
413
    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError>;
414
415

    /// Delete an item from the bucket
416
    /// The Key should be the name of the item, not including the bucket name.
417
    async fn delete(&self, key: &Key) -> Result<(), StoreError>;
418
419
420
421
422
423

    /// A stream of items inserted into the bucket.
    /// Every time the stream is polled it will either return a newly created entry, or block until
    /// such time.
    async fn watch(
        &self,
424
    ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + '_>>, StoreError>;
425

426
427
428
429
    /// The entries in this bucket.
    /// The Key includes the full path including the bucket name.
    /// That means you cannot directory get a Key from `entries` and pass it to `get` or `delete`.
    async fn entries(&self) -> Result<HashMap<Key, bytes::Bytes>, StoreError>;
430
431
432
}

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
433
pub enum StoreOutcome {
434
435
436
437
438
439
    /// The operation succeeded and created a new entry with this revision.
    /// Note that "create" also means update, because each new revision is a "create".
    Created(u64),
    /// The operation did not do anything, the value was already present, with this revision.
    Exists(u64),
}
440
impl fmt::Display for StoreOutcome {
441
442
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
443
444
            StoreOutcome::Created(revision) => write!(f, "Created at {revision}"),
            StoreOutcome::Exists(revision) => write!(f, "Exists at {revision}"),
445
446
447
448
449
        }
    }
}

#[derive(thiserror::Error, Debug)]
450
pub enum StoreError {
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
    #[error("Could not find bucket '{0}'")]
    MissingBucket(String),

    #[error("Could not find key '{0}'")]
    MissingKey(String),

    #[error("Internal storage error: '{0}'")]
    ProviderError(String),

    #[error("Internal NATS error: {0}")]
    NATSError(String),

    #[error("Internal etcd error: {0}")]
    EtcdError(String),

466
467
468
    #[error("Internal filesystem error: {0}")]
    FilesystemError(String),

469
    #[error("Key Value Error: {0} for bucket '{1}'")]
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    KeyValueError(String, String),

    #[error("Error decoding bytes: {0}")]
    JSONDecodeError(#[from] serde_json::error::Error),

    #[error("Race condition, retry the call")]
    Retry,
}

/// A trait allowing to get/set a revision on an object.
/// NATS uses this to ensure atomic updates.
pub trait Versioned {
    fn revision(&self) -> u64;
    fn set_revision(&mut self, r: u64);
}

#[cfg(test)]
mod tests {
    use std::sync::Arc;

    use super::*;
491
    use futures::{StreamExt, pin_mut};
492

493
    const BUCKET_NAME: &str = "v1/mdc";
494
495
496
497
498

    /// Convert the value returned by `watch()` into a broadcast stream that multiple
    /// clients can listen to.
    #[allow(dead_code)]
    pub struct TappableStream {
499
        tx: tokio::sync::broadcast::Sender<WatchEvent>,
500
501
502
503
504
505
    }

    #[allow(dead_code)]
    impl TappableStream {
        async fn new<T>(stream: T, max_size: usize) -> Self
        where
506
            T: futures::Stream<Item = WatchEvent> + Send + 'static,
507
508
509
510
511
512
513
514
515
516
517
518
        {
            let (tx, _) = tokio::sync::broadcast::channel(max_size);
            let tx2 = tx.clone();
            tokio::spawn(async move {
                pin_mut!(stream);
                while let Some(x) = stream.next().await {
                    let _ = tx2.send(x);
                }
            });
            TappableStream { tx }
        }

519
        fn subscribe(&self) -> tokio::sync::broadcast::Receiver<WatchEvent> {
520
521
522
523
524
            self.tx.subscribe()
        }
    }

    fn init() {
525
        crate::logging::init();
526
527
528
529
530
531
    }

    #[tokio::test]
    async fn test_memory_storage() -> anyhow::Result<()> {
        init();

532
        let s = Arc::new(MemoryStore::new());
533
534
535
        let s2 = Arc::clone(&s);

        let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
536
        let res = bucket.insert(&"test1".into(), "value1".into(), 0).await?;
537
        assert_eq!(res, StoreOutcome::Created(0));
538

539
540
541
        let mut expected = Vec::with_capacity(3);
        for i in 1..=3 {
            let item = WatchEvent::Put(KeyValue::new(
542
                Key::new(format!("test{i}")),
543
                format!("value{i}").into(),
544
545
546
547
            ));
            expected.push(item);
        }

548
549
550
551
552
553
554
        let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
        let ingress = tokio::spawn(async move {
            let b2 = s2.get_or_create_bucket(BUCKET_NAME, None).await?;
            let mut stream = b2.watch().await?;

            // Put in before starting the watch-all
            let v = stream.next().await.unwrap();
555
            assert_eq!(v, expected[0]);
556
557
558
559
560

            got_first_tx.send(()).unwrap();

            // Put in after
            let v = stream.next().await.unwrap();
561
562
            assert_eq!(v, expected[1]);

563
            let v = stream.next().await.unwrap();
564
            assert_eq!(v, expected[2]);
565

566
            Ok::<_, StoreError>(())
567
568
        });

569
        // MemoryStore uses a HashMap with no inherent ordering, so we must ensure test1 is
570
571
572
573
        // fetched before test2 is inserted, otherwise they can come out in any order, and we
        // wouldn't be testing the watch behavior.
        got_first_rx.await?;

574
        let res = bucket.insert(&"test2".into(), "value2".into(), 0).await?;
575
        assert_eq!(res, StoreOutcome::Created(0));
576
577

        // Repeat a key and revision. Ignored.
578
        let res = bucket.insert(&"test2".into(), "value2".into(), 0).await?;
579
        assert_eq!(res, StoreOutcome::Exists(0));
580
581

        // Increment revision
582
        let res = bucket.insert(&"test2".into(), "value2".into(), 1).await?;
583
        assert_eq!(res, StoreOutcome::Created(1));
584

585
        let res = bucket.insert(&"test3".into(), "value3".into(), 0).await?;
586
        assert_eq!(res, StoreOutcome::Created(0));
587
588
589
590
591
592
593
594
595
596
597

        // ingress exits once it has received all values
        let _ = ingress.await?;

        Ok(())
    }

    #[tokio::test]
    async fn test_broadcast_stream() -> anyhow::Result<()> {
        init();

598
        let s: &'static _ = Box::leak(Box::new(MemoryStore::new()));
599
600
601
        let bucket: &'static _ =
            Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));

602
        let res = bucket.insert(&"test1".into(), "value1".into(), 0).await?;
603
        assert_eq!(res, StoreOutcome::Created(0));
604
605
606
607
608
609
610

        let stream = bucket.watch().await?;
        let tap = TappableStream::new(stream, 10).await;

        let mut rx1 = tap.subscribe();
        let mut rx2 = tap.subscribe();

611
        let item = WatchEvent::Put(KeyValue::new(Key::new("test1".to_string()), "GK".into()));
612
        let item_clone = item.clone();
613
614
        let handle1 = tokio::spawn(async move {
            let b = rx1.recv().await.unwrap();
615
            assert_eq!(b, item_clone);
616
617
618
        });
        let handle2 = tokio::spawn(async move {
            let b = rx2.recv().await.unwrap();
619
            assert_eq!(b, item);
620
621
        });

622
        bucket.insert(&"test1".into(), "GK".into(), 1).await?;
623
624
625
626
627

        let _ = futures::join!(handle1, handle2);
        Ok(())
    }
}