Unverified Commit 794c0a44 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(keyvalue): Filesystem backed KeyValueStore (#4138)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 3fd0ab3d
...@@ -25,7 +25,7 @@ def dynamo_worker(static=False): ...@@ -25,7 +25,7 @@ def dynamo_worker(static=False):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, static) runtime = DistributedRuntime(loop, "etcd", static)
await func(runtime, *args, **kwargs) await func(runtime, *args, **kwargs)
......
...@@ -256,7 +256,7 @@ async def test_server_context_cancel(server, client): ...@@ -256,7 +256,7 @@ async def test_server_context_cancel(server, client):
except ValueError as e: except ValueError as e:
# Verify the expected cancellation exception is received # Verify the expected cancellation exception is received
# TODO: Should this be a asyncio.CancelledError? # TODO: Should this be a asyncio.CancelledError?
assert str(e) == "Stream ended before generation completed" assert str(e).startswith("Stream ended before generation completed")
# Verify server context cancellation status # Verify server context cancellation status
assert handler.context_is_stopped assert handler.context_is_stopped
......
...@@ -82,20 +82,17 @@ def run_client(example_dir, use_middle=False): ...@@ -82,20 +82,17 @@ def run_client(example_dir, use_middle=False):
) )
# Wait for client to complete # Wait for client to complete
stdout, _ = client_proc.communicate(timeout=1) stdout, _ = client_proc.communicate(timeout=2)
print(f"Client stdout: {stdout}")
if client_proc.returncode != 0:
pytest.fail(
f"Client failed with return code {client_proc.returncode}. Output: {stdout}"
)
return stdout return stdout
def stop_process(process): def stop_process(name, process):
"""Stop a running process and capture its output""" """Stop a running process and capture its output"""
process.terminate() process.terminate()
stdout, _ = process.communicate(timeout=1) stdout, _ = process.communicate(timeout=1)
print(f"{name}: {stdout}")
return stdout return stdout
...@@ -109,7 +106,7 @@ async def test_direct_connection_cancellation(example_dir, server_process): ...@@ -109,7 +106,7 @@ async def test_direct_connection_cancellation(example_dir, server_process):
await asyncio.sleep(1) await asyncio.sleep(1)
# Capture server output # Capture server output
server_output = stop_process(server_process) server_output = stop_process("server_process", server_process)
# Assert expected messages # Assert expected messages
assert ( assert (
...@@ -132,8 +129,8 @@ async def test_middle_server_cancellation( ...@@ -132,8 +129,8 @@ async def test_middle_server_cancellation(
await asyncio.sleep(1) await asyncio.sleep(1)
# Capture output from all processes # Capture output from all processes
server_output = stop_process(server_process) server_output = stop_process("server_process", server_process)
middle_output = stop_process(middle_server_process) middle_output = stop_process("middle_server_process", middle_server_process)
# Assert expected messages # Assert expected messages
assert ( assert (
......
...@@ -153,7 +153,7 @@ def start_nats_and_etcd_default_ports(): ...@@ -153,7 +153,7 @@ def start_nats_and_etcd_default_ports():
print(f"Using ETCD on default client port {etcd_client_port}") print(f"Using ETCD on default client port {etcd_client_port}")
# Start services with default ports # Start services with default ports
nats_server = subprocess.Popen(["nats-server", "-js"]) nats_server = subprocess.Popen(["nats-server", "-js", "--trace"])
etcd = subprocess.Popen(["etcd"]) etcd = subprocess.Popen(["etcd"])
return nats_server, etcd, nats_port, etcd_client_port, nats_data_dir, etcd_data_dir return nats_server, etcd, nats_port, etcd_client_port, nats_data_dir, etcd_data_dir
...@@ -181,6 +181,8 @@ def start_nats_and_etcd_random_ports(): ...@@ -181,6 +181,8 @@ def start_nats_and_etcd_random_ports():
etcd = subprocess.Popen( etcd = subprocess.Popen(
[ [
"etcd", "etcd",
"--logger",
"zap",
"--data-dir", "--data-dir",
str(etcd_data_dir), str(etcd_data_dir),
"--listen-client-urls", "--listen-client-urls",
...@@ -221,7 +223,11 @@ def start_nats_and_etcd_random_ports(): ...@@ -221,7 +223,11 @@ def start_nats_and_etcd_random_ports():
msg = log.get("msg", "") msg = log.get("msg", "")
# Look for the client port # Look for the client port
if "serving client traffic" in msg or "serving client" in msg: if (
"serving client traffic" in msg
or "serving client" in msg
or "serving insecure client" in msg
):
address = log.get("address", "") address = log.get("address", "")
match = re.search(r":(\d+)$", address) match = re.search(r":(\d+)$", address)
if match: if match:
...@@ -430,6 +436,6 @@ This is required because DistributedRuntime is a process-level singleton. ...@@ -430,6 +436,6 @@ This is required because DistributedRuntime is a process-level singleton.
) )
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, True) runtime = DistributedRuntime(loop, "mem", True)
yield runtime yield runtime
runtime.shutdown() runtime.shutdown()
...@@ -34,7 +34,7 @@ async def distributed_runtime(): ...@@ -34,7 +34,7 @@ async def distributed_runtime():
Each test gets its own runtime in a forked process to avoid singleton conflicts. Each test gets its own runtime in a forked process to avoid singleton conflicts.
""" """
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, False) runtime = DistributedRuntime(loop, "etcd", False)
yield runtime yield runtime
runtime.shutdown() runtime.shutdown()
......
...@@ -89,8 +89,8 @@ fn parse_sinks_from_env( ...@@ -89,8 +89,8 @@ fn parse_sinks_from_env(
} }
/// spawn one worker per sink; each subscribes to the bus (off hot path) /// spawn one worker per sink; each subscribes to the bus (off hot path)
pub fn spawn_workers_from_env(drt: Option<&dynamo_runtime::DistributedRuntime>) { pub fn spawn_workers_from_env(drt: &dynamo_runtime::DistributedRuntime) {
let nats_client = drt.and_then(|d| d.nats_client()); let nats_client = drt.nats_client();
let sinks = parse_sinks_from_env(nats_client); let sinks = parse_sinks_from_env(nats_client);
for sink in sinks { for sink in sinks {
let name = sink.name(); let name = sink.name();
......
...@@ -183,8 +183,8 @@ impl ModelWatcher { ...@@ -183,8 +183,8 @@ impl ModelWatcher {
} }
} }
} }
WatchEvent::Delete(kv) => { WatchEvent::Delete(key) => {
let deleted_key = kv.key_str(); let deleted_key = key.as_ref();
match self match self
.handle_delete(deleted_key, target_namespace, global_namespace) .handle_delete(deleted_key, target_namespace, global_namespace)
.await .await
......
...@@ -23,7 +23,6 @@ pub mod http; ...@@ -23,7 +23,6 @@ pub mod http;
pub mod text; pub mod text;
use dynamo_runtime::protocols::ENDPOINT_SCHEME; use dynamo_runtime::protocols::ENDPOINT_SCHEME;
use either::Either;
const BATCH_PREFIX: &str = "batch:"; const BATCH_PREFIX: &str = "batch:";
...@@ -107,15 +106,10 @@ impl Default for Input { ...@@ -107,15 +106,10 @@ impl Default for Input {
/// For Input::Endpoint pass a DistributedRuntime. For everything else pass either a Runtime or a /// For Input::Endpoint pass a DistributedRuntime. For everything else pass either a Runtime or a
/// DistributedRuntime. /// DistributedRuntime.
pub async fn run_input( pub async fn run_input(
rt: Either<dynamo_runtime::Runtime, dynamo_runtime::DistributedRuntime>, drt: dynamo_runtime::DistributedRuntime,
in_opt: Input, in_opt: Input,
engine_config: super::EngineConfig, engine_config: super::EngineConfig,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let runtime = match &rt {
Either::Left(rt) => rt.clone(),
Either::Right(drt) => drt.runtime().clone(),
};
// Initialize audit bus + sink workers (off hot path; fan-out supported) // Initialize audit bus + sink workers (off hot path; fan-out supported)
if crate::audit::config::policy().enabled { if crate::audit::config::policy().enabled {
let cap: usize = std::env::var("DYN_AUDIT_CAPACITY") let cap: usize = std::env::var("DYN_AUDIT_CAPACITY")
...@@ -123,38 +117,30 @@ pub async fn run_input( ...@@ -123,38 +117,30 @@ pub async fn run_input(
.and_then(|v| v.parse().ok()) .and_then(|v| v.parse().ok())
.unwrap_or(1024); .unwrap_or(1024);
crate::audit::bus::init(cap); crate::audit::bus::init(cap);
// Pass DistributedRuntime if available for shared NATS client crate::audit::sink::spawn_workers_from_env(&drt);
let drt_ref = match &rt { tracing::info!(cap, "Audit initialized");
Either::Right(drt) => Some(drt),
Either::Left(_) => None,
};
crate::audit::sink::spawn_workers_from_env(drt_ref);
tracing::info!("Audit initialized: bus cap={}", cap);
} }
match in_opt { match in_opt {
Input::Http => { Input::Http => {
http::run(runtime, engine_config).await?; http::run(drt, engine_config).await?;
} }
Input::Grpc => { Input::Grpc => {
grpc::run(runtime, engine_config).await?; grpc::run(drt, engine_config).await?;
} }
Input::Text => { Input::Text => {
text::run(runtime, None, engine_config).await?; text::run(drt, None, engine_config).await?;
} }
Input::Stdin => { Input::Stdin => {
let mut prompt = String::new(); let mut prompt = String::new();
std::io::stdin().read_to_string(&mut prompt).unwrap(); std::io::stdin().read_to_string(&mut prompt).unwrap();
text::run(runtime, Some(prompt), engine_config).await?; text::run(drt, Some(prompt), engine_config).await?;
} }
Input::Batch(path) => { Input::Batch(path) => {
batch::run(runtime, path, engine_config).await?; batch::run(drt, path, engine_config).await?;
} }
Input::Endpoint(path) => { Input::Endpoint(path) => {
let Either::Right(distributed_runtime) = rt else { endpoint::run(drt, path, engine_config).await?;
anyhow::bail!("Input::Endpoint requires passing a DistributedRuntime");
};
endpoint::run(distributed_runtime, path, engine_config).await?;
} }
} }
Ok(()) Ok(())
......
...@@ -8,7 +8,7 @@ use crate::types::openai::chat_completions::{ ...@@ -8,7 +8,7 @@ use crate::types::openai::chat_completions::{
}; };
use anyhow::Context as _; use anyhow::Context as _;
use dynamo_async_openai::types::FinishReason; use dynamo_async_openai::types::FinishReason;
use dynamo_runtime::{Runtime, pipeline::Context, runtime::CancellationToken}; use dynamo_runtime::{DistributedRuntime, pipeline::Context, runtime::CancellationToken};
use futures::StreamExt; use futures::StreamExt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::cmp; use std::cmp;
...@@ -51,11 +51,11 @@ struct Entry { ...@@ -51,11 +51,11 @@ struct Entry {
} }
pub async fn run( pub async fn run(
runtime: Runtime, distributed_runtime: DistributedRuntime,
input_jsonl: PathBuf, input_jsonl: PathBuf,
engine_config: EngineConfig, engine_config: EngineConfig,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token(); let cancel_token = distributed_runtime.primary_token();
// Check if the path exists and is a directory // Check if the path exists and is a directory
if !input_jsonl.exists() || !input_jsonl.is_file() { if !input_jsonl.exists() || !input_jsonl.is_file() {
anyhow::bail!( anyhow::bail!(
...@@ -64,7 +64,7 @@ pub async fn run( ...@@ -64,7 +64,7 @@ pub async fn run(
); );
} }
let mut prepared_engine = common::prepare_engine(runtime, engine_config).await?; let mut prepared_engine = common::prepare_engine(distributed_runtime, engine_config).await?;
let pre_processor = if prepared_engine.has_tokenizer() { let pre_processor = if prepared_engine.has_tokenizer() {
Some(OpenAIPreprocessor::new( Some(OpenAIPreprocessor::new(
......
...@@ -24,9 +24,8 @@ use crate::{ ...@@ -24,9 +24,8 @@ use crate::{
}; };
use dynamo_runtime::{ use dynamo_runtime::{
DistributedRuntime, Runtime, DistributedRuntime,
component::Client, component::Client,
distributed::DistributedConfig,
engine::{AsyncEngineStream, Data}, engine::{AsyncEngineStream, Data},
pipeline::{ pipeline::{
Context, ManyOut, Operator, PushRouter, RouterMode, SegmentSource, ServiceBackend, Context, ManyOut, Operator, PushRouter, RouterMode, SegmentSource, ServiceBackend,
...@@ -55,23 +54,25 @@ impl PreparedEngine { ...@@ -55,23 +54,25 @@ impl PreparedEngine {
/// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine. /// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine.
pub async fn prepare_engine( pub async fn prepare_engine(
runtime: Runtime, distributed_runtime: DistributedRuntime,
engine_config: EngineConfig, engine_config: EngineConfig,
) -> anyhow::Result<PreparedEngine> { ) -> anyhow::Result<PreparedEngine> {
match engine_config { match engine_config {
EngineConfig::Dynamic(local_model) => { EngineConfig::Dynamic(local_model) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let store = Arc::new(distributed_runtime.store().clone()); let store = Arc::new(distributed_runtime.store().clone());
let model_manager = Arc::new(ModelManager::new()); let model_manager = Arc::new(ModelManager::new());
let watch_obj = Arc::new(ModelWatcher::new( let watch_obj = Arc::new(ModelWatcher::new(
distributed_runtime, distributed_runtime.clone(),
model_manager.clone(), model_manager.clone(),
dynamo_runtime::pipeline::RouterMode::RoundRobin, dynamo_runtime::pipeline::RouterMode::RoundRobin,
None, None,
None, None,
)); ));
let (_, receiver) = store.watch(model_card::ROOT_PATH, None, runtime.primary_token()); let (_, receiver) = store.watch(
model_card::ROOT_PATH,
None,
distributed_runtime.primary_token(),
);
let inner_watch_obj = watch_obj.clone(); let inner_watch_obj = watch_obj.clone();
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
inner_watch_obj.watch(receiver, None).await; inner_watch_obj.watch(receiver, None).await;
...@@ -98,9 +99,6 @@ pub async fn prepare_engine( ...@@ -98,9 +99,6 @@ pub async fn prepare_engine(
let card = local_model.card(); let card = local_model.card();
let router_mode = local_model.router_config().router_mode; let router_mode = local_model.router_config().router_mode;
let dst_config = DistributedConfig::from_settings(true);
let distributed_runtime = DistributedRuntime::new(runtime, dst_config).await?;
let endpoint_id = local_model.endpoint_id(); let endpoint_id = local_model.endpoint_id();
let component = distributed_runtime let component = distributed_runtime
.namespace(&endpoint_id.namespace)? .namespace(&endpoint_id.namespace)?
......
...@@ -16,18 +16,20 @@ use crate::{ ...@@ -16,18 +16,20 @@ use crate::{
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
}, },
}; };
use dynamo_runtime::{DistributedRuntime, Runtime, storage::key_value_store::KeyValueStoreManager}; use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode}; use dynamo_runtime::{DistributedRuntime, storage::key_value_store::KeyValueStoreManager};
/// Build and run an KServe gRPC service /// Build and run an KServe gRPC service
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> { pub async fn run(
distributed_runtime: DistributedRuntime,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let grpc_service_builder = kserve::KserveService::builder() let grpc_service_builder = kserve::KserveService::builder()
.port(engine_config.local_model().http_port()) // [WIP] generalize port.. .port(engine_config.local_model().http_port()) // [WIP] generalize port..
.with_request_template(engine_config.local_model().request_template()); .with_request_template(engine_config.local_model().request_template());
let grpc_service = match engine_config { let grpc_service = match engine_config {
EngineConfig::Dynamic(_) => { EngineConfig::Dynamic(_) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let store = Arc::new(distributed_runtime.store().clone()); let store = Arc::new(distributed_runtime.store().clone());
let grpc_service = grpc_service_builder.build()?; let grpc_service = grpc_service_builder.build()?;
let router_config = engine_config.local_model().router_config(); let router_config = engine_config.local_model().router_config();
...@@ -39,7 +41,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -39,7 +41,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
Some(namespace.to_string()) Some(namespace.to_string())
}; };
run_watcher( run_watcher(
distributed_runtime, distributed_runtime.clone(),
grpc_service.state().manager_clone(), grpc_service.state().manager_clone(),
store, store,
router_config.router_mode, router_config.router_mode,
...@@ -55,8 +57,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -55,8 +57,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let checksum = card.mdcsum(); let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode; let router_mode = local_model.router_config().router_mode;
let dst_config = DistributedConfig::from_settings(true); // true means static
let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
let grpc_service = grpc_service_builder.build()?; let grpc_service = grpc_service_builder.build()?;
let manager = grpc_service.model_manager(); let manager = grpc_service.model_manager();
...@@ -157,8 +157,10 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -157,8 +157,10 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
grpc_service grpc_service
} }
}; };
grpc_service.run(runtime.primary_token()).await?; grpc_service
runtime.shutdown(); // Cancel primary token .run(distributed_runtime.primary_token())
.await?;
distributed_runtime.shutdown(); // Cancel primary token
Ok(()) Ok(())
} }
......
...@@ -17,12 +17,15 @@ use crate::{ ...@@ -17,12 +17,15 @@ use crate::{
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
}, },
}; };
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::storage::key_value_store::KeyValueStoreManager; use dynamo_runtime::storage::key_value_store::KeyValueStoreManager;
use dynamo_runtime::{DistributedRuntime, Runtime};
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode};
/// Build and run an HTTP service /// Build and run an HTTP service
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> { pub async fn run(
distributed_runtime: DistributedRuntime,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let local_model = engine_config.local_model(); let local_model = engine_config.local_model();
let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) { let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
(Some(tls_cert_path), Some(tls_key_path)) => { (Some(tls_cert_path), Some(tls_key_path)) => {
...@@ -63,7 +66,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -63,7 +66,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let http_service = match engine_config { let http_service = match engine_config {
EngineConfig::Dynamic(_) => { EngineConfig::Dynamic(_) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
// This allows the /health endpoint to query store for active instances // This allows the /health endpoint to query store for active instances
http_service_builder = http_service_builder.store(distributed_runtime.store().clone()); http_service_builder = http_service_builder.store(distributed_runtime.store().clone());
let http_service = http_service_builder.build()?; let http_service = http_service_builder.build()?;
...@@ -80,7 +82,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -80,7 +82,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
Some(namespace.to_string()) Some(namespace.to_string())
}; };
run_watcher( run_watcher(
distributed_runtime, distributed_runtime.clone(),
http_service.state().manager_clone(), http_service.state().manager_clone(),
store, store,
router_config.router_mode, router_config.router_mode,
...@@ -96,11 +98,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -96,11 +98,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
EngineConfig::StaticRemote(local_model) => { EngineConfig::StaticRemote(local_model) => {
let card = local_model.card(); let card = local_model.card();
let checksum = card.mdcsum(); let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode; let router_mode = local_model.router_config().router_mode;
let dst_config = DistributedConfig::from_settings(true); // true means static
let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
let http_service = http_service_builder.build()?; let http_service = http_service_builder.build()?;
let manager = http_service.model_manager(); let manager = http_service.model_manager();
...@@ -233,8 +231,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -233,8 +231,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
http_service.custom_backend_metrics_polling_interval, http_service.custom_backend_metrics_polling_interval,
http_service.custom_backend_registry.as_ref(), http_service.custom_backend_registry.as_ref(),
) { ) {
// Create DistributedRuntime for polling, matching the engine's mode
let drt = DistributedRuntime::from_settings(runtime.clone()).await?;
tracing::info!( tracing::info!(
namespace_component_endpoint=%namespace_component_endpoint, namespace_component_endpoint=%namespace_component_endpoint,
polling_interval_secs=polling_interval, polling_interval_secs=polling_interval,
...@@ -246,7 +242,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -246,7 +242,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
// shutdown phase. // shutdown phase.
Some( Some(
crate::http::service::custom_backend_metrics::spawn_custom_backend_polling_task( crate::http::service::custom_backend_metrics::spawn_custom_backend_polling_task(
drt, distributed_runtime.clone(),
namespace_component_endpoint.clone(), namespace_component_endpoint.clone(),
polling_interval, polling_interval,
registry.clone(), registry.clone(),
...@@ -256,14 +252,16 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -256,14 +252,16 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
None None
}; };
http_service.run(runtime.primary_token()).await?; http_service
.run(distributed_runtime.primary_token())
.await?;
// Abort the polling task if it was started // Abort the polling task if it was started
if let Some(task) = polling_task { if let Some(task) = polling_task {
task.abort(); task.abort();
} }
runtime.shutdown(); // Cancel primary token distributed_runtime.shutdown(); // Cancel primary token
Ok(()) Ok(())
} }
......
...@@ -5,7 +5,8 @@ use crate::request_template::RequestTemplate; ...@@ -5,7 +5,8 @@ use crate::request_template::RequestTemplate;
use crate::types::openai::chat_completions::{ use crate::types::openai::chat_completions::{
NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine, NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
}; };
use dynamo_runtime::{Runtime, pipeline::Context, runtime::CancellationToken}; use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::Context;
use futures::StreamExt; use futures::StreamExt;
use std::io::{ErrorKind, Write}; use std::io::{ErrorKind, Write};
...@@ -17,15 +18,15 @@ use crate::entrypoint::input::common; ...@@ -17,15 +18,15 @@ use crate::entrypoint::input::common;
const MAX_TOKENS: u32 = 8192; const MAX_TOKENS: u32 = 8192;
pub async fn run( pub async fn run(
runtime: Runtime, distributed_runtime: DistributedRuntime,
single_prompt: Option<String>, single_prompt: Option<String>,
engine_config: EngineConfig, engine_config: EngineConfig,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token(); let prepared_engine =
let prepared_engine = common::prepare_engine(runtime, engine_config).await?; common::prepare_engine(distributed_runtime.clone(), engine_config).await?;
// TODO: Pass prepared_engine directly // TODO: Pass prepared_engine directly
main_loop( main_loop(
cancel_token, distributed_runtime,
&prepared_engine.service_name, &prepared_engine.service_name,
prepared_engine.engine, prepared_engine.engine,
single_prompt, single_prompt,
...@@ -36,13 +37,14 @@ pub async fn run( ...@@ -36,13 +37,14 @@ pub async fn run(
} }
async fn main_loop( async fn main_loop(
cancel_token: CancellationToken, distributed_runtime: DistributedRuntime,
service_name: &str, service_name: &str,
engine: OpenAIChatCompletionsStreamingEngine, engine: OpenAIChatCompletionsStreamingEngine,
mut initial_prompt: Option<String>, mut initial_prompt: Option<String>,
_inspect_template: bool, _inspect_template: bool,
template: Option<RequestTemplate>, template: Option<RequestTemplate>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancel_token = distributed_runtime.primary_token();
if initial_prompt.is_none() { if initial_prompt.is_none() {
tracing::info!("Ctrl-c to exit"); tracing::info!("Ctrl-c to exit");
} }
...@@ -179,7 +181,11 @@ async fn main_loop( ...@@ -179,7 +181,11 @@ async fn main_loop(
break; break;
} }
} }
cancel_token.cancel(); // stop everything else
println!(); println!();
// Stop the runtime and wait for it to stop
distributed_runtime.shutdown();
cancel_token.cancelled().await;
Ok(()) Ok(())
} }
...@@ -167,7 +167,7 @@ mod tests { ...@@ -167,7 +167,7 @@ mod tests {
bus::init(100); bus::init(100);
let drt = create_test_drt().await; let drt = create_test_drt().await;
sink::spawn_workers_from_env(Some(&drt)); sink::spawn_workers_from_env(&drt);
time::sleep(Duration::from_millis(100)).await; time::sleep(Duration::from_millis(100)).await;
// Emit audit record // Emit audit record
...@@ -224,7 +224,7 @@ mod tests { ...@@ -224,7 +224,7 @@ mod tests {
bus::init(100); bus::init(100);
let drt = create_test_drt().await; let drt = create_test_drt().await;
sink::spawn_workers_from_env(Some(&drt)); sink::spawn_workers_from_env(&drt);
time::sleep(Duration::from_millis(100)).await; time::sleep(Duration::from_millis(100)).await;
// Request with store=true (should be audited) // Request with store=true (should be audited)
......
...@@ -63,6 +63,7 @@ bincode = { version = "1" } ...@@ -63,6 +63,7 @@ bincode = { version = "1" }
console-subscriber = { version = "0.4", optional = true } console-subscriber = { version = "0.4", optional = true }
educe = { version = "0.6.0" } educe = { version = "0.6.0" }
figment = { version = "0.10.19", features = ["env", "json", "toml", "test"] } figment = { version = "0.10.19", features = ["env", "json", "toml", "test"] }
inotify = { version = "0.11" }
local-ip-address = { version = "0.6.3" } local-ip-address = { version = "0.6.3" }
log = { version = "0.4" } log = { version = "0.4" }
nid = { version = "3.0.0", features = ["serde"] } nid = { version = "3.0.0", features = ["serde"] }
......
...@@ -679,6 +679,7 @@ dependencies = [ ...@@ -679,6 +679,7 @@ dependencies = [
"figment", "figment",
"futures", "futures",
"humantime", "humantime",
"inotify",
"local-ip-address", "local-ip-address",
"log", "log",
"nid", "nid",
...@@ -1354,6 +1355,28 @@ version = "0.1.15" ...@@ -1354,6 +1355,28 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "inotify"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3"
dependencies = [
"bitflags 2.9.0",
"futures-core",
"inotify-sys",
"libc",
"tokio",
]
[[package]]
name = "inotify-sys"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "iovec" name = "iovec"
version = "0.1.4" version = "0.1.4"
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::pipeline::{ use crate::{
pipeline::{
AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode, AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode,
SingleIn, SingleIn,
},
storage::key_value_store::{KeyValueStoreManager, WatchEvent},
}; };
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::net::unix::pipe::Receiver; use tokio::net::unix::pipe::Receiver;
use crate::{ use crate::{pipeline::async_trait, transports::etcd::Client as EtcdClient};
pipeline::async_trait,
transports::etcd::{Client as EtcdClient, WatchEvent},
};
use super::*; use super::*;
...@@ -70,12 +70,7 @@ impl Client { ...@@ -70,12 +70,7 @@ impl Client {
const INSTANCE_REFRESH_PERIOD: Duration = Duration::from_secs(1); const INSTANCE_REFRESH_PERIOD: Duration = Duration::from_secs(1);
// create live endpoint watcher // create live endpoint watcher
let Some(etcd_client) = &endpoint.component.drt.etcd_client else { let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?;
anyhow::bail!("Attempt to create a dynamic client on a static endpoint");
};
let instance_source =
Self::get_or_create_dynamic_instance_source(etcd_client, &endpoint).await?;
let client = Client { let client = Client {
endpoint, endpoint,
...@@ -194,7 +189,6 @@ impl Client { ...@@ -194,7 +189,6 @@ impl Client {
} }
async fn get_or_create_dynamic_instance_source( async fn get_or_create_dynamic_instance_source(
etcd_client: &EtcdClient,
endpoint: &Endpoint, endpoint: &Endpoint,
) -> Result<Arc<InstanceSource>> { ) -> Result<Arc<InstanceSource>> {
let drt = endpoint.drt(); let drt = endpoint.drt();
...@@ -209,12 +203,10 @@ impl Client { ...@@ -209,12 +203,10 @@ impl Client {
} }
} }
let prefix_watcher = etcd_client let prefix = endpoint.etcd_root();
.kv_get_and_watch_prefix(endpoint.etcd_root()) let store = Arc::new(drt.store().clone());
.await?; let (_, mut kv_event_rx) =
store.watch(super::INSTANCE_ROOT_PATH, None, drt.primary_token());
let (prefix, mut kv_event_rx) = prefix_watcher.dissolve();
let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]); let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
let secondary = endpoint.component.drt.runtime.secondary().clone(); let secondary = endpoint.component.drt.runtime.secondary().clone();
...@@ -223,7 +215,7 @@ impl Client { ...@@ -223,7 +215,7 @@ impl Client {
// currently this is created once per client, but this object/task should only be instantiated // currently this is created once per client, but this object/task should only be instantiated
// once per worker/instance // once per worker/instance
secondary.spawn(async move { secondary.spawn(async move {
tracing::debug!("Starting endpoint watcher for prefix: {}", prefix); tracing::debug!("Starting endpoint watcher for prefix: {prefix}");
let mut map = HashMap::new(); let mut map = HashMap::new();
loop { loop {
...@@ -245,23 +237,40 @@ impl Client { ...@@ -245,23 +237,40 @@ impl Client {
match kv_event { match kv_event {
WatchEvent::Put(kv) => { WatchEvent::Put(kv) => {
let key = String::from_utf8(kv.key().to_vec()); let key = kv.key_str();
let val = serde_json::from_slice::<Instance>(kv.value()); if !key.starts_with(&prefix) {
if let (Ok(key), Ok(val)) = (key, val) { continue;
map.insert(key.clone(), val);
} else {
tracing::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {prefix}");
break;
} }
let Some(mut key) = key.strip_prefix(super::INSTANCE_ROOT_PATH) else {
tracing::error!("WatchEvent::Put Key not in INSTANCE_ROOT_PATH. Should be impossible.");
continue;
};
if key.starts_with("/") {
key = &key[1..];
} }
WatchEvent::Delete(kv) => {
match String::from_utf8(kv.key().to_vec()) { match serde_json::from_slice::<Instance>(kv.value()) {
Ok(key) => { map.remove(&key); } Ok(val) => map.insert(key.to_string(), val),
Err(_) => { Err(err) => {
tracing::error!("Unable to parse delete endpoint event; shutting down endpoint watcher for prefix: {}", prefix); tracing::error!(error = %err, prefix,
"Unable to parse put endpoint event; shutting down endpoint watcher");
break; break;
} }
};
}
WatchEvent::Delete(key) => {
let key = key.as_ref();
if !key.starts_with(&prefix) {
continue;
}
let Some(mut key) = key.strip_prefix(super::INSTANCE_ROOT_PATH) else {
tracing::error!("WatchEvent::Delete Key not in INSTANCE_ROOT_PATH. Should be impossible.");
continue;
};
if key.starts_with("/") {
key = &key[1..];
} }
map.remove(key);
} }
} }
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
use derive_getters::Dissolve; use derive_getters::Dissolve;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::storage::key_value_store;
use super::*; use super::*;
pub use async_nats::service::endpoint::Stats as EndpointStats; pub use async_nats::service::endpoint::Stats as EndpointStats;
...@@ -118,8 +120,6 @@ impl EndpointConfigBuilder { ...@@ -118,8 +120,6 @@ impl EndpointConfigBuilder {
let endpoint_name = endpoint.name.clone(); let endpoint_name = endpoint.name.clone();
let system_health = endpoint.drt().system_health.clone(); let system_health = endpoint.drt().system_health.clone();
let subject = endpoint.subject_to(connection_id); let subject = endpoint.subject_to(connection_id);
let etcd_path = endpoint.etcd_path_with_lease_id(connection_id);
let etcd_client = endpoint.component.drt.etcd_client.clone();
// Register health check target in SystemHealth if provided // Register health check target in SystemHealth if provided
if let Some(health_check_payload) = &health_check_payload { if let Some(health_check_payload) = &health_check_payload {
...@@ -193,9 +193,6 @@ impl EndpointConfigBuilder { ...@@ -193,9 +193,6 @@ impl EndpointConfigBuilder {
result result
}); });
// make the components service endpoint discovery in etcd
// client.register_service()
let info = Instance { let info = Instance {
component: component_name.clone(), component: component_name.clone(),
endpoint: endpoint_name.clone(), endpoint: endpoint_name.clone(),
...@@ -206,15 +203,16 @@ impl EndpointConfigBuilder { ...@@ -206,15 +203,16 @@ impl EndpointConfigBuilder {
let info = serde_json::to_vec_pretty(&info)?; let info = serde_json::to_vec_pretty(&info)?;
if let Some(etcd_client) = &etcd_client let store = endpoint.drt().store();
&& let Err(e) = etcd_client let instances_bucket = store
.kv_create(&etcd_path, info, Some(connection_id)) .get_or_create_bucket(super::INSTANCE_ROOT_PATH, None)
.await .await?;
{ let key = key_value_store::Key::from_raw(endpoint.unique_path(connection_id));
if let Err(err) = instances_bucket.insert(&key, info.into(), 0).await {
tracing::error!( tracing::error!(
component_name, component_name,
endpoint_name, endpoint_name,
error = %e, error = %err,
"Unable to register service for discovery" "Unable to register service for discovery"
); );
endpoint_shutdown_token.cancel(); endpoint_shutdown_token.cancel();
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
pub use crate::component::Component; pub use crate::component::Component;
use crate::storage::key_value_store::{ use crate::storage::key_value_store::{
EtcdStore, KeyValueStore, KeyValueStoreEnum, KeyValueStoreManager, MemoryStore, EtcdStore, KeyValueStore, KeyValueStoreEnum, KeyValueStoreManager, KeyValueStoreSelect,
MemoryStore,
}; };
use crate::transports::nats::DRTNatsClientPrometheusMetrics; use crate::transports::nats::DRTNatsClientPrometheusMetrics;
use crate::{ use crate::{
...@@ -48,24 +49,23 @@ impl std::fmt::Debug for DistributedRuntime { ...@@ -48,24 +49,23 @@ impl std::fmt::Debug for DistributedRuntime {
impl DistributedRuntime { impl DistributedRuntime {
pub async fn new(runtime: Runtime, config: DistributedConfig) -> Result<Self> { pub async fn new(runtime: Runtime, config: DistributedConfig) -> Result<Self> {
let (etcd_config, nats_config, is_static) = config.dissolve(); let (selected_kv_store, nats_config, is_static) = config.dissolve();
let runtime_clone = runtime.clone(); let runtime_clone = runtime.clone();
// TODO: Here is where we will later select the KeyValueStore impl let (etcd_client, store) = match (is_static, selected_kv_store) {
let (etcd_client, store) = if is_static { (false, KeyValueStoreSelect::Etcd(etcd_config)) => {
(None, KeyValueStoreManager::memory()) let etcd_client = etcd::Client::new(*etcd_config, runtime_clone).await.inspect_err(|err|
} else { // The returned error doesn't show because of a dropped runtime error, so
match etcd::Client::new(etcd_config.clone(), runtime_clone).await { // log it first.
Ok(etcd_client) => { tracing::error!(%err, "Could not connect to etcd. Pass `--store-kv ..` to use a different backend or start etcd."))?;
let store = KeyValueStoreManager::etcd(etcd_client.clone()); let store = KeyValueStoreManager::etcd(etcd_client.clone());
(Some(etcd_client), store) (Some(etcd_client), store)
} }
Err(err) => { (false, KeyValueStoreSelect::File(root)) => (None, KeyValueStoreManager::file(root)),
tracing::info!(%err, "Did not connect to etcd. Using memory storage."); (true, _) | (false, KeyValueStoreSelect::Memory) => {
(None, KeyValueStoreManager::memory()) (None, KeyValueStoreManager::memory())
} }
}
}; };
let nats_client = Some(nats_config.clone().connect().await?); let nats_client = Some(nats_config.clone().connect().await?);
...@@ -234,6 +234,7 @@ impl DistributedRuntime { ...@@ -234,6 +234,7 @@ impl DistributedRuntime {
pub fn shutdown(&self) { pub fn shutdown(&self) {
self.runtime.shutdown(); self.runtime.shutdown();
self.store.shutdown();
} }
/// Create a [`Namespace`] /// Create a [`Namespace`]
...@@ -302,7 +303,7 @@ impl DistributedRuntime { ...@@ -302,7 +303,7 @@ impl DistributedRuntime {
#[derive(Dissolve)] #[derive(Dissolve)]
pub struct DistributedConfig { pub struct DistributedConfig {
pub etcd_config: etcd::ClientOptions, pub store_backend: KeyValueStoreSelect,
pub nats_config: nats::ClientOptions, pub nats_config: nats::ClientOptions,
pub is_static: bool, pub is_static: bool,
} }
...@@ -310,22 +311,22 @@ pub struct DistributedConfig { ...@@ -310,22 +311,22 @@ pub struct DistributedConfig {
impl DistributedConfig { impl DistributedConfig {
pub fn from_settings(is_static: bool) -> DistributedConfig { pub fn from_settings(is_static: bool) -> DistributedConfig {
DistributedConfig { DistributedConfig {
etcd_config: etcd::ClientOptions::default(), store_backend: KeyValueStoreSelect::Etcd(Box::default()),
nats_config: nats::ClientOptions::default(), nats_config: nats::ClientOptions::default(),
is_static, is_static,
} }
} }
pub fn for_cli() -> DistributedConfig { pub fn for_cli() -> DistributedConfig {
let mut config = DistributedConfig { let etcd_config = etcd::ClientOptions {
etcd_config: etcd::ClientOptions::default(), attach_lease: false,
..Default::default()
};
DistributedConfig {
store_backend: KeyValueStoreSelect::Etcd(Box::new(etcd_config)),
nats_config: nats::ClientOptions::default(), nats_config: nats::ClientOptions::default(),
is_static: false, is_static: false,
}; }
config.etcd_config.attach_lease = false;
config
} }
} }
......
...@@ -4,14 +4,16 @@ ...@@ -4,14 +4,16 @@
//! Interface to a traditional key-value store such as etcd. //! Interface to a traditional key-value store such as etcd.
//! "key_value_store" spelt out because in AI land "KV" means something else. //! "key_value_store" spelt out because in AI land "KV" means something else.
use std::collections::HashMap;
use std::fmt;
use std::pin::Pin; use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::{collections::HashMap, path::PathBuf};
use std::{env, fmt};
use crate::CancellationToken; use crate::CancellationToken;
use crate::slug::Slug; use crate::slug::Slug;
use crate::transports::etcd as etcd_transport;
use async_trait::async_trait; use async_trait::async_trait;
use futures::StreamExt; use futures::StreamExt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -22,10 +24,15 @@ mod nats; ...@@ -22,10 +24,15 @@ mod nats;
pub use nats::NATSStore; pub use nats::NATSStore;
mod etcd; mod etcd;
pub use etcd::EtcdStore; pub use etcd::EtcdStore;
mod file;
pub use file::FileStore;
const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(100); const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(100);
/// A key that is safe to use directly in the KV store. /// A key that is safe to use directly in the KV store.
///
/// TODO: Need to re-think this. etcd uses slash separators, so we often use from_raw
/// to avoid the slug. But other impl's, particularly file, need a real slug.
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct Key(String); pub struct Key(String);
...@@ -95,7 +102,7 @@ impl KeyValue { ...@@ -95,7 +102,7 @@ impl KeyValue {
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum WatchEvent { pub enum WatchEvent {
Put(KeyValue), Put(KeyValue),
Delete(KeyValue), Delete(Key),
} }
#[async_trait] #[async_trait]
...@@ -112,6 +119,57 @@ pub trait KeyValueStore: Send + Sync { ...@@ -112,6 +119,57 @@ pub trait KeyValueStore: Send + Sync {
async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError>; async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError>;
fn connection_id(&self) -> u64; fn connection_id(&self) -> u64;
fn shutdown(&self);
}
#[derive(Clone, Debug, Default)]
pub enum KeyValueStoreSelect {
// Box it because it is significantly bigger than the other variants
Etcd(Box<etcd_transport::ClientOptions>),
File(PathBuf),
#[default]
Memory,
// Nats not listed because likely we want to remove that impl. It is not currently used and not well tested.
}
impl fmt::Display for KeyValueStoreSelect {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
KeyValueStoreSelect::Etcd(opts) => {
let urls = opts.etcd_url.join(",");
write!(f, "Etcd({urls})")
}
KeyValueStoreSelect::File(path) => write!(f, "File({})", path.display()),
KeyValueStoreSelect::Memory => write!(f, "Memory"),
}
}
}
impl FromStr for KeyValueStoreSelect {
type Err = anyhow::Error;
fn from_str(s: &str) -> anyhow::Result<KeyValueStoreSelect> {
match s {
"etcd" => Ok(Self::Etcd(Box::default())),
"file" => {
let root = env::var("DYN_FILE_KV")
.map(PathBuf::from)
.unwrap_or_else(|_| env::temp_dir().join("dynamo_store_kv"));
Ok(Self::File(root))
}
"mem" => Ok(Self::Memory),
x => anyhow::bail!("Unknown key-value store type '{x}'"),
}
}
}
impl TryFrom<String> for KeyValueStoreSelect {
type Error = anyhow::Error;
fn try_from(s: String) -> anyhow::Result<KeyValueStoreSelect> {
s.parse()
}
} }
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]
...@@ -119,6 +177,7 @@ pub enum KeyValueStoreEnum { ...@@ -119,6 +177,7 @@ pub enum KeyValueStoreEnum {
Memory(MemoryStore), Memory(MemoryStore),
Nats(NATSStore), Nats(NATSStore),
Etcd(EtcdStore), Etcd(EtcdStore),
File(FileStore),
} }
impl KeyValueStoreEnum { impl KeyValueStoreEnum {
...@@ -133,6 +192,7 @@ impl KeyValueStoreEnum { ...@@ -133,6 +192,7 @@ impl KeyValueStoreEnum {
Memory(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?), Memory(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
Nats(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?), Nats(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
Etcd(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?), Etcd(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
File(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
}) })
} }
...@@ -154,6 +214,10 @@ impl KeyValueStoreEnum { ...@@ -154,6 +214,10 @@ impl KeyValueStoreEnum {
.get_bucket(bucket_name) .get_bucket(bucket_name)
.await? .await?
.map(|b| Box::new(b) as Box<dyn KeyValueBucket>), .map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
File(x) => x
.get_bucket(bucket_name)
.await?
.map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
}; };
Ok(maybe_bucket) Ok(maybe_bucket)
} }
...@@ -164,12 +228,23 @@ impl KeyValueStoreEnum { ...@@ -164,12 +228,23 @@ impl KeyValueStoreEnum {
Memory(x) => x.connection_id(), Memory(x) => x.connection_id(),
Etcd(x) => x.connection_id(), Etcd(x) => x.connection_id(),
Nats(x) => x.connection_id(), Nats(x) => x.connection_id(),
File(x) => x.connection_id(),
}
}
fn shutdown(&self) {
use KeyValueStoreEnum::*;
match self {
Memory(x) => x.shutdown(),
Etcd(x) => x.shutdown(),
Nats(x) => x.shutdown(),
File(x) => x.shutdown(),
} }
} }
} }
#[derive(Clone)] #[derive(Clone)]
pub struct KeyValueStoreManager(Arc<KeyValueStoreEnum>); pub struct KeyValueStoreManager(pub Arc<KeyValueStoreEnum>);
impl Default for KeyValueStoreManager { impl Default for KeyValueStoreManager {
fn default() -> Self { fn default() -> Self {
...@@ -187,6 +262,10 @@ impl KeyValueStoreManager { ...@@ -187,6 +262,10 @@ impl KeyValueStoreManager {
Self::new(KeyValueStoreEnum::Etcd(EtcdStore::new(etcd_client))) Self::new(KeyValueStoreEnum::Etcd(EtcdStore::new(etcd_client)))
} }
pub fn file<P: Into<PathBuf>>(root: P) -> Self {
Self::new(KeyValueStoreEnum::File(FileStore::new(root)))
}
fn new(s: KeyValueStoreEnum) -> KeyValueStoreManager { fn new(s: KeyValueStoreEnum) -> KeyValueStoreManager {
KeyValueStoreManager(Arc::new(s)) KeyValueStoreManager(Arc::new(s))
} }
...@@ -302,6 +381,12 @@ impl KeyValueStoreManager { ...@@ -302,6 +381,12 @@ impl KeyValueStoreManager {
} }
Ok(outcome) Ok(outcome)
} }
/// Cleanup any temporary state.
/// TODO: Should this be async? Take &mut self?
pub fn shutdown(&self) {
self.0.shutdown()
}
} }
/// An online storage for key-value config values. /// An online storage for key-value config values.
...@@ -366,6 +451,9 @@ pub enum StoreError { ...@@ -366,6 +451,9 @@ pub enum StoreError {
#[error("Internal etcd error: {0}")] #[error("Internal etcd error: {0}")]
EtcdError(String), EtcdError(String),
#[error("Internal filesystem error: {0}")]
FilesystemError(String),
#[error("Key Value Error: {0} for bucket '{1}'")] #[error("Key Value Error: {0} for bucket '{1}'")]
KeyValueError(String, String), KeyValueError(String, String),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment