Unverified Commit 794c0a44 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(keyvalue): Filesystem backed KeyValueStore (#4138)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 3fd0ab3d
......@@ -2391,6 +2391,7 @@ dependencies = [
"figment",
"futures",
"humantime",
"inotify",
"jsonschema",
"local-ip-address",
"log",
......@@ -3992,6 +3993,28 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "inotify"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3"
dependencies = [
"bitflags 2.9.4",
"futures-core",
"inotify-sys",
"libc",
"tokio",
]
[[package]]
name = "inotify-sys"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb"
dependencies = [
"libc",
]
[[package]]
name = "insta"
version = "1.43.2"
......
......@@ -225,6 +225,12 @@ def parse_args():
),
help=f"Interval in seconds for polling custom backend metrics. Set to > 0 to enable polling (default: 0=disabled, suggested: 9.2s which is less than typical Prometheus scrape interval). Can be set via {CUSTOM_BACKEND_METRICS_POLLING_INTERVAL_ENV_VAR} env var.",
)
parser.add_argument(
"--store-kv",
type=str,
default=os.environ.get("DYN_STORE_KV", "etcd"),
help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
)
flags = parser.parse_args()
......@@ -252,8 +258,7 @@ async def async_main():
os.environ["DYN_METRICS_PREFIX"] = flags.metrics_prefix
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, is_static)
runtime = DistributedRuntime(loop, flags.store_kv, is_static)
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
......
......@@ -204,6 +204,12 @@ def parse_args():
default=False,
help="Mark this as a decode worker which does not publish KV events and skips prefill cost estimation (default: False)",
)
parser.add_argument(
"--store-kv",
type=str,
default=os.environ.get("DYN_STORE_KV", "etcd"),
help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
)
args = parser.parse_args()
validate_worker_type_args(args)
......
......@@ -72,7 +72,7 @@ async def launch_workers(args, extra_engine_args_path):
logger.info(f"Creating mocker worker {worker_id + 1}/{args.num_workers}")
# Create a separate DistributedRuntime for this worker (on same event loop)
runtime = DistributedRuntime(loop, False)
runtime = DistributedRuntime(loop, args.store_kv, False)
runtimes.append(runtime)
# Create EntrypointArgs for this worker
......
......@@ -93,6 +93,12 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"default": None,
"help": "Dump debug config to the specified file path. If not specified, the config will be dumped to stdout at INFO level.",
},
"store-kv": {
"flags": ["--store-kv"],
"type": str,
"default": os.environ.get("DYN_STORE_KV", "etcd"),
"help": "Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
},
}
......@@ -102,6 +108,7 @@ class DynamoArgs:
component: str
endpoint: str
migration_limit: int
store_kv: str
# tool and reasoning parser options
tool_call_parser: Optional[str] = None
......@@ -329,6 +336,7 @@ async def parse_args(args: list[str]) -> Config:
component=parsed_component_name,
endpoint=parsed_endpoint_name,
migration_limit=parsed_args.migration_limit,
store_kv=parsed_args.store_kv,
tool_call_parser=tool_call_parser,
reasoning_parser=reasoning_parser,
custom_jinja_template=expanded_template_path,
......
......@@ -11,7 +11,7 @@ import uvloop
from dynamo.common.config_dump import dump_config
from dynamo.llm import ModelInput, ModelType
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.args import Config, DisaggregationMode, parse_args
from dynamo.sglang.health_check import (
......@@ -33,9 +33,12 @@ from dynamo.sglang.request_handlers import (
configure_dynamo_logging()
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
async def worker():
config = await parse_args(sys.argv[1:])
dump_config(config.dynamo_args.dump_config_to, config)
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, config.dynamo_args.store_kv, False)
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
......@@ -45,9 +48,6 @@ async def worker(runtime: DistributedRuntime):
logging.info("Signal handlers will trigger a graceful shutdown of the runtime")
config = await parse_args(sys.argv[1:])
dump_config(config.dynamo_args.dump_config_to, config)
if config.dynamo_args.embedding_worker:
await init_embedding(runtime, config)
elif config.dynamo_args.multimodal_processor:
......
......@@ -39,7 +39,7 @@ import dynamo.nixl_connect as nixl_connect
from dynamo.common.config_dump import dump_config
from dynamo.common.utils.prometheus import register_engine_metrics_callback
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import TensorRTLLMEngine, get_llm_engine
from dynamo.trtllm.health_check import TrtllmHealthCheckPayload
......@@ -102,11 +102,13 @@ async def get_engine_runtime_config(
return runtime_config
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
# Set up signal handler for graceful shutdown
async def worker():
config = cmd_line_args()
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, config.store_kv, False)
# Set up signal handler for graceful shutdown
def signal_handler():
# Schedule the shutdown coroutine instead of calling it directly
asyncio.create_task(graceful_shutdown(runtime))
......@@ -116,7 +118,6 @@ async def worker(runtime: DistributedRuntime):
logging.info("Signal handlers set up for graceful shutdown")
config = cmd_line_args()
await init(runtime, config)
......
......@@ -58,6 +58,7 @@ class Config:
self.tool_call_parser: Optional[str] = None
self.dump_config_to: Optional[str] = None
self.custom_jinja_template: Optional[str] = None
self.store_kv: str = ""
def __str__(self) -> str:
return (
......@@ -87,8 +88,9 @@ class Config:
f"max_file_size_mb={self.max_file_size_mb}, "
f"reasoning_parser={self.reasoning_parser}, "
f"tool_call_parser={self.tool_call_parser}, "
f"dump_config_to={self.dump_config_to},"
f"custom_jinja_template={self.custom_jinja_template}"
f"dump_config_to={self.dump_config_to}, "
f"custom_jinja_template={self.custom_jinja_template}, "
f"store_kv={self.store_kv}"
)
......@@ -278,6 +280,12 @@ def cmd_line_args():
default=None,
help="Path to a custom Jinja template file to override the model's default chat template. This template will take precedence over any template found in the model repository.",
)
parser.add_argument(
"--store-kv",
type=str,
default=os.environ.get("DYN_STORE_KV", "etcd"),
help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
)
args = parser.parse_args()
......@@ -337,6 +345,7 @@ def cmd_line_args():
config.reasoning_parser = args.dyn_reasoning_parser
config.tool_call_parser = args.dyn_tool_call_parser
config.dump_config_to = args.dump_config_to
config.store_kv = args.store_kv
# Handle custom jinja template path expansion (environment variables and home directory)
if args.custom_jinja_template:
......
......@@ -38,6 +38,7 @@ class Config:
migration_limit: int = 0
kv_port: Optional[int] = None
custom_jinja_template: Optional[str] = None
store_kv: str
# mirror vLLM
model: str
......@@ -164,6 +165,12 @@ def parse_args() -> Config:
"'USER: <image> please describe the image ASSISTANT:'."
),
)
parser.add_argument(
"--store-kv",
type=str,
default=os.environ.get("DYN_STORE_KV", "etcd"),
help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
)
add_config_dump_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
......@@ -233,6 +240,7 @@ def parse_args() -> Config:
config.multimodal_worker = args.multimodal_worker
config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker
config.mm_prompt_template = args.mm_prompt_template
config.store_kv = args.store_kv
# Validate custom Jinja template file exists if provided
if config.custom_jinja_template is not None:
......
......@@ -25,7 +25,7 @@ from dynamo.llm import (
fetch_llm,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.multimodal_handlers import (
EncodeWorkerHandler,
......@@ -70,16 +70,16 @@ async def graceful_shutdown(runtime):
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
async def worker():
config = parse_args()
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, config.store_kv, False)
await configure_ports(config)
overwrite_args(config)
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
......
......@@ -50,7 +50,7 @@ async def main():
return
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, True)
runtime = DistributedRuntime(loop, "mem", True)
# Connect to middle server or direct server based on argument
if use_middle_server:
......
......@@ -50,7 +50,7 @@ class MiddleServer:
async def main():
"""Start the middle server"""
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, True)
runtime = DistributedRuntime(loop, "mem", True)
# Create middle server handler
handler = MiddleServer(runtime)
......
......@@ -31,7 +31,7 @@ class DemoServer:
async def main():
"""Start the demo server"""
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, True)
runtime = DistributedRuntime(loop, "mem", True)
# Create server component
component = runtime.namespace("demo").component("server")
......
......@@ -123,7 +123,7 @@ async def async_main():
# Create DistributedRuntime - similar to frontend/main.py line 246
is_static = True # Use static mode (no etcd)
runtime = DistributedRuntime(loop, is_static) # type: ignore[call-arg]
runtime = DistributedRuntime(loop, "mem", is_static) # type: ignore[call-arg]
# Setup signal handlers for graceful shutdown
def signal_handler():
......
......@@ -127,6 +127,12 @@ pub struct Flags {
#[arg(long, default_value = "false")]
pub static_worker: bool,
/// Which key-value backend to use: etcd, mem, file.
/// Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details.
/// File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.
#[arg(long, default_value = "etcd")]
pub store_kv: String,
/// Everything after a `--`.
/// These are the command line arguments to the python engine when using `pystr` or `pytok`.
#[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)]
......
......@@ -6,10 +6,11 @@ use dynamo_llm::entrypoint::EngineConfig;
use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
use dynamo_runtime::distributed::DistributedConfig;
use dynamo_runtime::storage::key_value_store::KeyValueStoreSelect;
use dynamo_runtime::transports::nats;
use dynamo_runtime::{DistributedRuntime, Runtime};
mod flags;
use either::Either;
pub use flags::Flags;
mod opt;
pub use dynamo_llm::request_template::RequestTemplate;
......@@ -73,14 +74,16 @@ pub async fn run(
// TODO: old, address this later:
// If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we let LocalModel invent one.
let mut rt = Either::Left(runtime.clone());
if let Input::Endpoint(path) = &in_opt {
builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?));
let dst_config = DistributedConfig::from_settings(flags.static_worker);
let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
rt = Either::Right(distributed_runtime);
}
let selected_store: KeyValueStoreSelect = flags.store_kv.parse()?;
let dst_config = DistributedConfig {
store_backend: selected_store,
nats_config: nats::ClientOptions::default(),
is_static: flags.static_worker,
};
let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
if let Some(Output::Static(path)) = &out_opt {
builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?));
}
......@@ -98,10 +101,16 @@ pub async fn run(
flags.validate(&in_opt, &out_opt)?;
// Make an engine from the local_model, flags and output.
let engine_config = engine_for(out_opt, flags.clone(), local_model, rt.clone()).await?;
let engine_config = engine_for(
out_opt,
flags.clone(),
local_model,
distributed_runtime.clone(),
)
.await?;
// Run it from an input
dynamo_llm::entrypoint::input::run_input(rt, in_opt, engine_config).await?;
dynamo_llm::entrypoint::input::run_input(distributed_runtime, in_opt, engine_config).await?;
Ok(())
}
......@@ -112,7 +121,7 @@ async fn engine_for(
out_opt: Output,
flags: Flags,
local_model: LocalModel,
rt: Either<Runtime, DistributedRuntime>,
drt: DistributedRuntime,
) -> anyhow::Result<EngineConfig> {
match out_opt {
Output::Auto => {
......@@ -135,10 +144,6 @@ async fn engine_for(
is_static: flags.static_worker,
}),
Output::Mocker => {
let Either::Right(drt) = rt else {
panic!("Mocker requires a distributed runtime to run.");
};
let args = flags.mocker_config();
let endpoint = local_model.endpoint_id().clone();
......
......@@ -1606,6 +1606,7 @@ dependencies = [
"figment",
"futures",
"humantime",
"inotify",
"local-ip-address",
"log",
"nid",
......@@ -2857,6 +2858,28 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "inotify"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3"
dependencies = [
"bitflags 2.9.3",
"futures-core",
"inotify-sys",
"libc",
"tokio",
]
[[package]]
name = "inotify-sys"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb"
dependencies = [
"libc",
]
[[package]]
name = "instant"
version = "0.1.13"
......
......@@ -115,7 +115,7 @@ def parse_args():
async def run():
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, False)
runtime = DistributedRuntime(loop, "etcd", False)
args = parse_args()
......
......@@ -2,6 +2,8 @@
// SPDX-License-Identifier: Apache-2.0
use dynamo_llm::local_model::LocalModel;
use dynamo_runtime::distributed::DistributedConfig;
use dynamo_runtime::storage::key_value_store::KeyValueStoreSelect;
use futures::StreamExt;
use once_cell::sync::OnceCell;
use pyo3::IntoPyObjectExt;
......@@ -426,7 +428,9 @@ enum ModelInput {
#[pymethods]
impl DistributedRuntime {
#[new]
fn new(event_loop: PyObject, is_static: bool) -> PyResult<Self> {
fn new(event_loop: PyObject, store_kv: String, is_static: bool) -> PyResult<Self> {
let selected_kv_store: KeyValueStoreSelect = store_kv.parse().map_err(to_pyerr)?;
// Try to get existing runtime first, create new Worker only if needed
// This allows multiple DistributedRuntime instances to share the same tokio runtime
let runtime = rs::Worker::runtime_from_existing()
......@@ -464,9 +468,14 @@ impl DistributedRuntime {
rs::DistributedRuntime::from_settings_without_discovery(runtime),
)
} else {
let config = DistributedConfig {
store_backend: selected_kv_store,
is_static: false,
nats_config: dynamo_runtime::transports::nats::ClientOptions::default(),
};
runtime
.secondary()
.block_on(rs::DistributedRuntime::from_settings(runtime))
.block_on(rs::DistributedRuntime::new(runtime, config))
};
let inner = inner.map_err(to_pyerr)?;
......@@ -628,7 +637,7 @@ impl DistributedRuntime {
}
fn shutdown(&self) {
self.inner.runtime().shutdown();
self.inner.shutdown();
}
fn event_loop(&self) -> PyObject {
......
......@@ -299,7 +299,7 @@ pub fn run_input<'p>(
let input_enum: Input = input.parse().map_err(to_pyerr)?;
pyo3_async_runtimes::tokio::future_into_py(py, async move {
dynamo_llm::entrypoint::input::run_input(
either::Either::Right(distributed_runtime.inner.clone()),
distributed_runtime.inner.clone(),
input_enum,
engine_config.inner,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment