Unverified Commit 886506c1 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Command line flag to set request plane mode: tcp, http or nats (#4365)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 9f3d1b90
...@@ -238,18 +238,6 @@ dependencies = [ ...@@ -238,18 +238,6 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
] ]
[[package]]
name = "async-channel"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "924ed96dd52d1b75e9c1a3e6275715fd320f5f9439fb5a4a11fa51f4221158d2"
dependencies = [
"concurrent-queue",
"event-listener-strategy",
"futures-core",
"pin-project-lite",
]
[[package]] [[package]]
name = "async-nats" name = "async-nats"
version = "0.40.0" version = "0.40.0"
...@@ -2089,16 +2077,6 @@ dependencies = [ ...@@ -2089,16 +2077,6 @@ dependencies = [
"syn 2.0.110", "syn 2.0.110",
] ]
[[package]]
name = "dlpark"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc178fc3bf4ce54c26ccffcf271ff574954ac4b940f15121be3d69f277194537"
dependencies = [
"half 2.7.1",
"pyo3",
]
[[package]] [[package]]
name = "dlv-list" name = "dlv-list"
version = "0.5.2" version = "0.5.2"
...@@ -4072,15 +4050,6 @@ dependencies = [ ...@@ -4072,15 +4050,6 @@ dependencies = [
"web-time", "web-time",
] ]
[[package]]
name = "indoc"
version = "2.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706"
dependencies = [
"rustversion",
]
[[package]] [[package]]
name = "inlinable_string" name = "inlinable_string"
version = "0.1.15" version = "0.1.15"
...@@ -4170,15 +4139,6 @@ dependencies = [ ...@@ -4170,15 +4139,6 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "inventory"
version = "0.3.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc61209c082fbeb19919bee74b176221b27223e27b65d781eb91af24eb1fb46e"
dependencies = [
"rustversion",
]
[[package]] [[package]]
name = "iovec" name = "iovec"
version = "0.1.4" version = "0.1.4"
...@@ -4601,40 +4561,6 @@ dependencies = [ ...@@ -4601,40 +4561,6 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "kvbm-py3"
version = "0.7.0"
dependencies = [
"anyhow",
"async-stream",
"async-trait",
"cudarc 0.16.6",
"derive-getters",
"dlpark",
"dynamo-llm",
"dynamo-runtime",
"either",
"futures",
"local-ip-address",
"once_cell",
"prometheus",
"pyo3",
"pyo3-async-runtimes",
"pythonize",
"rand 0.9.2",
"rstest 0.25.0",
"serde",
"serde_json",
"socket2 0.6.1",
"thiserror 2.0.17",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid 1.18.1",
]
[[package]] [[package]]
name = "lalrpop-util" name = "lalrpop-util"
version = "0.20.2" version = "0.20.2"
...@@ -5004,15 +4930,6 @@ dependencies = [ ...@@ -5004,15 +4930,6 @@ dependencies = [
"autocfg", "autocfg",
] ]
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "metal" name = "metal"
version = "0.27.0" version = "0.27.0"
...@@ -5581,7 +5498,7 @@ dependencies = [ ...@@ -5581,7 +5498,7 @@ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
"cfg-if 1.0.4", "cfg-if 1.0.4",
"libc", "libc",
"memoffset 0.7.1", "memoffset",
"pin-utils", "pin-utils",
] ]
...@@ -6751,107 +6668,6 @@ dependencies = [ ...@@ -6751,107 +6668,6 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "pyo3"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
dependencies = [
"cfg-if 1.0.4",
"indoc",
"libc",
"memoffset 0.9.1",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-async-runtimes"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "977dc837525cfd22919ba6a831413854beb7c99a256c03bf8624ad707e45810e"
dependencies = [
"async-channel",
"clap 4.5.51",
"futures",
"inventory",
"once_cell",
"pin-project-lite",
"pyo3",
"pyo3-async-runtimes-macros",
"tokio",
]
[[package]]
name = "pyo3-async-runtimes-macros"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2df2884957d2476731f987673befac5d521dff10abb0a7cbe12015bc7702fe9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.110",
]
[[package]]
name = "pyo3-build-config"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.110",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.110",
]
[[package]]
name = "pythonize"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91a6ee7a084f913f98d70cdc3ebec07e852b735ae3059a1500db2661265da9ff"
dependencies = [
"pyo3",
"serde",
]
[[package]] [[package]]
name = "qoi" name = "qoi"
version = "0.4.1" version = "0.4.1"
...@@ -7442,18 +7258,6 @@ dependencies = [ ...@@ -7442,18 +7258,6 @@ dependencies = [
"rustc_version", "rustc_version",
] ]
[[package]]
name = "rstest"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fc39292f8613e913f7df8fa892b8944ceb47c247b78e1b1ae2f09e019be789d"
dependencies = [
"futures-timer",
"futures-util",
"rstest_macros 0.25.0",
"rustc_version",
]
[[package]] [[package]]
name = "rstest_macros" name = "rstest_macros"
version = "0.18.2" version = "0.18.2"
...@@ -7489,24 +7293,6 @@ dependencies = [ ...@@ -7489,24 +7293,6 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "rstest_macros"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f168d99749d307be9de54d23fd226628d99768225ef08f6ffb52e0182a27746"
dependencies = [
"cfg-if 1.0.4",
"glob",
"proc-macro-crate",
"proc-macro2",
"quote",
"regex",
"relative-path",
"rustc_version",
"syn 2.0.110",
"unicode-ident",
]
[[package]] [[package]]
name = "rstest_reuse" name = "rstest_reuse"
version = "0.7.0" version = "0.7.0"
...@@ -9979,12 +9765,6 @@ dependencies = [ ...@@ -9979,12 +9765,6 @@ dependencies = [
"rand 0.8.5", "rand 0.8.5",
] ]
[[package]]
name = "unindent"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
[[package]] [[package]]
name = "unsafe-libyaml" name = "unsafe-libyaml"
version = "0.2.11" version = "0.2.11"
......
...@@ -216,9 +216,17 @@ def parse_args(): ...@@ -216,9 +216,17 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--store-kv", "--store-kv",
type=str, type=str,
choices=["etcd", "file", "mem"],
default=os.environ.get("DYN_STORE_KV", "etcd"), default=os.environ.get("DYN_STORE_KV", "etcd"),
help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.", help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
) )
parser.add_argument(
"--request-plane",
type=str,
choices=["nats", "http", "tcp"],
default=os.environ.get("DYN_REQUEST_PLANE", "nats"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
)
flags = parser.parse_args() flags = parser.parse_args()
...@@ -253,7 +261,7 @@ async def async_main(): ...@@ -253,7 +261,7 @@ async def async_main():
os.environ["DYN_METRICS_PREFIX"] = flags.metrics_prefix os.environ["DYN_METRICS_PREFIX"] = flags.metrics_prefix
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, flags.store_kv) runtime = DistributedRuntime(loop, flags.store_kv, flags.request_plane)
def signal_handler(): def signal_handler():
asyncio.create_task(graceful_shutdown(runtime)) asyncio.create_task(graceful_shutdown(runtime))
......
...@@ -207,9 +207,17 @@ def parse_args(): ...@@ -207,9 +207,17 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--store-kv", "--store-kv",
type=str, type=str,
choices=["etcd", "file", "mem"],
default=os.environ.get("DYN_STORE_KV", "etcd"), default=os.environ.get("DYN_STORE_KV", "etcd"),
help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.", help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
) )
parser.add_argument(
"--request-plane",
type=str,
choices=["nats", "http", "tcp"],
default=os.environ.get("DYN_REQUEST_PLANE", "nats"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
)
args = parser.parse_args() args = parser.parse_args()
validate_worker_type_args(args) validate_worker_type_args(args)
......
...@@ -72,7 +72,7 @@ async def launch_workers(args, extra_engine_args_path): ...@@ -72,7 +72,7 @@ async def launch_workers(args, extra_engine_args_path):
logger.info(f"Creating mocker worker {worker_id + 1}/{args.num_workers}") logger.info(f"Creating mocker worker {worker_id + 1}/{args.num_workers}")
# Create a separate DistributedRuntime for this worker (on same event loop) # Create a separate DistributedRuntime for this worker (on same event loop)
runtime = DistributedRuntime(loop, args.store_kv) runtime = DistributedRuntime(loop, args.store_kv, args.request_plane)
runtimes.append(runtime) runtimes.append(runtime)
# Create EntrypointArgs for this worker # Create EntrypointArgs for this worker
......
...@@ -99,9 +99,17 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = { ...@@ -99,9 +99,17 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"store-kv": { "store-kv": {
"flags": ["--store-kv"], "flags": ["--store-kv"],
"type": str, "type": str,
"choices": ["etcd", "file", "mem"],
"default": os.environ.get("DYN_STORE_KV", "etcd"), "default": os.environ.get("DYN_STORE_KV", "etcd"),
"help": "Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.", "help": "Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
}, },
"request-plane": {
"flags": ["--request-plane"],
"type": str,
"choices": ["nats", "http", "tcp"],
"default": os.environ.get("DYN_REQUEST_PLANE", "nats"),
"help": "Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
},
} }
...@@ -112,6 +120,7 @@ class DynamoArgs: ...@@ -112,6 +120,7 @@ class DynamoArgs:
endpoint: str endpoint: str
migration_limit: int migration_limit: int
store_kv: str store_kv: str
request_plane: str
# tool and reasoning parser options # tool and reasoning parser options
tool_call_parser: Optional[str] = None tool_call_parser: Optional[str] = None
...@@ -448,6 +457,7 @@ async def parse_args(args: list[str]) -> Config: ...@@ -448,6 +457,7 @@ async def parse_args(args: list[str]) -> Config:
endpoint=parsed_endpoint_name, endpoint=parsed_endpoint_name,
migration_limit=parsed_args.migration_limit, migration_limit=parsed_args.migration_limit,
store_kv=parsed_args.store_kv, store_kv=parsed_args.store_kv,
request_plane=parsed_args.request_plane,
tool_call_parser=tool_call_parser, tool_call_parser=tool_call_parser,
reasoning_parser=reasoning_parser, reasoning_parser=reasoning_parser,
custom_jinja_template=expanded_template_path, custom_jinja_template=expanded_template_path,
......
...@@ -69,7 +69,9 @@ async def worker(): ...@@ -69,7 +69,9 @@ async def worker():
dump_config(config.dynamo_args.dump_config_to, config) dump_config(config.dynamo_args.dump_config_to, config)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, config.dynamo_args.store_kv) runtime = DistributedRuntime(
loop, config.dynamo_args.store_kv, config.dynamo_args.request_plane
)
def signal_handler(): def signal_handler():
asyncio.create_task(graceful_shutdown(runtime)) asyncio.create_task(graceful_shutdown(runtime))
......
...@@ -106,7 +106,7 @@ async def worker(): ...@@ -106,7 +106,7 @@ async def worker():
config = cmd_line_args() config = cmd_line_args()
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, config.store_kv) runtime = DistributedRuntime(loop, config.store_kv, config.request_plane)
# Set up signal handler for graceful shutdown # Set up signal handler for graceful shutdown
def signal_handler(): def signal_handler():
......
...@@ -59,6 +59,7 @@ class Config: ...@@ -59,6 +59,7 @@ class Config:
self.dump_config_to: Optional[str] = None self.dump_config_to: Optional[str] = None
self.custom_jinja_template: Optional[str] = None self.custom_jinja_template: Optional[str] = None
self.store_kv: str = "" self.store_kv: str = ""
self.request_plane: str = ""
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
...@@ -90,7 +91,8 @@ class Config: ...@@ -90,7 +91,8 @@ class Config:
f"tool_call_parser={self.tool_call_parser}, " f"tool_call_parser={self.tool_call_parser}, "
f"dump_config_to={self.dump_config_to}, " f"dump_config_to={self.dump_config_to}, "
f"custom_jinja_template={self.custom_jinja_template}, " f"custom_jinja_template={self.custom_jinja_template}, "
f"store_kv={self.store_kv}" f"store_kv={self.store_kv}, "
f"request_plane={self.request_plane}"
) )
...@@ -283,9 +285,17 @@ def cmd_line_args(): ...@@ -283,9 +285,17 @@ def cmd_line_args():
parser.add_argument( parser.add_argument(
"--store-kv", "--store-kv",
type=str, type=str,
choices=["etcd", "file", "mem"],
default=os.environ.get("DYN_STORE_KV", "etcd"), default=os.environ.get("DYN_STORE_KV", "etcd"),
help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.", help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
) )
parser.add_argument(
"--request-plane",
type=str,
choices=["nats", "http", "tcp"],
default=os.environ.get("DYN_REQUEST_PLANE", "nats"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
)
args = parser.parse_args() args = parser.parse_args()
...@@ -346,6 +356,7 @@ def cmd_line_args(): ...@@ -346,6 +356,7 @@ def cmd_line_args():
config.tool_call_parser = args.dyn_tool_call_parser config.tool_call_parser = args.dyn_tool_call_parser
config.dump_config_to = args.dump_config_to config.dump_config_to = args.dump_config_to
config.store_kv = args.store_kv config.store_kv = args.store_kv
config.request_plane = args.request_plane
# Handle custom jinja template path expansion (environment variables and home directory) # Handle custom jinja template path expansion (environment variables and home directory)
if args.custom_jinja_template: if args.custom_jinja_template:
......
...@@ -42,6 +42,7 @@ class Config: ...@@ -42,6 +42,7 @@ class Config:
migration_limit: int = 0 migration_limit: int = 0
custom_jinja_template: Optional[str] = None custom_jinja_template: Optional[str] = None
store_kv: str store_kv: str
request_plane: str
# mirror vLLM # mirror vLLM
model: str model: str
...@@ -177,9 +178,17 @@ def parse_args() -> Config: ...@@ -177,9 +178,17 @@ def parse_args() -> Config:
parser.add_argument( parser.add_argument(
"--store-kv", "--store-kv",
type=str, type=str,
choices=["etcd", "file", "mem"],
default=os.environ.get("DYN_STORE_KV", "etcd"), default=os.environ.get("DYN_STORE_KV", "etcd"),
help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.", help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
) )
parser.add_argument(
"--request-plane",
type=str,
choices=["nats", "http", "tcp"],
default=os.environ.get("DYN_REQUEST_PLANE", "nats"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
)
add_config_dump_args(parser) add_config_dump_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
...@@ -258,6 +267,7 @@ def parse_args() -> Config: ...@@ -258,6 +267,7 @@ def parse_args() -> Config:
config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker
config.mm_prompt_template = args.mm_prompt_template config.mm_prompt_template = args.mm_prompt_template
config.store_kv = args.store_kv config.store_kv = args.store_kv
config.request_plane = args.request_plane
# Validate custom Jinja template file exists if provided # Validate custom Jinja template file exists if provided
if config.custom_jinja_template is not None: if config.custom_jinja_template is not None:
......
...@@ -75,7 +75,7 @@ async def worker(): ...@@ -75,7 +75,7 @@ async def worker():
config = parse_args() config = parse_args()
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, config.store_kv) runtime = DistributedRuntime(loop, config.store_kv, config.request_plane)
overwrite_args(config) overwrite_args(config)
......
...@@ -50,7 +50,7 @@ async def main(): ...@@ -50,7 +50,7 @@ async def main():
return return
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, "file") runtime = DistributedRuntime(loop, "file", "nats")
# Connect to middle server or direct server based on argument # Connect to middle server or direct server based on argument
if use_middle_server: if use_middle_server:
......
...@@ -50,7 +50,7 @@ class MiddleServer: ...@@ -50,7 +50,7 @@ class MiddleServer:
async def main(): async def main():
"""Start the middle server""" """Start the middle server"""
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, "file") runtime = DistributedRuntime(loop, "file", "nats")
# Create middle server handler # Create middle server handler
handler = MiddleServer(runtime) handler = MiddleServer(runtime)
......
...@@ -31,7 +31,7 @@ class DemoServer: ...@@ -31,7 +31,7 @@ class DemoServer:
async def main(): async def main():
"""Start the demo server""" """Start the demo server"""
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, "file") runtime = DistributedRuntime(loop, "file", "nats")
# Create server component # Create server component
component = runtime.namespace("demo").component("server") component = runtime.namespace("demo").component("server")
......
...@@ -121,9 +121,9 @@ async def main(): ...@@ -121,9 +121,9 @@ async def main():
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
if is_static: if is_static:
runtime = DistributedRuntime(loop, "file") runtime = DistributedRuntime(loop, "file", "nats")
else: else:
runtime = DistributedRuntime(loop, "etcd") runtime = DistributedRuntime(loop, "etcd", "nats")
try: try:
await worker(runtime) # type: ignore[arg-type] await worker(runtime) # type: ignore[arg-type]
......
...@@ -122,7 +122,7 @@ async def async_main(): ...@@ -122,7 +122,7 @@ async def async_main():
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
# Create DistributedRuntime - similar to frontend/main.py line 246 # Create DistributedRuntime - similar to frontend/main.py line 246
runtime = DistributedRuntime(loop, "file") # type: ignore[call-arg] runtime = DistributedRuntime(loop, "file", "nats") # type: ignore[call-arg]
# Setup signal handlers for graceful shutdown # Setup signal handlers for graceful shutdown
def signal_handler(): def signal_handler():
......
...@@ -123,11 +123,14 @@ pub struct Flags { ...@@ -123,11 +123,14 @@ pub struct Flags {
/// Which key-value backend to use: etcd, mem, file. /// Which key-value backend to use: etcd, mem, file.
/// Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. /// Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details.
/// File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv. /// File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.
#[arg(long, default_value = "etcd")] #[arg(long, default_value = "etcd", value_parser = ["etcd", "file", "mem"])]
pub store_kv: String, pub store_kv: String,
/// Everything after a `--`. /// Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp].
/// These are the command line arguments to the python engine when using `pystr` or `pytok`. #[arg(long, default_value = "nats", value_parser = ["nats", "http", "tcp"])]
pub request_plane: String,
/// Everything after a `--`. Not currently used.
#[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)] #[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)]
pub last: Vec<String>, pub last: Vec<String>,
} }
......
...@@ -5,7 +5,7 @@ use anyhow::Context as _; ...@@ -5,7 +5,7 @@ use anyhow::Context as _;
use dynamo_llm::entrypoint::EngineConfig; use dynamo_llm::entrypoint::EngineConfig;
use dynamo_llm::entrypoint::input::Input; use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder}; use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
use dynamo_runtime::distributed::DistributedConfig; use dynamo_runtime::distributed::{DistributedConfig, RequestPlaneMode};
use dynamo_runtime::storage::key_value_store::KeyValueStoreSelect; use dynamo_runtime::storage::key_value_store::KeyValueStoreSelect;
use dynamo_runtime::transports::nats; use dynamo_runtime::transports::nats;
use dynamo_runtime::{DistributedRuntime, Runtime}; use dynamo_runtime::{DistributedRuntime, Runtime};
...@@ -78,9 +78,11 @@ pub async fn run( ...@@ -78,9 +78,11 @@ pub async fn run(
builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?)); builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?));
} }
let selected_store: KeyValueStoreSelect = flags.store_kv.parse()?; let selected_store: KeyValueStoreSelect = flags.store_kv.parse()?;
let request_plane: RequestPlaneMode = flags.request_plane.parse()?;
let dst_config = DistributedConfig { let dst_config = DistributedConfig {
store_backend: selected_store, store_backend: selected_store,
nats_config: nats::ClientOptions::default(), nats_config: nats::ClientOptions::default(),
request_plane,
}; };
let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?; let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
let local_model = builder.build().await?; let local_model = builder.build().await?;
......
...@@ -115,7 +115,7 @@ def parse_args(): ...@@ -115,7 +115,7 @@ def parse_args():
async def run(): async def run():
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, "etcd") runtime = DistributedRuntime(loop, "etcd", "nats")
args = parse_args() args = parse_args()
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use dynamo_llm::local_model::LocalModel; use dynamo_llm::local_model::LocalModel;
use dynamo_runtime::distributed::DistributedConfig; use dynamo_runtime::distributed::{DistributedConfig, RequestPlaneMode};
use dynamo_runtime::storage::key_value_store::KeyValueStoreSelect; use dynamo_runtime::storage::key_value_store::KeyValueStoreSelect;
use futures::StreamExt; use futures::StreamExt;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
...@@ -429,8 +429,9 @@ enum ModelInput { ...@@ -429,8 +429,9 @@ enum ModelInput {
#[pymethods] #[pymethods]
impl DistributedRuntime { impl DistributedRuntime {
#[new] #[new]
fn new(event_loop: PyObject, store_kv: String) -> PyResult<Self> { fn new(event_loop: PyObject, store_kv: String, request_plane: String) -> PyResult<Self> {
let selected_kv_store: KeyValueStoreSelect = store_kv.parse().map_err(to_pyerr)?; let selected_kv_store: KeyValueStoreSelect = store_kv.parse().map_err(to_pyerr)?;
let request_plane: RequestPlaneMode = request_plane.parse().map_err(to_pyerr)?;
// Try to get existing runtime first, create new Worker only if needed // Try to get existing runtime first, create new Worker only if needed
// This allows multiple DistributedRuntime instances to share the same tokio runtime // This allows multiple DistributedRuntime instances to share the same tokio runtime
...@@ -463,6 +464,7 @@ impl DistributedRuntime { ...@@ -463,6 +464,7 @@ impl DistributedRuntime {
let runtime_config = DistributedConfig { let runtime_config = DistributedConfig {
store_backend: selected_kv_store, store_backend: selected_kv_store,
nats_config: dynamo_runtime::transports::nats::ClientOptions::default(), nats_config: dynamo_runtime::transports::nats::ClientOptions::default(),
request_plane,
}; };
let inner = runtime let inner = runtime
.secondary() .secondary()
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import os
from functools import wraps from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type, Union from typing import Any, AsyncGenerator, Callable, Type, Union
...@@ -25,7 +26,8 @@ def dynamo_worker(): ...@@ -25,7 +26,8 @@ def dynamo_worker():
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, "etcd") request_plane = os.environ.get("DYN_REQUEST_PLANE", "nats")
runtime = DistributedRuntime(loop, "etcd", request_plane)
await func(runtime, *args, **kwargs) await func(runtime, *args, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment