Unverified Commit 794c0a44 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(keyvalue): Filesystem backed KeyValueStore (#4138)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 3fd0ab3d
...@@ -2391,6 +2391,7 @@ dependencies = [ ...@@ -2391,6 +2391,7 @@ dependencies = [
"figment", "figment",
"futures", "futures",
"humantime", "humantime",
"inotify",
"jsonschema", "jsonschema",
"local-ip-address", "local-ip-address",
"log", "log",
...@@ -3992,6 +3993,28 @@ version = "0.1.15" ...@@ -3992,6 +3993,28 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "inotify"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3"
dependencies = [
"bitflags 2.9.4",
"futures-core",
"inotify-sys",
"libc",
"tokio",
]
[[package]]
name = "inotify-sys"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "insta" name = "insta"
version = "1.43.2" version = "1.43.2"
......
...@@ -225,6 +225,12 @@ def parse_args(): ...@@ -225,6 +225,12 @@ def parse_args():
), ),
help=f"Interval in seconds for polling custom backend metrics. Set to > 0 to enable polling (default: 0=disabled, suggested: 9.2s which is less than typical Prometheus scrape interval). Can be set via {CUSTOM_BACKEND_METRICS_POLLING_INTERVAL_ENV_VAR} env var.", help=f"Interval in seconds for polling custom backend metrics. Set to > 0 to enable polling (default: 0=disabled, suggested: 9.2s which is less than typical Prometheus scrape interval). Can be set via {CUSTOM_BACKEND_METRICS_POLLING_INTERVAL_ENV_VAR} env var.",
) )
parser.add_argument(
"--store-kv",
type=str,
default=os.environ.get("DYN_STORE_KV", "etcd"),
help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
)
flags = parser.parse_args() flags = parser.parse_args()
...@@ -252,8 +258,7 @@ async def async_main(): ...@@ -252,8 +258,7 @@ async def async_main():
os.environ["DYN_METRICS_PREFIX"] = flags.metrics_prefix os.environ["DYN_METRICS_PREFIX"] = flags.metrics_prefix
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, flags.store_kv, is_static)
runtime = DistributedRuntime(loop, is_static)
def signal_handler(): def signal_handler():
asyncio.create_task(graceful_shutdown(runtime)) asyncio.create_task(graceful_shutdown(runtime))
......
...@@ -204,6 +204,12 @@ def parse_args(): ...@@ -204,6 +204,12 @@ def parse_args():
default=False, default=False,
help="Mark this as a decode worker which does not publish KV events and skips prefill cost estimation (default: False)", help="Mark this as a decode worker which does not publish KV events and skips prefill cost estimation (default: False)",
) )
parser.add_argument(
"--store-kv",
type=str,
default=os.environ.get("DYN_STORE_KV", "etcd"),
help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
)
args = parser.parse_args() args = parser.parse_args()
validate_worker_type_args(args) validate_worker_type_args(args)
......
...@@ -72,7 +72,7 @@ async def launch_workers(args, extra_engine_args_path): ...@@ -72,7 +72,7 @@ async def launch_workers(args, extra_engine_args_path):
logger.info(f"Creating mocker worker {worker_id + 1}/{args.num_workers}") logger.info(f"Creating mocker worker {worker_id + 1}/{args.num_workers}")
# Create a separate DistributedRuntime for this worker (on same event loop) # Create a separate DistributedRuntime for this worker (on same event loop)
runtime = DistributedRuntime(loop, False) runtime = DistributedRuntime(loop, args.store_kv, False)
runtimes.append(runtime) runtimes.append(runtime)
# Create EntrypointArgs for this worker # Create EntrypointArgs for this worker
......
...@@ -93,6 +93,12 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = { ...@@ -93,6 +93,12 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"default": None, "default": None,
"help": "Dump debug config to the specified file path. If not specified, the config will be dumped to stdout at INFO level.", "help": "Dump debug config to the specified file path. If not specified, the config will be dumped to stdout at INFO level.",
}, },
"store-kv": {
"flags": ["--store-kv"],
"type": str,
"default": os.environ.get("DYN_STORE_KV", "etcd"),
"help": "Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
},
} }
...@@ -102,6 +108,7 @@ class DynamoArgs: ...@@ -102,6 +108,7 @@ class DynamoArgs:
component: str component: str
endpoint: str endpoint: str
migration_limit: int migration_limit: int
store_kv: str
# tool and reasoning parser options # tool and reasoning parser options
tool_call_parser: Optional[str] = None tool_call_parser: Optional[str] = None
...@@ -329,6 +336,7 @@ async def parse_args(args: list[str]) -> Config: ...@@ -329,6 +336,7 @@ async def parse_args(args: list[str]) -> Config:
component=parsed_component_name, component=parsed_component_name,
endpoint=parsed_endpoint_name, endpoint=parsed_endpoint_name,
migration_limit=parsed_args.migration_limit, migration_limit=parsed_args.migration_limit,
store_kv=parsed_args.store_kv,
tool_call_parser=tool_call_parser, tool_call_parser=tool_call_parser,
reasoning_parser=reasoning_parser, reasoning_parser=reasoning_parser,
custom_jinja_template=expanded_template_path, custom_jinja_template=expanded_template_path,
......
...@@ -11,7 +11,7 @@ import uvloop ...@@ -11,7 +11,7 @@ import uvloop
from dynamo.common.config_dump import dump_config from dynamo.common.config_dump import dump_config
from dynamo.llm import ModelInput, ModelType from dynamo.llm import ModelInput, ModelType
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.args import Config, DisaggregationMode, parse_args from dynamo.sglang.args import Config, DisaggregationMode, parse_args
from dynamo.sglang.health_check import ( from dynamo.sglang.health_check import (
...@@ -33,9 +33,12 @@ from dynamo.sglang.request_handlers import ( ...@@ -33,9 +33,12 @@ from dynamo.sglang.request_handlers import (
configure_dynamo_logging() configure_dynamo_logging()
@dynamo_worker(static=False) async def worker():
async def worker(runtime: DistributedRuntime): config = await parse_args(sys.argv[1:])
dump_config(config.dynamo_args.dump_config_to, config)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, config.dynamo_args.store_kv, False)
def signal_handler(): def signal_handler():
asyncio.create_task(graceful_shutdown(runtime)) asyncio.create_task(graceful_shutdown(runtime))
...@@ -45,9 +48,6 @@ async def worker(runtime: DistributedRuntime): ...@@ -45,9 +48,6 @@ async def worker(runtime: DistributedRuntime):
logging.info("Signal handlers will trigger a graceful shutdown of the runtime") logging.info("Signal handlers will trigger a graceful shutdown of the runtime")
config = await parse_args(sys.argv[1:])
dump_config(config.dynamo_args.dump_config_to, config)
if config.dynamo_args.embedding_worker: if config.dynamo_args.embedding_worker:
await init_embedding(runtime, config) await init_embedding(runtime, config)
elif config.dynamo_args.multimodal_processor: elif config.dynamo_args.multimodal_processor:
......
...@@ -39,7 +39,7 @@ import dynamo.nixl_connect as nixl_connect ...@@ -39,7 +39,7 @@ import dynamo.nixl_connect as nixl_connect
from dynamo.common.config_dump import dump_config from dynamo.common.config_dump import dump_config
from dynamo.common.utils.prometheus import register_engine_metrics_callback from dynamo.common.utils.prometheus import register_engine_metrics_callback
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import TensorRTLLMEngine, get_llm_engine from dynamo.trtllm.engine import TensorRTLLMEngine, get_llm_engine
from dynamo.trtllm.health_check import TrtllmHealthCheckPayload from dynamo.trtllm.health_check import TrtllmHealthCheckPayload
...@@ -102,11 +102,13 @@ async def get_engine_runtime_config( ...@@ -102,11 +102,13 @@ async def get_engine_runtime_config(
return runtime_config return runtime_config
@dynamo_worker(static=False) async def worker():
async def worker(runtime: DistributedRuntime): config = cmd_line_args()
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, config.store_kv, False)
# Set up signal handler for graceful shutdown
def signal_handler(): def signal_handler():
# Schedule the shutdown coroutine instead of calling it directly # Schedule the shutdown coroutine instead of calling it directly
asyncio.create_task(graceful_shutdown(runtime)) asyncio.create_task(graceful_shutdown(runtime))
...@@ -116,7 +118,6 @@ async def worker(runtime: DistributedRuntime): ...@@ -116,7 +118,6 @@ async def worker(runtime: DistributedRuntime):
logging.info("Signal handlers set up for graceful shutdown") logging.info("Signal handlers set up for graceful shutdown")
config = cmd_line_args()
await init(runtime, config) await init(runtime, config)
......
...@@ -58,6 +58,7 @@ class Config: ...@@ -58,6 +58,7 @@ class Config:
self.tool_call_parser: Optional[str] = None self.tool_call_parser: Optional[str] = None
self.dump_config_to: Optional[str] = None self.dump_config_to: Optional[str] = None
self.custom_jinja_template: Optional[str] = None self.custom_jinja_template: Optional[str] = None
self.store_kv: str = ""
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
...@@ -87,8 +88,9 @@ class Config: ...@@ -87,8 +88,9 @@ class Config:
f"max_file_size_mb={self.max_file_size_mb}, " f"max_file_size_mb={self.max_file_size_mb}, "
f"reasoning_parser={self.reasoning_parser}, " f"reasoning_parser={self.reasoning_parser}, "
f"tool_call_parser={self.tool_call_parser}, " f"tool_call_parser={self.tool_call_parser}, "
f"dump_config_to={self.dump_config_to}," f"dump_config_to={self.dump_config_to}, "
f"custom_jinja_template={self.custom_jinja_template}" f"custom_jinja_template={self.custom_jinja_template}, "
f"store_kv={self.store_kv}"
) )
...@@ -278,6 +280,12 @@ def cmd_line_args(): ...@@ -278,6 +280,12 @@ def cmd_line_args():
default=None, default=None,
help="Path to a custom Jinja template file to override the model's default chat template. This template will take precedence over any template found in the model repository.", help="Path to a custom Jinja template file to override the model's default chat template. This template will take precedence over any template found in the model repository.",
) )
parser.add_argument(
"--store-kv",
type=str,
default=os.environ.get("DYN_STORE_KV", "etcd"),
help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
)
args = parser.parse_args() args = parser.parse_args()
...@@ -337,6 +345,7 @@ def cmd_line_args(): ...@@ -337,6 +345,7 @@ def cmd_line_args():
config.reasoning_parser = args.dyn_reasoning_parser config.reasoning_parser = args.dyn_reasoning_parser
config.tool_call_parser = args.dyn_tool_call_parser config.tool_call_parser = args.dyn_tool_call_parser
config.dump_config_to = args.dump_config_to config.dump_config_to = args.dump_config_to
config.store_kv = args.store_kv
# Handle custom jinja template path expansion (environment variables and home directory) # Handle custom jinja template path expansion (environment variables and home directory)
if args.custom_jinja_template: if args.custom_jinja_template:
......
...@@ -38,6 +38,7 @@ class Config: ...@@ -38,6 +38,7 @@ class Config:
migration_limit: int = 0 migration_limit: int = 0
kv_port: Optional[int] = None kv_port: Optional[int] = None
custom_jinja_template: Optional[str] = None custom_jinja_template: Optional[str] = None
store_kv: str
# mirror vLLM # mirror vLLM
model: str model: str
...@@ -164,6 +165,12 @@ def parse_args() -> Config: ...@@ -164,6 +165,12 @@ def parse_args() -> Config:
"'USER: <image> please describe the image ASSISTANT:'." "'USER: <image> please describe the image ASSISTANT:'."
), ),
) )
parser.add_argument(
"--store-kv",
type=str,
default=os.environ.get("DYN_STORE_KV", "etcd"),
help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
)
add_config_dump_args(parser) add_config_dump_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
...@@ -233,6 +240,7 @@ def parse_args() -> Config: ...@@ -233,6 +240,7 @@ def parse_args() -> Config:
config.multimodal_worker = args.multimodal_worker config.multimodal_worker = args.multimodal_worker
config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker
config.mm_prompt_template = args.mm_prompt_template config.mm_prompt_template = args.mm_prompt_template
config.store_kv = args.store_kv
# Validate custom Jinja template file exists if provided # Validate custom Jinja template file exists if provided
if config.custom_jinja_template is not None: if config.custom_jinja_template is not None:
......
...@@ -25,7 +25,7 @@ from dynamo.llm import ( ...@@ -25,7 +25,7 @@ from dynamo.llm import (
fetch_llm, fetch_llm,
register_llm, register_llm,
) )
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.multimodal_handlers import ( from dynamo.vllm.multimodal_handlers import (
EncodeWorkerHandler, EncodeWorkerHandler,
...@@ -70,16 +70,16 @@ async def graceful_shutdown(runtime): ...@@ -70,16 +70,16 @@ async def graceful_shutdown(runtime):
logging.info("DistributedRuntime shutdown complete") logging.info("DistributedRuntime shutdown complete")
@dynamo_worker(static=False) async def worker():
async def worker(runtime: DistributedRuntime):
config = parse_args() config = parse_args()
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, config.store_kv, False)
await configure_ports(config) await configure_ports(config)
overwrite_args(config) overwrite_args(config)
# Set up signal handler for graceful shutdown # Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler(): def signal_handler():
asyncio.create_task(graceful_shutdown(runtime)) asyncio.create_task(graceful_shutdown(runtime))
......
...@@ -50,7 +50,7 @@ async def main(): ...@@ -50,7 +50,7 @@ async def main():
return return
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, True) runtime = DistributedRuntime(loop, "mem", True)
# Connect to middle server or direct server based on argument # Connect to middle server or direct server based on argument
if use_middle_server: if use_middle_server:
......
...@@ -50,7 +50,7 @@ class MiddleServer: ...@@ -50,7 +50,7 @@ class MiddleServer:
async def main(): async def main():
"""Start the middle server""" """Start the middle server"""
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, True) runtime = DistributedRuntime(loop, "mem", True)
# Create middle server handler # Create middle server handler
handler = MiddleServer(runtime) handler = MiddleServer(runtime)
......
...@@ -31,7 +31,7 @@ class DemoServer: ...@@ -31,7 +31,7 @@ class DemoServer:
async def main(): async def main():
"""Start the demo server""" """Start the demo server"""
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, True) runtime = DistributedRuntime(loop, "mem", True)
# Create server component # Create server component
component = runtime.namespace("demo").component("server") component = runtime.namespace("demo").component("server")
......
...@@ -123,7 +123,7 @@ async def async_main(): ...@@ -123,7 +123,7 @@ async def async_main():
# Create DistributedRuntime - similar to frontend/main.py line 246 # Create DistributedRuntime - similar to frontend/main.py line 246
is_static = True # Use static mode (no etcd) is_static = True # Use static mode (no etcd)
runtime = DistributedRuntime(loop, is_static) # type: ignore[call-arg] runtime = DistributedRuntime(loop, "mem", is_static) # type: ignore[call-arg]
# Setup signal handlers for graceful shutdown # Setup signal handlers for graceful shutdown
def signal_handler(): def signal_handler():
......
...@@ -127,6 +127,12 @@ pub struct Flags { ...@@ -127,6 +127,12 @@ pub struct Flags {
#[arg(long, default_value = "false")] #[arg(long, default_value = "false")]
pub static_worker: bool, pub static_worker: bool,
/// Which key-value backend to use: etcd, mem, file.
/// Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details.
/// File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.
#[arg(long, default_value = "etcd")]
pub store_kv: String,
/// Everything after a `--`. /// Everything after a `--`.
/// These are the command line arguments to the python engine when using `pystr` or `pytok`. /// These are the command line arguments to the python engine when using `pystr` or `pytok`.
#[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)] #[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)]
......
...@@ -6,10 +6,11 @@ use dynamo_llm::entrypoint::EngineConfig; ...@@ -6,10 +6,11 @@ use dynamo_llm::entrypoint::EngineConfig;
use dynamo_llm::entrypoint::input::Input; use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder}; use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
use dynamo_runtime::distributed::DistributedConfig; use dynamo_runtime::distributed::DistributedConfig;
use dynamo_runtime::storage::key_value_store::KeyValueStoreSelect;
use dynamo_runtime::transports::nats;
use dynamo_runtime::{DistributedRuntime, Runtime}; use dynamo_runtime::{DistributedRuntime, Runtime};
mod flags; mod flags;
use either::Either;
pub use flags::Flags; pub use flags::Flags;
mod opt; mod opt;
pub use dynamo_llm::request_template::RequestTemplate; pub use dynamo_llm::request_template::RequestTemplate;
...@@ -73,14 +74,16 @@ pub async fn run( ...@@ -73,14 +74,16 @@ pub async fn run(
// TODO: old, address this later: // TODO: old, address this later:
// If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint. // If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we let LocalModel invent one. // If not, then the endpoint isn't exposed so we let LocalModel invent one.
let mut rt = Either::Left(runtime.clone());
if let Input::Endpoint(path) = &in_opt { if let Input::Endpoint(path) = &in_opt {
builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?)); builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?));
}
let dst_config = DistributedConfig::from_settings(flags.static_worker); let selected_store: KeyValueStoreSelect = flags.store_kv.parse()?;
let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?; let dst_config = DistributedConfig {
rt = Either::Right(distributed_runtime); store_backend: selected_store,
nats_config: nats::ClientOptions::default(),
is_static: flags.static_worker,
}; };
let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
if let Some(Output::Static(path)) = &out_opt { if let Some(Output::Static(path)) = &out_opt {
builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?)); builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?));
} }
...@@ -98,10 +101,16 @@ pub async fn run( ...@@ -98,10 +101,16 @@ pub async fn run(
flags.validate(&in_opt, &out_opt)?; flags.validate(&in_opt, &out_opt)?;
// Make an engine from the local_model, flags and output. // Make an engine from the local_model, flags and output.
let engine_config = engine_for(out_opt, flags.clone(), local_model, rt.clone()).await?; let engine_config = engine_for(
out_opt,
flags.clone(),
local_model,
distributed_runtime.clone(),
)
.await?;
// Run it from an input // Run it from an input
dynamo_llm::entrypoint::input::run_input(rt, in_opt, engine_config).await?; dynamo_llm::entrypoint::input::run_input(distributed_runtime, in_opt, engine_config).await?;
Ok(()) Ok(())
} }
...@@ -112,7 +121,7 @@ async fn engine_for( ...@@ -112,7 +121,7 @@ async fn engine_for(
out_opt: Output, out_opt: Output,
flags: Flags, flags: Flags,
local_model: LocalModel, local_model: LocalModel,
rt: Either<Runtime, DistributedRuntime>, drt: DistributedRuntime,
) -> anyhow::Result<EngineConfig> { ) -> anyhow::Result<EngineConfig> {
match out_opt { match out_opt {
Output::Auto => { Output::Auto => {
...@@ -135,10 +144,6 @@ async fn engine_for( ...@@ -135,10 +144,6 @@ async fn engine_for(
is_static: flags.static_worker, is_static: flags.static_worker,
}), }),
Output::Mocker => { Output::Mocker => {
let Either::Right(drt) = rt else {
panic!("Mocker requires a distributed runtime to run.");
};
let args = flags.mocker_config(); let args = flags.mocker_config();
let endpoint = local_model.endpoint_id().clone(); let endpoint = local_model.endpoint_id().clone();
......
...@@ -1606,6 +1606,7 @@ dependencies = [ ...@@ -1606,6 +1606,7 @@ dependencies = [
"figment", "figment",
"futures", "futures",
"humantime", "humantime",
"inotify",
"local-ip-address", "local-ip-address",
"log", "log",
"nid", "nid",
...@@ -2857,6 +2858,28 @@ version = "0.1.15" ...@@ -2857,6 +2858,28 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "inotify"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3"
dependencies = [
"bitflags 2.9.3",
"futures-core",
"inotify-sys",
"libc",
"tokio",
]
[[package]]
name = "inotify-sys"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "instant" name = "instant"
version = "0.1.13" version = "0.1.13"
......
...@@ -115,7 +115,7 @@ def parse_args(): ...@@ -115,7 +115,7 @@ def parse_args():
async def run(): async def run():
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, False) runtime = DistributedRuntime(loop, "etcd", False)
args = parse_args() args = parse_args()
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use dynamo_llm::local_model::LocalModel; use dynamo_llm::local_model::LocalModel;
use dynamo_runtime::distributed::DistributedConfig;
use dynamo_runtime::storage::key_value_store::KeyValueStoreSelect;
use futures::StreamExt; use futures::StreamExt;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use pyo3::IntoPyObjectExt; use pyo3::IntoPyObjectExt;
...@@ -426,7 +428,9 @@ enum ModelInput { ...@@ -426,7 +428,9 @@ enum ModelInput {
#[pymethods] #[pymethods]
impl DistributedRuntime { impl DistributedRuntime {
#[new] #[new]
fn new(event_loop: PyObject, is_static: bool) -> PyResult<Self> { fn new(event_loop: PyObject, store_kv: String, is_static: bool) -> PyResult<Self> {
let selected_kv_store: KeyValueStoreSelect = store_kv.parse().map_err(to_pyerr)?;
// Try to get existing runtime first, create new Worker only if needed // Try to get existing runtime first, create new Worker only if needed
// This allows multiple DistributedRuntime instances to share the same tokio runtime // This allows multiple DistributedRuntime instances to share the same tokio runtime
let runtime = rs::Worker::runtime_from_existing() let runtime = rs::Worker::runtime_from_existing()
...@@ -464,9 +468,14 @@ impl DistributedRuntime { ...@@ -464,9 +468,14 @@ impl DistributedRuntime {
rs::DistributedRuntime::from_settings_without_discovery(runtime), rs::DistributedRuntime::from_settings_without_discovery(runtime),
) )
} else { } else {
let config = DistributedConfig {
store_backend: selected_kv_store,
is_static: false,
nats_config: dynamo_runtime::transports::nats::ClientOptions::default(),
};
runtime runtime
.secondary() .secondary()
.block_on(rs::DistributedRuntime::from_settings(runtime)) .block_on(rs::DistributedRuntime::new(runtime, config))
}; };
let inner = inner.map_err(to_pyerr)?; let inner = inner.map_err(to_pyerr)?;
...@@ -628,7 +637,7 @@ impl DistributedRuntime { ...@@ -628,7 +637,7 @@ impl DistributedRuntime {
} }
fn shutdown(&self) { fn shutdown(&self) {
self.inner.runtime().shutdown(); self.inner.shutdown();
} }
fn event_loop(&self) -> PyObject { fn event_loop(&self) -> PyObject {
......
...@@ -299,7 +299,7 @@ pub fn run_input<'p>( ...@@ -299,7 +299,7 @@ pub fn run_input<'p>(
let input_enum: Input = input.parse().map_err(to_pyerr)?; let input_enum: Input = input.parse().map_err(to_pyerr)?;
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
dynamo_llm::entrypoint::input::run_input( dynamo_llm::entrypoint::input::run_input(
either::Either::Right(distributed_runtime.inner.clone()), distributed_runtime.inner.clone(),
input_enum, input_enum,
engine_config.inner, engine_config.inner,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment