Unverified Commit 794c0a44 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(keyvalue): Filesystem backed KeyValueStore (#4138)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 3fd0ab3d
......@@ -25,7 +25,7 @@ def dynamo_worker(static=False):
@wraps(func)
async def wrapper(*args, **kwargs):
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, static)
runtime = DistributedRuntime(loop, "etcd", static)
await func(runtime, *args, **kwargs)
......
......@@ -256,7 +256,7 @@ async def test_server_context_cancel(server, client):
except ValueError as e:
# Verify the expected cancellation exception is received
# TODO: Should this be a asyncio.CancelledError?
assert str(e) == "Stream ended before generation completed"
assert str(e).startswith("Stream ended before generation completed")
# Verify server context cancellation status
assert handler.context_is_stopped
......
......@@ -82,20 +82,17 @@ def run_client(example_dir, use_middle=False):
)
# Wait for client to complete
stdout, _ = client_proc.communicate(timeout=1)
if client_proc.returncode != 0:
pytest.fail(
f"Client failed with return code {client_proc.returncode}. Output: {stdout}"
)
stdout, _ = client_proc.communicate(timeout=2)
print(f"Client stdout: {stdout}")
return stdout
def stop_process(process):
def stop_process(name, process):
"""Stop a running process and capture its output"""
process.terminate()
stdout, _ = process.communicate(timeout=1)
print(f"{name}: {stdout}")
return stdout
......@@ -109,7 +106,7 @@ async def test_direct_connection_cancellation(example_dir, server_process):
await asyncio.sleep(1)
# Capture server output
server_output = stop_process(server_process)
server_output = stop_process("server_process", server_process)
# Assert expected messages
assert (
......@@ -132,8 +129,8 @@ async def test_middle_server_cancellation(
await asyncio.sleep(1)
# Capture output from all processes
server_output = stop_process(server_process)
middle_output = stop_process(middle_server_process)
server_output = stop_process("server_process", server_process)
middle_output = stop_process("middle_server_process", middle_server_process)
# Assert expected messages
assert (
......
......@@ -153,7 +153,7 @@ def start_nats_and_etcd_default_ports():
print(f"Using ETCD on default client port {etcd_client_port}")
# Start services with default ports
nats_server = subprocess.Popen(["nats-server", "-js"])
nats_server = subprocess.Popen(["nats-server", "-js", "--trace"])
etcd = subprocess.Popen(["etcd"])
return nats_server, etcd, nats_port, etcd_client_port, nats_data_dir, etcd_data_dir
......@@ -181,6 +181,8 @@ def start_nats_and_etcd_random_ports():
etcd = subprocess.Popen(
[
"etcd",
"--logger",
"zap",
"--data-dir",
str(etcd_data_dir),
"--listen-client-urls",
......@@ -221,7 +223,11 @@ def start_nats_and_etcd_random_ports():
msg = log.get("msg", "")
# Look for the client port
if "serving client traffic" in msg or "serving client" in msg:
if (
"serving client traffic" in msg
or "serving client" in msg
or "serving insecure client" in msg
):
address = log.get("address", "")
match = re.search(r":(\d+)$", address)
if match:
......@@ -430,6 +436,6 @@ This is required because DistributedRuntime is a process-level singleton.
)
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, True)
runtime = DistributedRuntime(loop, "mem", True)
yield runtime
runtime.shutdown()
......@@ -34,7 +34,7 @@ async def distributed_runtime():
Each test gets its own runtime in a forked process to avoid singleton conflicts.
"""
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, False)
runtime = DistributedRuntime(loop, "etcd", False)
yield runtime
runtime.shutdown()
......
......@@ -89,8 +89,8 @@ fn parse_sinks_from_env(
}
/// spawn one worker per sink; each subscribes to the bus (off hot path)
pub fn spawn_workers_from_env(drt: Option<&dynamo_runtime::DistributedRuntime>) {
let nats_client = drt.and_then(|d| d.nats_client());
pub fn spawn_workers_from_env(drt: &dynamo_runtime::DistributedRuntime) {
let nats_client = drt.nats_client();
let sinks = parse_sinks_from_env(nats_client);
for sink in sinks {
let name = sink.name();
......
......@@ -183,8 +183,8 @@ impl ModelWatcher {
}
}
}
WatchEvent::Delete(kv) => {
let deleted_key = kv.key_str();
WatchEvent::Delete(key) => {
let deleted_key = key.as_ref();
match self
.handle_delete(deleted_key, target_namespace, global_namespace)
.await
......
......@@ -23,7 +23,6 @@ pub mod http;
pub mod text;
use dynamo_runtime::protocols::ENDPOINT_SCHEME;
use either::Either;
const BATCH_PREFIX: &str = "batch:";
......@@ -107,15 +106,10 @@ impl Default for Input {
/// For Input::Endpoint pass a DistributedRuntime. For everything else pass either a Runtime or a
/// DistributedRuntime.
pub async fn run_input(
rt: Either<dynamo_runtime::Runtime, dynamo_runtime::DistributedRuntime>,
drt: dynamo_runtime::DistributedRuntime,
in_opt: Input,
engine_config: super::EngineConfig,
) -> anyhow::Result<()> {
let runtime = match &rt {
Either::Left(rt) => rt.clone(),
Either::Right(drt) => drt.runtime().clone(),
};
// Initialize audit bus + sink workers (off hot path; fan-out supported)
if crate::audit::config::policy().enabled {
let cap: usize = std::env::var("DYN_AUDIT_CAPACITY")
......@@ -123,38 +117,30 @@ pub async fn run_input(
.and_then(|v| v.parse().ok())
.unwrap_or(1024);
crate::audit::bus::init(cap);
// Pass DistributedRuntime if available for shared NATS client
let drt_ref = match &rt {
Either::Right(drt) => Some(drt),
Either::Left(_) => None,
};
crate::audit::sink::spawn_workers_from_env(drt_ref);
tracing::info!("Audit initialized: bus cap={}", cap);
crate::audit::sink::spawn_workers_from_env(&drt);
tracing::info!(cap, "Audit initialized");
}
match in_opt {
Input::Http => {
http::run(runtime, engine_config).await?;
http::run(drt, engine_config).await?;
}
Input::Grpc => {
grpc::run(runtime, engine_config).await?;
grpc::run(drt, engine_config).await?;
}
Input::Text => {
text::run(runtime, None, engine_config).await?;
text::run(drt, None, engine_config).await?;
}
Input::Stdin => {
let mut prompt = String::new();
std::io::stdin().read_to_string(&mut prompt).unwrap();
text::run(runtime, Some(prompt), engine_config).await?;
text::run(drt, Some(prompt), engine_config).await?;
}
Input::Batch(path) => {
batch::run(runtime, path, engine_config).await?;
batch::run(drt, path, engine_config).await?;
}
Input::Endpoint(path) => {
let Either::Right(distributed_runtime) = rt else {
anyhow::bail!("Input::Endpoint requires passing a DistributedRuntime");
};
endpoint::run(distributed_runtime, path, engine_config).await?;
endpoint::run(drt, path, engine_config).await?;
}
}
Ok(())
......
......@@ -8,7 +8,7 @@ use crate::types::openai::chat_completions::{
};
use anyhow::Context as _;
use dynamo_async_openai::types::FinishReason;
use dynamo_runtime::{Runtime, pipeline::Context, runtime::CancellationToken};
use dynamo_runtime::{DistributedRuntime, pipeline::Context, runtime::CancellationToken};
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use std::cmp;
......@@ -51,11 +51,11 @@ struct Entry {
}
pub async fn run(
runtime: Runtime,
distributed_runtime: DistributedRuntime,
input_jsonl: PathBuf,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token();
let cancel_token = distributed_runtime.primary_token();
// Check if the path exists and is a directory
if !input_jsonl.exists() || !input_jsonl.is_file() {
anyhow::bail!(
......@@ -64,7 +64,7 @@ pub async fn run(
);
}
let mut prepared_engine = common::prepare_engine(runtime, engine_config).await?;
let mut prepared_engine = common::prepare_engine(distributed_runtime, engine_config).await?;
let pre_processor = if prepared_engine.has_tokenizer() {
Some(OpenAIPreprocessor::new(
......
......@@ -24,9 +24,8 @@ use crate::{
};
use dynamo_runtime::{
DistributedRuntime, Runtime,
DistributedRuntime,
component::Client,
distributed::DistributedConfig,
engine::{AsyncEngineStream, Data},
pipeline::{
Context, ManyOut, Operator, PushRouter, RouterMode, SegmentSource, ServiceBackend,
......@@ -55,23 +54,25 @@ impl PreparedEngine {
/// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine.
pub async fn prepare_engine(
runtime: Runtime,
distributed_runtime: DistributedRuntime,
engine_config: EngineConfig,
) -> anyhow::Result<PreparedEngine> {
match engine_config {
EngineConfig::Dynamic(local_model) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let store = Arc::new(distributed_runtime.store().clone());
let model_manager = Arc::new(ModelManager::new());
let watch_obj = Arc::new(ModelWatcher::new(
distributed_runtime,
distributed_runtime.clone(),
model_manager.clone(),
dynamo_runtime::pipeline::RouterMode::RoundRobin,
None,
None,
));
let (_, receiver) = store.watch(model_card::ROOT_PATH, None, runtime.primary_token());
let (_, receiver) = store.watch(
model_card::ROOT_PATH,
None,
distributed_runtime.primary_token(),
);
let inner_watch_obj = watch_obj.clone();
let _watcher_task = tokio::spawn(async move {
inner_watch_obj.watch(receiver, None).await;
......@@ -98,9 +99,6 @@ pub async fn prepare_engine(
let card = local_model.card();
let router_mode = local_model.router_config().router_mode;
let dst_config = DistributedConfig::from_settings(true);
let distributed_runtime = DistributedRuntime::new(runtime, dst_config).await?;
let endpoint_id = local_model.endpoint_id();
let component = distributed_runtime
.namespace(&endpoint_id.namespace)?
......
......@@ -16,18 +16,20 @@ use crate::{
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
};
use dynamo_runtime::{DistributedRuntime, Runtime, storage::key_value_store::KeyValueStoreManager};
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode};
use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::{DistributedRuntime, storage::key_value_store::KeyValueStoreManager};
/// Build and run an KServe gRPC service
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
pub async fn run(
distributed_runtime: DistributedRuntime,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let grpc_service_builder = kserve::KserveService::builder()
.port(engine_config.local_model().http_port()) // [WIP] generalize port..
.with_request_template(engine_config.local_model().request_template());
let grpc_service = match engine_config {
EngineConfig::Dynamic(_) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let store = Arc::new(distributed_runtime.store().clone());
let grpc_service = grpc_service_builder.build()?;
let router_config = engine_config.local_model().router_config();
......@@ -39,7 +41,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
Some(namespace.to_string())
};
run_watcher(
distributed_runtime,
distributed_runtime.clone(),
grpc_service.state().manager_clone(),
store,
router_config.router_mode,
......@@ -55,8 +57,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode;
let dst_config = DistributedConfig::from_settings(true); // true means static
let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
let grpc_service = grpc_service_builder.build()?;
let manager = grpc_service.model_manager();
......@@ -157,8 +157,10 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
grpc_service
}
};
grpc_service.run(runtime.primary_token()).await?;
runtime.shutdown(); // Cancel primary token
grpc_service
.run(distributed_runtime.primary_token())
.await?;
distributed_runtime.shutdown(); // Cancel primary token
Ok(())
}
......
......@@ -17,12 +17,15 @@ use crate::{
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
};
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::storage::key_value_store::KeyValueStoreManager;
use dynamo_runtime::{DistributedRuntime, Runtime};
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode};
/// Build and run an HTTP service
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
pub async fn run(
distributed_runtime: DistributedRuntime,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let local_model = engine_config.local_model();
let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
(Some(tls_cert_path), Some(tls_key_path)) => {
......@@ -63,7 +66,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let http_service = match engine_config {
EngineConfig::Dynamic(_) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
// This allows the /health endpoint to query store for active instances
http_service_builder = http_service_builder.store(distributed_runtime.store().clone());
let http_service = http_service_builder.build()?;
......@@ -80,7 +82,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
Some(namespace.to_string())
};
run_watcher(
distributed_runtime,
distributed_runtime.clone(),
http_service.state().manager_clone(),
store,
router_config.router_mode,
......@@ -96,11 +98,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
EngineConfig::StaticRemote(local_model) => {
let card = local_model.card();
let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode;
let dst_config = DistributedConfig::from_settings(true); // true means static
let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
let http_service = http_service_builder.build()?;
let manager = http_service.model_manager();
......@@ -233,8 +231,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
http_service.custom_backend_metrics_polling_interval,
http_service.custom_backend_registry.as_ref(),
) {
// Create DistributedRuntime for polling, matching the engine's mode
let drt = DistributedRuntime::from_settings(runtime.clone()).await?;
tracing::info!(
namespace_component_endpoint=%namespace_component_endpoint,
polling_interval_secs=polling_interval,
......@@ -246,7 +242,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
// shutdown phase.
Some(
crate::http::service::custom_backend_metrics::spawn_custom_backend_polling_task(
drt,
distributed_runtime.clone(),
namespace_component_endpoint.clone(),
polling_interval,
registry.clone(),
......@@ -256,14 +252,16 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
None
};
http_service.run(runtime.primary_token()).await?;
http_service
.run(distributed_runtime.primary_token())
.await?;
// Abort the polling task if it was started
if let Some(task) = polling_task {
task.abort();
}
runtime.shutdown(); // Cancel primary token
distributed_runtime.shutdown(); // Cancel primary token
Ok(())
}
......
......@@ -5,7 +5,8 @@ use crate::request_template::RequestTemplate;
use crate::types::openai::chat_completions::{
NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
};
use dynamo_runtime::{Runtime, pipeline::Context, runtime::CancellationToken};
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::Context;
use futures::StreamExt;
use std::io::{ErrorKind, Write};
......@@ -17,15 +18,15 @@ use crate::entrypoint::input::common;
const MAX_TOKENS: u32 = 8192;
pub async fn run(
runtime: Runtime,
distributed_runtime: DistributedRuntime,
single_prompt: Option<String>,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token();
let prepared_engine = common::prepare_engine(runtime, engine_config).await?;
let prepared_engine =
common::prepare_engine(distributed_runtime.clone(), engine_config).await?;
// TODO: Pass prepared_engine directly
main_loop(
cancel_token,
distributed_runtime,
&prepared_engine.service_name,
prepared_engine.engine,
single_prompt,
......@@ -36,13 +37,14 @@ pub async fn run(
}
async fn main_loop(
cancel_token: CancellationToken,
distributed_runtime: DistributedRuntime,
service_name: &str,
engine: OpenAIChatCompletionsStreamingEngine,
mut initial_prompt: Option<String>,
_inspect_template: bool,
template: Option<RequestTemplate>,
) -> anyhow::Result<()> {
let cancel_token = distributed_runtime.primary_token();
if initial_prompt.is_none() {
tracing::info!("Ctrl-c to exit");
}
......@@ -179,7 +181,11 @@ async fn main_loop(
break;
}
}
cancel_token.cancel(); // stop everything else
println!();
// Stop the runtime and wait for it to stop
distributed_runtime.shutdown();
cancel_token.cancelled().await;
Ok(())
}
......@@ -167,7 +167,7 @@ mod tests {
bus::init(100);
let drt = create_test_drt().await;
sink::spawn_workers_from_env(Some(&drt));
sink::spawn_workers_from_env(&drt);
time::sleep(Duration::from_millis(100)).await;
// Emit audit record
......@@ -224,7 +224,7 @@ mod tests {
bus::init(100);
let drt = create_test_drt().await;
sink::spawn_workers_from_env(Some(&drt));
sink::spawn_workers_from_env(&drt);
time::sleep(Duration::from_millis(100)).await;
// Request with store=true (should be audited)
......
......@@ -63,6 +63,7 @@ bincode = { version = "1" }
console-subscriber = { version = "0.4", optional = true }
educe = { version = "0.6.0" }
figment = { version = "0.10.19", features = ["env", "json", "toml", "test"] }
inotify = { version = "0.11" }
local-ip-address = { version = "0.6.3" }
log = { version = "0.4" }
nid = { version = "3.0.0", features = ["serde"] }
......
......@@ -679,6 +679,7 @@ dependencies = [
"figment",
"futures",
"humantime",
"inotify",
"local-ip-address",
"log",
"nid",
......@@ -1354,6 +1355,28 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "inotify"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3"
dependencies = [
"bitflags 2.9.0",
"futures-core",
"inotify-sys",
"libc",
"tokio",
]
[[package]]
name = "inotify-sys"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb"
dependencies = [
"libc",
]
[[package]]
name = "iovec"
version = "0.1.4"
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::pipeline::{
AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode,
SingleIn,
use crate::{
pipeline::{
AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode,
SingleIn,
},
storage::key_value_store::{KeyValueStoreManager, WatchEvent},
};
use arc_swap::ArcSwap;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::unix::pipe::Receiver;
use crate::{
pipeline::async_trait,
transports::etcd::{Client as EtcdClient, WatchEvent},
};
use crate::{pipeline::async_trait, transports::etcd::Client as EtcdClient};
use super::*;
......@@ -70,12 +70,7 @@ impl Client {
const INSTANCE_REFRESH_PERIOD: Duration = Duration::from_secs(1);
// create live endpoint watcher
let Some(etcd_client) = &endpoint.component.drt.etcd_client else {
anyhow::bail!("Attempt to create a dynamic client on a static endpoint");
};
let instance_source =
Self::get_or_create_dynamic_instance_source(etcd_client, &endpoint).await?;
let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?;
let client = Client {
endpoint,
......@@ -194,7 +189,6 @@ impl Client {
}
async fn get_or_create_dynamic_instance_source(
etcd_client: &EtcdClient,
endpoint: &Endpoint,
) -> Result<Arc<InstanceSource>> {
let drt = endpoint.drt();
......@@ -209,12 +203,10 @@ impl Client {
}
}
let prefix_watcher = etcd_client
.kv_get_and_watch_prefix(endpoint.etcd_root())
.await?;
let (prefix, mut kv_event_rx) = prefix_watcher.dissolve();
let prefix = endpoint.etcd_root();
let store = Arc::new(drt.store().clone());
let (_, mut kv_event_rx) =
store.watch(super::INSTANCE_ROOT_PATH, None, drt.primary_token());
let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
let secondary = endpoint.component.drt.runtime.secondary().clone();
......@@ -223,7 +215,7 @@ impl Client {
// currently this is created once per client, but this object/task should only be instantiated
// once per worker/instance
secondary.spawn(async move {
tracing::debug!("Starting endpoint watcher for prefix: {}", prefix);
tracing::debug!("Starting endpoint watcher for prefix: {prefix}");
let mut map = HashMap::new();
loop {
......@@ -245,23 +237,40 @@ impl Client {
match kv_event {
WatchEvent::Put(kv) => {
let key = String::from_utf8(kv.key().to_vec());
let val = serde_json::from_slice::<Instance>(kv.value());
if let (Ok(key), Ok(val)) = (key, val) {
map.insert(key.clone(), val);
} else {
tracing::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {prefix}");
break;
let key = kv.key_str();
if !key.starts_with(&prefix) {
continue;
}
}
WatchEvent::Delete(kv) => {
match String::from_utf8(kv.key().to_vec()) {
Ok(key) => { map.remove(&key); }
Err(_) => {
tracing::error!("Unable to parse delete endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
let Some(mut key) = key.strip_prefix(super::INSTANCE_ROOT_PATH) else {
tracing::error!("WatchEvent::Put Key not in INSTANCE_ROOT_PATH. Should be impossible.");
continue;
};
if key.starts_with("/") {
key = &key[1..];
}
match serde_json::from_slice::<Instance>(kv.value()) {
Ok(val) => map.insert(key.to_string(), val),
Err(err) => {
tracing::error!(error = %err, prefix,
"Unable to parse put endpoint event; shutting down endpoint watcher");
break;
}
};
}
WatchEvent::Delete(key) => {
let key = key.as_ref();
if !key.starts_with(&prefix) {
continue;
}
let Some(mut key) = key.strip_prefix(super::INSTANCE_ROOT_PATH) else {
tracing::error!("WatchEvent::Delete Key not in INSTANCE_ROOT_PATH. Should be impossible.");
continue;
};
if key.starts_with("/") {
key = &key[1..];
}
map.remove(key);
}
}
......
......@@ -4,6 +4,8 @@
use derive_getters::Dissolve;
use tokio_util::sync::CancellationToken;
use crate::storage::key_value_store;
use super::*;
pub use async_nats::service::endpoint::Stats as EndpointStats;
......@@ -118,8 +120,6 @@ impl EndpointConfigBuilder {
let endpoint_name = endpoint.name.clone();
let system_health = endpoint.drt().system_health.clone();
let subject = endpoint.subject_to(connection_id);
let etcd_path = endpoint.etcd_path_with_lease_id(connection_id);
let etcd_client = endpoint.component.drt.etcd_client.clone();
// Register health check target in SystemHealth if provided
if let Some(health_check_payload) = &health_check_payload {
......@@ -193,9 +193,6 @@ impl EndpointConfigBuilder {
result
});
// make the components service endpoint discovery in etcd
// client.register_service()
let info = Instance {
component: component_name.clone(),
endpoint: endpoint_name.clone(),
......@@ -206,15 +203,16 @@ impl EndpointConfigBuilder {
let info = serde_json::to_vec_pretty(&info)?;
if let Some(etcd_client) = &etcd_client
&& let Err(e) = etcd_client
.kv_create(&etcd_path, info, Some(connection_id))
.await
{
let store = endpoint.drt().store();
let instances_bucket = store
.get_or_create_bucket(super::INSTANCE_ROOT_PATH, None)
.await?;
let key = key_value_store::Key::from_raw(endpoint.unique_path(connection_id));
if let Err(err) = instances_bucket.insert(&key, info.into(), 0).await {
tracing::error!(
component_name,
endpoint_name,
error = %e,
error = %err,
"Unable to register service for discovery"
);
endpoint_shutdown_token.cancel();
......
......@@ -3,7 +3,8 @@
pub use crate::component::Component;
use crate::storage::key_value_store::{
EtcdStore, KeyValueStore, KeyValueStoreEnum, KeyValueStoreManager, MemoryStore,
EtcdStore, KeyValueStore, KeyValueStoreEnum, KeyValueStoreManager, KeyValueStoreSelect,
MemoryStore,
};
use crate::transports::nats::DRTNatsClientPrometheusMetrics;
use crate::{
......@@ -48,23 +49,22 @@ impl std::fmt::Debug for DistributedRuntime {
impl DistributedRuntime {
pub async fn new(runtime: Runtime, config: DistributedConfig) -> Result<Self> {
let (etcd_config, nats_config, is_static) = config.dissolve();
let (selected_kv_store, nats_config, is_static) = config.dissolve();
let runtime_clone = runtime.clone();
// TODO: Here is where we will later select the KeyValueStore impl
let (etcd_client, store) = if is_static {
(None, KeyValueStoreManager::memory())
} else {
match etcd::Client::new(etcd_config.clone(), runtime_clone).await {
Ok(etcd_client) => {
let store = KeyValueStoreManager::etcd(etcd_client.clone());
(Some(etcd_client), store)
}
Err(err) => {
tracing::info!(%err, "Did not connect to etcd. Using memory storage.");
(None, KeyValueStoreManager::memory())
}
let (etcd_client, store) = match (is_static, selected_kv_store) {
(false, KeyValueStoreSelect::Etcd(etcd_config)) => {
let etcd_client = etcd::Client::new(*etcd_config, runtime_clone).await.inspect_err(|err|
// The returned error doesn't show because of a dropped runtime error, so
// log it first.
tracing::error!(%err, "Could not connect to etcd. Pass `--store-kv ..` to use a different backend or start etcd."))?;
let store = KeyValueStoreManager::etcd(etcd_client.clone());
(Some(etcd_client), store)
}
(false, KeyValueStoreSelect::File(root)) => (None, KeyValueStoreManager::file(root)),
(true, _) | (false, KeyValueStoreSelect::Memory) => {
(None, KeyValueStoreManager::memory())
}
};
......@@ -234,6 +234,7 @@ impl DistributedRuntime {
pub fn shutdown(&self) {
self.runtime.shutdown();
self.store.shutdown();
}
/// Create a [`Namespace`]
......@@ -302,7 +303,7 @@ impl DistributedRuntime {
#[derive(Dissolve)]
pub struct DistributedConfig {
pub etcd_config: etcd::ClientOptions,
pub store_backend: KeyValueStoreSelect,
pub nats_config: nats::ClientOptions,
pub is_static: bool,
}
......@@ -310,22 +311,22 @@ pub struct DistributedConfig {
impl DistributedConfig {
pub fn from_settings(is_static: bool) -> DistributedConfig {
DistributedConfig {
etcd_config: etcd::ClientOptions::default(),
store_backend: KeyValueStoreSelect::Etcd(Box::default()),
nats_config: nats::ClientOptions::default(),
is_static,
}
}
pub fn for_cli() -> DistributedConfig {
let mut config = DistributedConfig {
etcd_config: etcd::ClientOptions::default(),
let etcd_config = etcd::ClientOptions {
attach_lease: false,
..Default::default()
};
DistributedConfig {
store_backend: KeyValueStoreSelect::Etcd(Box::new(etcd_config)),
nats_config: nats::ClientOptions::default(),
is_static: false,
};
config.etcd_config.attach_lease = false;
config
}
}
}
......
......@@ -4,14 +4,16 @@
//! Interface to a traditional key-value store such as etcd.
//! "key_value_store" spelt out because in AI land "KV" means something else.
use std::collections::HashMap;
use std::fmt;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::{collections::HashMap, path::PathBuf};
use std::{env, fmt};
use crate::CancellationToken;
use crate::slug::Slug;
use crate::transports::etcd as etcd_transport;
use async_trait::async_trait;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
......@@ -22,10 +24,15 @@ mod nats;
pub use nats::NATSStore;
mod etcd;
pub use etcd::EtcdStore;
mod file;
pub use file::FileStore;
const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(100);
/// A key that is safe to use directly in the KV store.
///
/// TODO: Need to re-think this. etcd uses slash separators, so we often use from_raw
/// to avoid the slug. But other impl's, particularly file, need a real slug.
#[derive(Debug, Clone, PartialEq)]
pub struct Key(String);
......@@ -95,7 +102,7 @@ impl KeyValue {
#[derive(Debug, Clone, PartialEq)]
pub enum WatchEvent {
Put(KeyValue),
Delete(KeyValue),
Delete(Key),
}
#[async_trait]
......@@ -112,6 +119,57 @@ pub trait KeyValueStore: Send + Sync {
async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError>;
fn connection_id(&self) -> u64;
fn shutdown(&self);
}
#[derive(Clone, Debug, Default)]
pub enum KeyValueStoreSelect {
// Box it because it is significantly bigger than the other variants
Etcd(Box<etcd_transport::ClientOptions>),
File(PathBuf),
#[default]
Memory,
// Nats not listed because likely we want to remove that impl. It is not currently used and not well tested.
}
impl fmt::Display for KeyValueStoreSelect {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
KeyValueStoreSelect::Etcd(opts) => {
let urls = opts.etcd_url.join(",");
write!(f, "Etcd({urls})")
}
KeyValueStoreSelect::File(path) => write!(f, "File({})", path.display()),
KeyValueStoreSelect::Memory => write!(f, "Memory"),
}
}
}
impl FromStr for KeyValueStoreSelect {
type Err = anyhow::Error;
fn from_str(s: &str) -> anyhow::Result<KeyValueStoreSelect> {
match s {
"etcd" => Ok(Self::Etcd(Box::default())),
"file" => {
let root = env::var("DYN_FILE_KV")
.map(PathBuf::from)
.unwrap_or_else(|_| env::temp_dir().join("dynamo_store_kv"));
Ok(Self::File(root))
}
"mem" => Ok(Self::Memory),
x => anyhow::bail!("Unknown key-value store type '{x}'"),
}
}
}
impl TryFrom<String> for KeyValueStoreSelect {
type Error = anyhow::Error;
fn try_from(s: String) -> anyhow::Result<KeyValueStoreSelect> {
s.parse()
}
}
#[allow(clippy::large_enum_variant)]
......@@ -119,6 +177,7 @@ pub enum KeyValueStoreEnum {
Memory(MemoryStore),
Nats(NATSStore),
Etcd(EtcdStore),
File(FileStore),
}
impl KeyValueStoreEnum {
......@@ -133,6 +192,7 @@ impl KeyValueStoreEnum {
Memory(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
Nats(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
Etcd(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
File(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
})
}
......@@ -154,6 +214,10 @@ impl KeyValueStoreEnum {
.get_bucket(bucket_name)
.await?
.map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
File(x) => x
.get_bucket(bucket_name)
.await?
.map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
};
Ok(maybe_bucket)
}
......@@ -164,12 +228,23 @@ impl KeyValueStoreEnum {
Memory(x) => x.connection_id(),
Etcd(x) => x.connection_id(),
Nats(x) => x.connection_id(),
File(x) => x.connection_id(),
}
}
fn shutdown(&self) {
use KeyValueStoreEnum::*;
match self {
Memory(x) => x.shutdown(),
Etcd(x) => x.shutdown(),
Nats(x) => x.shutdown(),
File(x) => x.shutdown(),
}
}
}
#[derive(Clone)]
pub struct KeyValueStoreManager(Arc<KeyValueStoreEnum>);
pub struct KeyValueStoreManager(pub Arc<KeyValueStoreEnum>);
impl Default for KeyValueStoreManager {
fn default() -> Self {
......@@ -187,6 +262,10 @@ impl KeyValueStoreManager {
Self::new(KeyValueStoreEnum::Etcd(EtcdStore::new(etcd_client)))
}
pub fn file<P: Into<PathBuf>>(root: P) -> Self {
Self::new(KeyValueStoreEnum::File(FileStore::new(root)))
}
fn new(s: KeyValueStoreEnum) -> KeyValueStoreManager {
KeyValueStoreManager(Arc::new(s))
}
......@@ -302,6 +381,12 @@ impl KeyValueStoreManager {
}
Ok(outcome)
}
/// Cleanup any temporary state.
/// TODO: Should this be async? Take &mut self?
pub fn shutdown(&self) {
self.0.shutdown()
}
}
/// An online storage for key-value config values.
......@@ -366,6 +451,9 @@ pub enum StoreError {
#[error("Internal etcd error: {0}")]
EtcdError(String),
#[error("Internal filesystem error: {0}")]
FilesystemError(String),
#[error("Key Value Error: {0} for bucket '{1}'")]
KeyValueError(String, String),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment