Unverified Commit 27904535 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix(storage): Correctly encoding FileStore keys (#4539)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent f2769d88
......@@ -2860,6 +2860,7 @@ dependencies = [
"opentelemetry-otlp",
"opentelemetry_sdk",
"parking_lot",
"percent-encoding",
"prometheus",
"rand 0.9.2",
"rayon",
......
......@@ -1716,6 +1716,7 @@ dependencies = [
"opentelemetry-otlp",
"opentelemetry_sdk",
"parking_lot",
"percent-encoding",
"prometheus",
"rand 0.9.2",
"rayon",
......
......@@ -315,7 +315,7 @@ impl ModelManager {
.get_or_create_bucket(KV_ROUTERS_ROOT_PATH, None)
.await?;
let router_uuid = uuid::Uuid::new_v4();
let router_key = Key::from_raw(format!("{}/{router_uuid}", endpoint.path()));
let router_key = Key::new(format!("{}/{router_uuid}", endpoint.path()));
let json_router_config = serde_json::to_vec_pretty(&kv_router_config.unwrap_or_default())?;
router_bucket
.insert(&router_key, json_router_config.into(), 0)
......
......@@ -485,11 +485,11 @@ async fn cleanup_orphaned_consumers(
.iter()
.filter_map(|(key, _)| {
// Check if key contains this component's path
if !key.contains(&component_path) {
if !key.as_ref().contains(&component_path) {
return None;
}
// Extract the last part (should be the UUID)
key.split('/').next_back().map(str::to_string)
key.as_ref().split('/').next_back().map(str::to_string)
})
.collect();
......
......@@ -79,6 +79,7 @@ nid = { version = "3.0.0", features = ["serde"] }
nix = { version = "0.29", features = ["signal"] }
nuid = { version = "0.5" }
once_cell = { version = "1" }
percent-encoding = { version = "2.3.2" } # also used by tonic, reqwest, axum, etc
rayon = { version = "1.10" }
regex = { version = "1" }
socket2 = { version = "0.5.8" }
......
......@@ -913,6 +913,7 @@ dependencies = [
"opentelemetry-otlp",
"opentelemetry_sdk",
"parking_lot",
"percent-encoding",
"prometheus",
"rand 0.9.2",
"rayon",
......@@ -1504,7 +1505,7 @@ dependencies = [
"libc",
"percent-encoding",
"pin-project-lite",
"socket2 0.5.9",
"socket2 0.6.1",
"tokio",
"tower-service",
"tracing",
......@@ -2509,9 +2510,9 @@ dependencies = [
[[package]]
name = "percent-encoding"
version = "2.3.1"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220"
[[package]]
name = "pest"
......
......@@ -184,7 +184,7 @@ impl Discovery for KVStoreDiscovery {
key_path
);
let bucket = self.store.get_or_create_bucket(bucket_name, None).await?;
let key = crate::storage::key_value_store::Key::from_raw(key_path.clone());
let key = crate::storage::key_value_store::Key::new(key_path.clone());
tracing::debug!(
"KVStoreDiscovery::register: Inserting into bucket={}, key={}",
......@@ -251,7 +251,7 @@ impl Discovery for KVStoreDiscovery {
return Ok(());
};
let key = crate::storage::key_value_store::Key::from_raw(key_path.clone());
let key = crate::storage::key_value_store::Key::new(key_path.clone());
// Delete the entry from the bucket
bucket.delete(&key).await?;
......@@ -277,12 +277,12 @@ impl Discovery for KVStoreDiscovery {
// Filter by prefix and deserialize
let mut instances = Vec::new();
for (key_str, value) in entries {
if Self::matches_prefix(&key_str, &prefix, bucket_name) {
for (key, value) in entries {
if Self::matches_prefix(key.as_ref(), &prefix, bucket_name) {
match Self::parse_instance(&value) {
Ok(instance) => instances.push(instance),
Err(e) => {
tracing::warn!(key = %key_str, error = %e, "Failed to parse discovery instance");
tracing::warn!(%key, error = %e, "Failed to parse discovery instance");
}
}
}
......
......@@ -4,6 +4,7 @@
//! Interface to a traditional key-value store such as etcd.
//! "key_value_store" spelt out because in AI land "KV" means something else.
use std::borrow::Cow;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
......@@ -12,10 +13,10 @@ use std::{collections::HashMap, path::PathBuf};
use std::{env, fmt};
use crate::CancellationToken;
use crate::slug::Slug;
use crate::transports::etcd as etcd_transport;
use async_trait::async_trait;
use futures::StreamExt;
use percent_encoding::{NON_ALPHANUMERIC, percent_decode_str, percent_encode};
use serde::{Deserialize, Serialize};
mod mem;
......@@ -29,27 +30,32 @@ pub use file::FileStore;
const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(100);
/// A key that is safe to use directly in the KV store.
///
/// TODO: Need to re-think this. etcd uses slash separators, so we often use from_raw
/// to avoid the slug. But other impl's, particularly file, need a real slug.
#[derive(Debug, Clone, PartialEq)]
/// String we use as the Key in a key-value storage operation. Simple String wrapper
/// that can encode / decode a string.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Key(String);
impl Key {
pub fn new(s: &str) -> Key {
Key(Slug::slugify(s).to_string())
pub fn new(s: String) -> Key {
Key(s)
}
/// Create a Key without changing the string, it is assumed already KV store safe.
pub fn from_raw(s: String) -> Key {
Key(s)
/// Takes a URL-safe percent-encoded string and creates a Key from it by decoding first.
/// dynamo%2Fbackend%2Fgenerate%2F17216e63492ef21f becomes dynamo/backend/generate/17216e63492ef21f
pub fn from_url_safe(s: &str) -> Key {
Key(percent_decode_str(s).decode_utf8_lossy().to_string())
}
/// A URL-safe percent-encoded representation of this key.
/// e.g. dynamo/backend/generate/17216e63492ef21f becomes dynamo%2Fbackend%2Fgenerate%2F17216e63492ef21f
pub fn url_safe(&self) -> Cow<'_, str> {
percent_encode(self.0.as_bytes(), NON_ALPHANUMERIC).into()
}
}
impl From<&str> for Key {
fn from(s: &str) -> Key {
Key::new(s)
Key::new(s.to_string())
}
}
......@@ -73,21 +79,21 @@ impl From<&Key> for String {
#[derive(Debug, Clone, PartialEq)]
pub struct KeyValue {
key: String,
key: Key,
value: bytes::Bytes,
}
impl KeyValue {
pub fn new(key: String, value: bytes::Bytes) -> Self {
pub fn new(key: Key, value: bytes::Bytes) -> Self {
KeyValue { key, value }
}
pub fn key(&self) -> String {
self.key.clone()
self.key.clone().to_string()
}
pub fn key_str(&self) -> &str {
&self.key
self.key.as_ref()
}
pub fn value(&self) -> &[u8] {
......@@ -394,6 +400,7 @@ impl KeyValueStoreManager {
pub trait KeyValueBucket: Send + Sync {
/// A bucket is a collection of key/value pairs.
/// Insert a value into a bucket, if it doesn't exist already
/// The Key should be the name of the item, not including the bucket name.
async fn insert(
&self,
key: &Key,
......@@ -402,9 +409,11 @@ pub trait KeyValueBucket: Send + Sync {
) -> Result<StoreOutcome, StoreError>;
/// Fetch an item from the key-value storage
/// The Key should be the name of the item, not including the bucket name.
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError>;
/// Delete an item from the bucket
/// The Key should be the name of the item, not including the bucket name.
async fn delete(&self, key: &Key) -> Result<(), StoreError>;
/// A stream of items inserted into the bucket.
......@@ -414,7 +423,10 @@ pub trait KeyValueBucket: Send + Sync {
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + '_>>, StoreError>;
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError>;
/// The entries in this bucket.
/// The Key includes the full path including the bucket name.
/// That means you cannot directory get a Key from `entries` and pass it to `get` or `delete`.
async fn entries(&self) -> Result<HashMap<Key, bytes::Bytes>, StoreError>;
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
......@@ -527,7 +539,7 @@ mod tests {
let mut expected = Vec::with_capacity(3);
for i in 1..=3 {
let item = WatchEvent::Put(KeyValue::new(
format!("test{i}"),
Key::new(format!("test{i}")),
format!("value{i}").into(),
));
expected.push(item);
......@@ -596,7 +608,7 @@ mod tests {
let mut rx1 = tap.subscribe();
let mut rx2 = tap.subscribe();
let item = WatchEvent::Put(KeyValue::new("test1".to_string(), "GK".into()));
let item = WatchEvent::Put(KeyValue::new(Key::new("test1".to_string()), "GK".into()));
let item_clone = item.clone();
let handle1 = tokio::spawn(async move {
let b = rx1.recv().await.unwrap();
......
......@@ -126,7 +126,7 @@ impl KeyValueBucket for EtcdBucket {
etcd::WatchEvent::Put(kv) => {
let (k, v) = kv.into_key_value();
let key = match String::from_utf8(k) {
Ok(k) => k,
Ok(k) => Key::new(k),
Err(err) => {
tracing::error!(%err, prefix, "Invalid UTF8 in etcd key");
continue;
......@@ -138,13 +138,13 @@ impl KeyValueBucket for EtcdBucket {
etcd::WatchEvent::Delete(kv) => {
let (k, _) = kv.into_key_value();
let key = match String::from_utf8(k) {
Ok(k) => k,
Ok(k) => Key::new(k),
Err(err) => {
tracing::error!(%err, prefix, "Invalid UTF8 in etcd key");
continue;
}
};
yield WatchEvent::Delete(Key::from_raw(key));
yield WatchEvent::Delete(key);
}
}
}
......@@ -152,7 +152,7 @@ impl KeyValueBucket for EtcdBucket {
Ok(Box::pin(output))
}
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
async fn entries(&self) -> Result<HashMap<Key, bytes::Bytes>, StoreError> {
let k = make_key(&self.bucket_name, &"".into());
tracing::trace!("etcd entries: {k}");
......@@ -161,11 +161,11 @@ impl KeyValueBucket for EtcdBucket {
.kv_get_prefix(k)
.await
.map_err(|e| StoreError::EtcdError(e.to_string()))?;
let out: HashMap<String, bytes::Bytes> = resp
let out: HashMap<Key, bytes::Bytes> = resp
.into_iter()
.map(|kv| {
let (k, v) = kv.into_key_value();
(String::from_utf8_lossy(&k).to_string(), v.into())
(Key::new(String::from_utf8_lossy(&k).to_string()), v.into())
})
.collect();
......@@ -287,7 +287,7 @@ mod concurrent_create_tests {
let barrier = Arc::new(Barrier::new(num_workers));
// Shared test data
let test_key: Key = Key::new(&format!("concurrent_test_key_{}", uuid::Uuid::new_v4()));
let test_key: Key = Key::new(format!("concurrent_test_key_{}", uuid::Uuid::new_v4()));
let test_value = "test_value";
// Spawn multiple tasks that will all try to create the same key simultaneously
......
......@@ -286,7 +286,7 @@ impl KeyValueBucket for Directory {
value: bytes::Bytes,
_revision: u64, // Not used. Maybe put in file name?
) -> Result<StoreOutcome, StoreError> {
let safe_key = Key::new(key.as_ref()); // because of from_raw
let safe_key = key.url_safe();
let full_path = self.p.join(safe_key.as_ref());
self.owned_files.lock().insert(full_path.clone());
let str_path = full_path.display().to_string();
......@@ -298,7 +298,7 @@ impl KeyValueBucket for Directory {
/// Read a file from the directory
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
let safe_key = Key::new(key.as_ref()); // because of from_raw
let safe_key = key.url_safe();
let full_path = self.p.join(safe_key.as_ref());
if !full_path.exists() {
return Ok(None);
......@@ -313,7 +313,7 @@ impl KeyValueBucket for Directory {
/// Delete a file from the directory
async fn delete(&self, key: &Key) -> Result<(), StoreError> {
let safe_key = Key::new(key.as_ref()); // because of from_raw
let safe_key = key.url_safe();
let full_path = self.p.join(safe_key.as_ref());
let str_path = full_path.display().to_string();
if !full_path.exists() {
......@@ -374,7 +374,7 @@ impl KeyValueBucket for Directory {
let canonical_item_path = item_path.canonicalize().unwrap_or_else(|_| item_path.clone());
let key = match canonical_item_path.strip_prefix(&root) {
Ok(stripped) => stripped.display().to_string().replace("_", "/"),
Ok(stripped) => Key::from_url_safe(&stripped.display().to_string()),
Err(err) => {
// Possibly this should be a panic.
// A key cannot be outside the file store root.
......@@ -400,7 +400,7 @@ impl KeyValueBucket for Directory {
yield WatchEvent::Put(item);
}
EventKind::Remove(event::RemoveKind::File) => {
yield WatchEvent::Delete(Key::from_raw(key));
yield WatchEvent::Delete(key);
}
_ => {
// These happen every time the keep-alive updates last modified time
......@@ -412,7 +412,7 @@ impl KeyValueBucket for Directory {
}))
}
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
async fn entries(&self) -> Result<HashMap<Key, bytes::Bytes>, StoreError> {
let contents = fs::read_dir(&self.p)
.with_context(|| self.p.display().to_string())
.map_err(a_to_fs_err)?;
......@@ -437,7 +437,7 @@ impl KeyValueBucket for Directory {
};
let key = match canonical_entry_path.strip_prefix(&self.root) {
Ok(p) => p.to_string_lossy().to_string().replace("_", "/"),
Ok(p) => Key::from_url_safe(&p.to_string_lossy()),
Err(err) => {
tracing::error!(
error = %err,
......@@ -482,17 +482,17 @@ mod tests {
let m = FileStore::new(t.path());
let bucket = m.get_or_create_bucket("v1/tests", None).await.unwrap();
let _ = bucket
.insert(&Key::new("key1/multi/part"), "value1".into(), 0)
.insert(&Key::new("key1/multi/part".to_string()), "value1".into(), 0)
.await
.unwrap();
let _ = bucket
.insert(&Key::new("key2"), "value2".into(), 0)
.insert(&Key::new("key2".to_string()), "value2".into(), 0)
.await
.unwrap();
let entries = bucket.entries().await.unwrap();
let keys: HashSet<String> = entries.into_keys().collect();
let keys: HashSet<Key> = entries.into_keys().collect();
assert!(keys.contains("v1/tests/key1/multi/part"));
assert!(keys.contains("v1/tests/key2"));
assert!(keys.contains(&Key::new("v1/tests/key1/multi/part".to_string())));
assert!(keys.contains(&Key::new("v1/tests/key2".to_string())));
}
}
......@@ -182,7 +182,7 @@ impl KeyValueBucket for MemoryBucketRef {
};
for (key, (_rev, v)) in &bucket.data {
seen_keys.insert(key.clone());
let item = KeyValue::new(key.clone(), v.clone());
let item = KeyValue::new(Key::new(key.clone()), v.clone());
existing_items.push(WatchEvent::Put(item));
}
drop(data_lock);
......@@ -203,25 +203,29 @@ impl KeyValueBucket for MemoryBucketRef {
if seen_keys.contains(&key) {
continue;
}
let item = KeyValue::new(key, value);
let item = KeyValue::new(Key::new(key), value);
yield WatchEvent::Put(item);
},
Some(MemoryEvent::Delete { key }) => {
yield WatchEvent::Delete(Key::from_raw(key));
yield WatchEvent::Delete(Key::new(key));
}
}
}
}))
}
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
async fn entries(&self) -> Result<HashMap<Key, bytes::Bytes>, StoreError> {
let locked_data = self.inner.data.lock();
match locked_data.get(&self.name) {
Some(bucket) => Ok(bucket
.data
.iter()
.map(|(k, (_rev, v))| ([self.name.clone(), k.to_string()].join("/"), v.clone()))
.collect()),
Some(bucket) => {
let mut out = HashMap::new();
for (k, (_rev, v)) in bucket.data.iter() {
let key = Key::new([self.name.clone(), k.to_string()].join("/"));
let value = v.clone();
out.insert(key, value);
}
Ok(out)
}
None => Err(StoreError::MissingBucket(self.name.clone())),
}
}
......@@ -240,16 +244,16 @@ mod tests {
let m = MemoryStore::new();
let bucket = m.get_or_create_bucket("bucket1", None).await.unwrap();
let _ = bucket
.insert(&Key::new("key1"), "value1".into(), 0)
.insert(&Key::new("key1".to_string()), "value1".into(), 0)
.await
.unwrap();
let _ = bucket
.insert(&Key::new("key2"), "value2".into(), 0)
.insert(&Key::new("key2".to_string()), "value2".into(), 0)
.await
.unwrap();
let entries = bucket.entries().await.unwrap();
let keys: HashSet<String> = entries.into_keys().collect();
assert!(keys.contains("bucket1/key1"));
assert!(keys.contains("bucket1/key2"));
let keys: HashSet<Key> = entries.into_keys().collect();
assert!(keys.contains(&Key::new("bucket1/key1".to_string())));
assert!(keys.contains(&Key::new("bucket1/key2".to_string())));
}
}
......@@ -165,14 +165,15 @@ impl KeyValueBucket for NATSBucket {
>| async move {
match maybe_entry {
Ok(entry) => {
let key = Key::new(entry.key);
Some(match entry.operation {
Operation::Put => {
let item = KeyValue::new(entry.key, entry.value);
let item = KeyValue::new(key, entry.value);
WatchEvent::Put(item)
}
Operation::Delete => WatchEvent::Delete(Key::from_raw(entry.key)),
Operation::Delete => WatchEvent::Delete(key),
// TODO: What is Purge? Not urgent, NATS impl not used
Operation::Purge => WatchEvent::Delete(Key::from_raw(entry.key)),
Operation::Purge => WatchEvent::Delete(key),
})
}
Err(e) => {
......@@ -185,7 +186,7 @@ impl KeyValueBucket for NATSBucket {
))
}
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
async fn entries(&self) -> Result<HashMap<Key, bytes::Bytes>, StoreError> {
let mut key_stream = self
.nats_store
.keys()
......@@ -194,7 +195,7 @@ impl KeyValueBucket for NATSBucket {
let mut out = HashMap::new();
while let Some(Ok(key)) = key_stream.next().await {
if let Ok(Some(entry)) = self.nats_store.entry(&key).await {
out.insert(key, entry.value);
out.insert(Key::new(key), entry.value);
}
}
Ok(out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment