".buildkite/vscode:/vscode.git/clone" did not exist on "d565e0976fb5ffd353727066ac8aa98272e318af"
kv.rs 18.7 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
// SPDX-License-Identifier: Apache-2.0

//! Interface to a traditional key-value store such as etcd.
//! "key_value_store" spelt out because in AI land "KV" means something else.

7
use std::borrow::Cow;
8
use std::pin::Pin;
9
use std::str::FromStr;
10
11
use std::sync::Arc;
use std::time::Duration;
12
13
use std::{collections::HashMap, path::PathBuf};
use std::{env, fmt};
14

15
use crate::CancellationToken;
16
use crate::transports::etcd as etcd_transport;
17
18
use async_trait::async_trait;
use futures::StreamExt;
19
use percent_encoding::{NON_ALPHANUMERIC, percent_decode_str, percent_encode};
20
21
22
use serde::{Deserialize, Serialize};

mod mem;
23
pub use mem::MemoryStore;
24
mod nats;
25
pub use nats::NATSStore;
26
mod etcd;
27
pub use etcd::EtcdStore;
28
29
mod file;
pub use file::FileStore;
30

31
const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(1000);
32

33
34
35
/// String we use as the Key in a key-value storage operation. Simple String wrapper
/// that can encode / decode a string.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36
37
38
pub struct Key(String);

impl Key {
39
40
    pub fn new(s: String) -> Key {
        Key(s)
41
42
    }

43
44
45
46
47
48
49
50
51
52
    /// Takes a URL-safe percent-encoded string and creates a Key from it by decoding first.
    /// dynamo%2Fbackend%2Fgenerate%2F17216e63492ef21f becomes dynamo/backend/generate/17216e63492ef21f
    pub fn from_url_safe(s: &str) -> Key {
        Key(percent_decode_str(s).decode_utf8_lossy().to_string())
    }

    /// A URL-safe percent-encoded representation of this key.
    /// e.g.  dynamo/backend/generate/17216e63492ef21f becomes dynamo%2Fbackend%2Fgenerate%2F17216e63492ef21f
    pub fn url_safe(&self) -> Cow<'_, str> {
        percent_encode(self.0.as_bytes(), NON_ALPHANUMERIC).into()
53
54
55
56
57
    }
}

impl From<&str> for Key {
    fn from(s: &str) -> Key {
58
        Key::new(s.to_string())
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    }
}

impl fmt::Display for Key {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self.0)
    }
}

impl AsRef<str> for Key {
    fn as_ref(&self) -> &str {
        &self.0
    }
}

impl From<&Key> for String {
    fn from(k: &Key) -> String {
        k.0.clone()
    }
}

80
81
#[derive(Debug, Clone, PartialEq)]
pub struct KeyValue {
82
    key: Key,
83
84
85
86
    value: bytes::Bytes,
}

impl KeyValue {
87
    pub fn new(key: Key, value: bytes::Bytes) -> Self {
88
89
        KeyValue { key, value }
    }
90
91

    pub fn key(&self) -> String {
92
        self.key.clone().to_string()
93
94
95
    }

    pub fn key_str(&self) -> &str {
96
        self.key.as_ref()
97
98
99
100
101
102
103
104
105
    }

    pub fn value(&self) -> &[u8] {
        &self.value
    }

    pub fn value_str(&self) -> anyhow::Result<&str> {
        std::str::from_utf8(self.value()).map_err(From::from)
    }
106
107
108
109
110
}

#[derive(Debug, Clone, PartialEq)]
pub enum WatchEvent {
    Put(KeyValue),
111
    Delete(Key),
112
113
}

114
#[async_trait]
115
116
pub trait Store: Send + Sync {
    type Bucket: Bucket + Send + Sync + 'static;
117

118
119
120
121
122
    async fn get_or_create_bucket(
        &self,
        bucket_name: &str,
        // auto-delete items older than this
        ttl: Option<Duration>,
123
124
125
126
127
    ) -> Result<Self::Bucket, StoreError>;

    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError>;

    fn connection_id(&self) -> u64;
128
129
130
131
132

    fn shutdown(&self);
}

#[derive(Clone, Debug, Default)]
133
pub enum Selector {
134
135
136
137
138
139
140
141
    // Box it because it is significantly bigger than the other variants
    Etcd(Box<etcd_transport::ClientOptions>),
    File(PathBuf),
    #[default]
    Memory,
    // Nats not listed because likely we want to remove that impl. It is not currently used and not well tested.
}

142
impl fmt::Display for Selector {
143
144
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
145
            Selector::Etcd(opts) => {
146
147
148
                let urls = opts.etcd_url.join(",");
                write!(f, "Etcd({urls})")
            }
149
150
            Selector::File(path) => write!(f, "File({})", path.display()),
            Selector::Memory => write!(f, "Memory"),
151
152
153
154
        }
    }
}

155
impl FromStr for Selector {
156
157
    type Err = anyhow::Error;

158
    fn from_str(s: &str) -> anyhow::Result<Selector> {
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        match s {
            "etcd" => Ok(Self::Etcd(Box::default())),
            "file" => {
                let root = env::var("DYN_FILE_KV")
                    .map(PathBuf::from)
                    .unwrap_or_else(|_| env::temp_dir().join("dynamo_store_kv"));
                Ok(Self::File(root))
            }
            "mem" => Ok(Self::Memory),
            x => anyhow::bail!("Unknown key-value store type '{x}'"),
        }
    }
}

173
impl TryFrom<String> for Selector {
174
175
    type Error = anyhow::Error;

176
    fn try_from(s: String) -> anyhow::Result<Selector> {
177
178
        s.parse()
    }
179
180
181
}

#[allow(clippy::large_enum_variant)]
182
enum KeyValueStoreEnum {
183
184
185
    Memory(MemoryStore),
    Nats(NATSStore),
    Etcd(EtcdStore),
186
    File(FileStore),
187
188
189
190
191
192
193
194
}

impl KeyValueStoreEnum {
    async fn get_or_create_bucket(
        &self,
        bucket_name: &str,
        // auto-delete items older than this
        ttl: Option<Duration>,
195
    ) -> Result<Box<dyn Bucket>, StoreError> {
196
197
198
199
200
        use KeyValueStoreEnum::*;
        Ok(match self {
            Memory(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
            Nats(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
            Etcd(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
201
            File(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
202
203
        })
    }
204

205
    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Box<dyn Bucket>>, StoreError> {
206
        use KeyValueStoreEnum::*;
207
        let maybe_bucket: Option<Box<dyn Bucket>> = match self {
208
209
210
            Memory(x) => x
                .get_bucket(bucket_name)
                .await?
211
                .map(|b| Box::new(b) as Box<dyn Bucket>),
212
213
214
            Nats(x) => x
                .get_bucket(bucket_name)
                .await?
215
                .map(|b| Box::new(b) as Box<dyn Bucket>),
216
217
218
            Etcd(x) => x
                .get_bucket(bucket_name)
                .await?
219
                .map(|b| Box::new(b) as Box<dyn Bucket>),
220
221
222
            File(x) => x
                .get_bucket(bucket_name)
                .await?
223
                .map(|b| Box::new(b) as Box<dyn Bucket>),
224
225
226
        };
        Ok(maybe_bucket)
    }
227

228
229
230
231
232
233
    fn connection_id(&self) -> u64 {
        use KeyValueStoreEnum::*;
        match self {
            Memory(x) => x.connection_id(),
            Etcd(x) => x.connection_id(),
            Nats(x) => x.connection_id(),
234
235
236
237
238
239
240
241
242
243
244
            File(x) => x.connection_id(),
        }
    }

    fn shutdown(&self) {
        use KeyValueStoreEnum::*;
        match self {
            Memory(x) => x.shutdown(),
            Etcd(x) => x.shutdown(),
            Nats(x) => x.shutdown(),
            File(x) => x.shutdown(),
245
246
        }
    }
247
248
}

249
#[derive(Clone)]
250
pub struct Manager(Arc<KeyValueStoreEnum>);
251

252
impl Default for Manager {
253
    fn default() -> Self {
254
        Manager::memory()
255
256
    }
}
257

258
impl Manager {
259
260
261
262
263
264
265
266
267
    /// In-memory KeyValueStoreManager for testing
    pub fn memory() -> Self {
        Self::new(KeyValueStoreEnum::Memory(MemoryStore::new()))
    }

    pub fn etcd(etcd_client: crate::transports::etcd::Client) -> Self {
        Self::new(KeyValueStoreEnum::Etcd(EtcdStore::new(etcd_client)))
    }

268
269
    pub fn file<P: Into<PathBuf>>(cancel_token: CancellationToken, root: P) -> Self {
        Self::new(KeyValueStoreEnum::File(FileStore::new(cancel_token, root)))
270
271
    }

272
273
    fn new(s: KeyValueStoreEnum) -> Manager {
        Manager(Arc::new(s))
274
275
276
277
278
279
280
    }

    pub async fn get_or_create_bucket(
        &self,
        bucket_name: &str,
        // auto-delete items older than this
        ttl: Option<Duration>,
281
    ) -> Result<Box<dyn Bucket>, StoreError> {
282
283
284
285
286
287
        self.0.get_or_create_bucket(bucket_name, ttl).await
    }

    pub async fn get_bucket(
        &self,
        bucket_name: &str,
288
    ) -> Result<Option<Box<dyn Bucket>>, StoreError> {
289
290
291
292
293
        self.0.get_bucket(bucket_name).await
    }

    pub fn connection_id(&self) -> u64 {
        self.0.connection_id()
294
295
296
297
298
    }

    pub async fn load<T: for<'a> Deserialize<'a>>(
        &self,
        bucket: &str,
299
        key: &Key,
300
    ) -> Result<Option<T>, StoreError> {
301
302
303
304
        let Some(bucket) = self.0.get_bucket(bucket).await? else {
            // No bucket means no cards
            return Ok(None);
        };
305
306
        Ok(match bucket.get(key).await? {
            Some(card_bytes) => {
307
                let card: T = serde_json::from_slice(card_bytes.as_ref())?;
308
                Some(card)
309
            }
310
311
            None => None,
        })
312
313
314
315
316
    }

    /// Returns a receiver that will receive all the existing keys, and
    /// then block and receive new keys as they are created.
    /// Starts a task that runs forever, watches the store.
317
    pub fn watch(
318
319
320
        self: Arc<Self>,
        bucket_name: &str,
        bucket_ttl: Option<Duration>,
321
        cancel_token: CancellationToken,
322
    ) -> (
323
        tokio::task::JoinHandle<Result<(), StoreError>>,
324
        tokio::sync::mpsc::Receiver<WatchEvent>,
325
326
    ) {
        let bucket_name = bucket_name.to_string();
327
        let (tx, rx) = tokio::sync::mpsc::channel(1024);
328
329
330
331
332
333
334
335
336
        let watch_task = tokio::spawn(async move {
            // Start listening for changes but don't poll this yet
            let bucket = self
                .0
                .get_or_create_bucket(&bucket_name, bucket_ttl)
                .await?;
            let mut stream = bucket.watch().await?;

            // Send all the existing keys
337
            for (key, bytes) in bucket.entries().await? {
338
339
340
341
342
343
344
345
346
                if let Err(err) = tx
                    .send_timeout(
                        WatchEvent::Put(KeyValue::new(key, bytes)),
                        WATCH_SEND_TIMEOUT,
                    )
                    .await
                {
                    tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding existing key to channel");
                }
347
348
349
            }

            // Now block waiting for new entries
350
            loop {
351
                let event = tokio::select! {
352
353
                    _ = cancel_token.cancelled() => break,
                    result = stream.next() => match result {
354
                        Some(event) => event,
355
356
357
                        None => break,
                    }
                };
358
359
360
                if let Err(err) = tx.send_timeout(event, WATCH_SEND_TIMEOUT).await {
                    tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding new key to channel");
                }
361
362
            }

363
            Ok::<(), StoreError>(())
364
365
366
367
368
369
370
371
        });
        (watch_task, rx)
    }

    pub async fn publish<T: Serialize + Versioned + Send + Sync>(
        &self,
        bucket_name: &str,
        bucket_ttl: Option<Duration>,
372
        key: &Key,
373
        obj: &mut T,
374
    ) -> anyhow::Result<StoreOutcome> {
375
        let obj_json = serde_json::to_vec(obj)?;
376
377
        let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;

378
        let outcome = bucket.insert(key, obj_json.into(), obj.revision()).await?;
379
380

        match outcome {
381
            StoreOutcome::Created(revision) | StoreOutcome::Exists(revision) => {
382
383
384
385
386
                obj.set_revision(revision);
            }
        }
        Ok(outcome)
    }
387
388
389
390
391
392

    /// Cleanup any temporary state.
    /// TODO: Should this be async? Take &mut self?
    pub fn shutdown(&self) {
        self.0.shutdown()
    }
393
394
395
396
}

/// An online storage for key-value config values.
#[async_trait]
397
pub trait Bucket: Send + Sync {
398
399
    /// A bucket is a collection of key/value pairs.
    /// Insert a value into a bucket, if it doesn't exist already
400
    /// The Key should be the name of the item, not including the bucket name.
401
402
    async fn insert(
        &self,
403
        key: &Key,
404
        value: bytes::Bytes,
405
        revision: u64,
406
    ) -> Result<StoreOutcome, StoreError>;
407
408

    /// Fetch an item from the key-value storage
409
    /// The Key should be the name of the item, not including the bucket name.
410
    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError>;
411
412

    /// Delete an item from the bucket
413
    /// The Key should be the name of the item, not including the bucket name.
414
    async fn delete(&self, key: &Key) -> Result<(), StoreError>;
415
416
417
418
419
420

    /// A stream of items inserted into the bucket.
    /// Every time the stream is polled it will either return a newly created entry, or block until
    /// such time.
    async fn watch(
        &self,
421
    ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + '_>>, StoreError>;
422

423
424
425
426
    /// The entries in this bucket.
    /// The Key includes the full path including the bucket name.
    /// That means you cannot directory get a Key from `entries` and pass it to `get` or `delete`.
    async fn entries(&self) -> Result<HashMap<Key, bytes::Bytes>, StoreError>;
427
428
429
}

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
430
pub enum StoreOutcome {
431
432
433
434
435
436
    /// The operation succeeded and created a new entry with this revision.
    /// Note that "create" also means update, because each new revision is a "create".
    Created(u64),
    /// The operation did not do anything, the value was already present, with this revision.
    Exists(u64),
}
437
impl fmt::Display for StoreOutcome {
438
439
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
440
441
            StoreOutcome::Created(revision) => write!(f, "Created at {revision}"),
            StoreOutcome::Exists(revision) => write!(f, "Exists at {revision}"),
442
443
444
445
446
        }
    }
}

#[derive(thiserror::Error, Debug)]
447
pub enum StoreError {
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    #[error("Could not find bucket '{0}'")]
    MissingBucket(String),

    #[error("Could not find key '{0}'")]
    MissingKey(String),

    #[error("Internal storage error: '{0}'")]
    ProviderError(String),

    #[error("Internal NATS error: {0}")]
    NATSError(String),

    #[error("Internal etcd error: {0}")]
    EtcdError(String),

463
464
465
    #[error("Internal filesystem error: {0}")]
    FilesystemError(String),

466
    #[error("Key Value Error: {0} for bucket '{1}'")]
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
    KeyValueError(String, String),

    #[error("Error decoding bytes: {0}")]
    JSONDecodeError(#[from] serde_json::error::Error),

    #[error("Race condition, retry the call")]
    Retry,
}

/// A trait allowing to get/set a revision on an object.
/// NATS uses this to ensure atomic updates.
pub trait Versioned {
    fn revision(&self) -> u64;
    fn set_revision(&mut self, r: u64);
}

#[cfg(test)]
mod tests {
    use std::sync::Arc;

    use super::*;
488
    use futures::{StreamExt, pin_mut};
489

490
    const BUCKET_NAME: &str = "v1/mdc";
491
492
493
494
495

    /// Convert the value returned by `watch()` into a broadcast stream that multiple
    /// clients can listen to.
    #[allow(dead_code)]
    pub struct TappableStream {
496
        tx: tokio::sync::broadcast::Sender<WatchEvent>,
497
498
499
500
501
502
    }

    #[allow(dead_code)]
    impl TappableStream {
        async fn new<T>(stream: T, max_size: usize) -> Self
        where
503
            T: futures::Stream<Item = WatchEvent> + Send + 'static,
504
505
506
507
508
509
510
511
512
513
514
515
        {
            let (tx, _) = tokio::sync::broadcast::channel(max_size);
            let tx2 = tx.clone();
            tokio::spawn(async move {
                pin_mut!(stream);
                while let Some(x) = stream.next().await {
                    let _ = tx2.send(x);
                }
            });
            TappableStream { tx }
        }

516
        fn subscribe(&self) -> tokio::sync::broadcast::Receiver<WatchEvent> {
517
518
519
520
521
            self.tx.subscribe()
        }
    }

    fn init() {
522
        crate::logging::init();
523
524
525
526
527
528
    }

    #[tokio::test]
    async fn test_memory_storage() -> anyhow::Result<()> {
        init();

529
        let s = Arc::new(MemoryStore::new());
530
531
532
        let s2 = Arc::clone(&s);

        let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
533
        let res = bucket.insert(&"test1".into(), "value1".into(), 0).await?;
534
        assert_eq!(res, StoreOutcome::Created(0));
535

536
537
538
        let mut expected = Vec::with_capacity(3);
        for i in 1..=3 {
            let item = WatchEvent::Put(KeyValue::new(
539
                Key::new(format!("test{i}")),
540
                format!("value{i}").into(),
541
542
543
544
            ));
            expected.push(item);
        }

545
546
547
548
549
550
551
        let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
        let ingress = tokio::spawn(async move {
            let b2 = s2.get_or_create_bucket(BUCKET_NAME, None).await?;
            let mut stream = b2.watch().await?;

            // Put in before starting the watch-all
            let v = stream.next().await.unwrap();
552
            assert_eq!(v, expected[0]);
553
554
555
556
557

            got_first_tx.send(()).unwrap();

            // Put in after
            let v = stream.next().await.unwrap();
558
559
            assert_eq!(v, expected[1]);

560
            let v = stream.next().await.unwrap();
561
            assert_eq!(v, expected[2]);
562

563
            Ok::<_, StoreError>(())
564
565
        });

566
        // MemoryStore uses a HashMap with no inherent ordering, so we must ensure test1 is
567
568
569
570
        // fetched before test2 is inserted, otherwise they can come out in any order, and we
        // wouldn't be testing the watch behavior.
        got_first_rx.await?;

571
        let res = bucket.insert(&"test2".into(), "value2".into(), 0).await?;
572
        assert_eq!(res, StoreOutcome::Created(0));
573
574

        // Repeat a key and revision. Ignored.
575
        let res = bucket.insert(&"test2".into(), "value2".into(), 0).await?;
576
        assert_eq!(res, StoreOutcome::Exists(0));
577
578

        // Increment revision
579
        let res = bucket.insert(&"test2".into(), "value2".into(), 1).await?;
580
        assert_eq!(res, StoreOutcome::Created(1));
581

582
        let res = bucket.insert(&"test3".into(), "value3".into(), 0).await?;
583
        assert_eq!(res, StoreOutcome::Created(0));
584
585
586
587
588
589
590
591
592
593
594

        // ingress exits once it has received all values
        let _ = ingress.await?;

        Ok(())
    }

    #[tokio::test]
    async fn test_broadcast_stream() -> anyhow::Result<()> {
        init();

595
        let s: &'static _ = Box::leak(Box::new(MemoryStore::new()));
596
597
598
        let bucket: &'static _ =
            Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));

599
        let res = bucket.insert(&"test1".into(), "value1".into(), 0).await?;
600
        assert_eq!(res, StoreOutcome::Created(0));
601
602
603
604
605
606
607

        let stream = bucket.watch().await?;
        let tap = TappableStream::new(stream, 10).await;

        let mut rx1 = tap.subscribe();
        let mut rx2 = tap.subscribe();

608
        let item = WatchEvent::Put(KeyValue::new(Key::new("test1".to_string()), "GK".into()));
609
        let item_clone = item.clone();
610
611
        let handle1 = tokio::spawn(async move {
            let b = rx1.recv().await.unwrap();
612
            assert_eq!(b, item_clone);
613
614
615
        });
        let handle2 = tokio::spawn(async move {
            let b = rx2.recv().await.unwrap();
616
            assert_eq!(b, item);
617
618
        });

619
        bucket.insert(&"test1".into(), "GK".into(), 1).await?;
620
621
622
623
624

        let _ = futures::join!(handle1, handle2);
        Ok(())
    }
}