Unverified Commit c03e2f6b authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Migrate planner virtual_connector internals into bindings (#3205)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent e1c7a4a4
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio
import logging import logging
import os import os
import time
from typing import Optional from typing import Optional
from dynamo._core import VirtualConnectorCoordinator
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
from dynamo.planner.planner_connector import PlannerConnector from dynamo.planner.planner_connector import PlannerConnector
from dynamo.runtime import DistributedRuntime, EtcdKvCache from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging() configure_dynamo_logging()
...@@ -28,105 +27,37 @@ SCALING_MAX_RETRIES = SCALING_MAX_WAIT_TIME // SCALING_CHECK_INTERVAL # 180 ret ...@@ -28,105 +27,37 @@ SCALING_MAX_RETRIES = SCALING_MAX_WAIT_TIME // SCALING_CHECK_INTERVAL # 180 ret
class VirtualConnector(PlannerConnector): class VirtualConnector(PlannerConnector):
""" """
This is a virtual connector for planner to output scaling decisions to non-native environments This is a virtual connector for planner to output scaling decisions to non-native environments
This virtual connector does not actually scale the deployment, instead, it communicates with the non-native environment through ETCD This virtual connector does not actually scale the deployment, instead, it communicates with the non-native environment through dynamo-runtime's VirtualConnectorCoordinator.
The deployment environment needs to read from ETCD to receive the scaling decisions and update ETCD to report scaling status The deployment environment needs to use VirtualConnectorClient (in the Rust/Python bindings) to read from the scaling decisions and update report scaling status.
The prefix for the ETCD key is /{dynamo_namespace}/planner/
To output the scaling decisions, planner write to three keys:
- num_prefill_workers: an integer (stored as string), specifying how many prefill workers the deployment should have in the last scaling decision
- num_decode_workers: an integer (stored as string), specifying how many decode workers the deployment should have in the last scaling decision
- decision_id: an integer (stored as string), specifying an incremental id for the last scaling decision, if there's no scaling decision, the value would be -1
To receive the status of the scaling decisions, the deployment environment needs to write to this key:
- scaled_decision_id: an integer (stored as string), specifying if the newest decision_id that has been scaled
""" """
def __init__( def __init__(
self, runtime: DistributedRuntime, dynamo_namespace: str, backend: str self, runtime: DistributedRuntime, dynamo_namespace: str, backend: str
): ):
etcd_client = runtime.do_not_use_etcd_client() self.connector = VirtualConnectorCoordinator(
if etcd_client is None: runtime,
raise RuntimeError("ETCD client is not initialized") dynamo_namespace,
SCALING_CHECK_INTERVAL,
SCALING_MAX_WAIT_TIME,
SCALING_MAX_RETRIES,
)
self.backend = backend self.backend = backend
self.dynamo_namespace = dynamo_namespace
self.worker_component_names = WORKER_COMPONENT_NAMES[backend] self.worker_component_names = WORKER_COMPONENT_NAMES[backend]
# Initialize current worker counts
self.num_prefill_workers = 0
self.num_decode_workers = 0
self.decision_id = -1
# Track when we first started skipping scaling due to unready state
self.first_skip_timestamp: Optional[float] = None
# Store etcd_client for async initialization
self._etcd_client = etcd_client
self._etcd_kv_cache = None
async def _async_init(self): async def _async_init(self):
"""Async initialization that must be called after __init__""" """Async initialization that must be called after __init__"""
if self._etcd_kv_cache is not None: await self.connector.async_init()
return # Already initialized
# Create EtcdKvCache with initial values
initial_values = {
"num_prefill_workers": str(self.num_prefill_workers).encode("utf-8"),
"num_decode_workers": str(self.num_decode_workers).encode("utf-8"),
"decision_id": str(self.decision_id).encode("utf-8"),
}
self._etcd_kv_cache = await EtcdKvCache.create(
self._etcd_client,
f"/{self.dynamo_namespace}/planner/",
initial_values,
)
# Load current values from ETCD if they exist async def _update_scaling_decision(
await self._load_current_state() self, num_prefill: Optional[int] = None, num_decode: Optional[int] = None
):
@property """Update scaling decision"""
def etcd_kv_cache(self): await self.connector.update_scaling_decision(num_prefill, num_decode)
"""Get the etcd_kv_cache, ensuring async initialization is complete"""
if self._etcd_kv_cache is None:
raise RuntimeError(
"VirtualConnector not properly initialized. Call _async_init() first."
)
return self._etcd_kv_cache
async def _load_current_state(self):
"""Load current state from ETCD"""
# Get all current values
all_values = await self.etcd_kv_cache.get_all()
# Parse num_prefill_workers
if "num_prefill_workers" in all_values:
try:
self.num_prefill_workers = int(
all_values["num_prefill_workers"].decode("utf-8")
)
except (ValueError, AttributeError):
logger.warning(
"Failed to parse num_prefill_workers from ETCD, using default 0"
)
# Parse num_decode_workers
if "num_decode_workers" in all_values:
try:
self.num_decode_workers = int(
all_values["num_decode_workers"].decode("utf-8")
)
except (ValueError, AttributeError):
logger.warning(
"Failed to parse num_decode_workers from ETCD, using default 0"
)
# Parse decision_id async def _wait_for_scaling_completion(self):
if "decision_id" in all_values: """Wait for the deployment environment to report that scaling is complete"""
try: await self.connector.wait_for_scaling_completion()
self.decision_id = int(all_values["decision_id"].decode("utf-8"))
except (ValueError, AttributeError):
logger.warning(
"Failed to parse decision_id from ETCD, using default -1"
)
def _component_to_worker_type(self, component_name: str) -> Optional[str]: def _component_to_worker_type(self, component_name: str) -> Optional[str]:
"""Map component name to worker type (prefill or decode)""" """Map component name to worker type (prefill or decode)"""
...@@ -137,121 +68,6 @@ class VirtualConnector(PlannerConnector): ...@@ -137,121 +68,6 @@ class VirtualConnector(PlannerConnector):
else: else:
return None return None
async def _is_scaling_ready(self) -> bool:
"""Check if the previous scaling decision has been completed"""
# If this is the first decision, it's always ready
if self.decision_id == -1:
return True
# Check if scaled_decision_id matches current decision_id
scaled_decision_id_bytes = await self.etcd_kv_cache.get("scaled_decision_id")
if scaled_decision_id_bytes:
try:
scaled_decision_id = int(scaled_decision_id_bytes.decode("utf-8"))
return scaled_decision_id >= self.decision_id
except (ValueError, AttributeError):
logger.warning("Failed to parse scaled_decision_id from ETCD")
# If no scaled_decision_id exists, assume not ready
return False
async def _update_scaling_decision(
self, num_prefill: Optional[int] = None, num_decode: Optional[int] = None
):
"""Update scaling decision in ETCD"""
# Check if there's actually a change
prefill_changed = (
num_prefill is not None and num_prefill != self.num_prefill_workers
)
decode_changed = (
num_decode is not None and num_decode != self.num_decode_workers
)
if not prefill_changed and not decode_changed:
logger.info(
f"No scaling needed (prefill={self.num_prefill_workers}, decode={self.num_decode_workers}), skipping ETCD update"
)
return
# Check if previous scaling is ready
is_ready = await self._is_scaling_ready()
if not is_ready:
current_time = time.time()
# If this is the first time we're skipping, record the timestamp
if self.first_skip_timestamp is None:
self.first_skip_timestamp = current_time
logger.info(
f"Previous scaling decision #{self.decision_id} not ready, starting to track skip time"
)
# Check if we've been waiting too long
if self.first_skip_timestamp is not None:
time_waited = current_time - self.first_skip_timestamp
else:
# This should not happen since we just set it above, but for type safety
time_waited = 0.0
if time_waited < SCALING_MAX_WAIT_TIME:
logger.warning(
f"Previous scaling decision #{self.decision_id} not ready, "
f"skipping new decision (waited {time_waited:.1f}s / {SCALING_MAX_WAIT_TIME}s)"
)
return
else:
logger.warning(
f"Previous scaling decision #{self.decision_id} not ready after {SCALING_MAX_WAIT_TIME}s, "
f"proceeding with new decision anyway"
)
# Reset the skip timestamp since we're making a decision
self.first_skip_timestamp = None
# Update internal state
if num_prefill is not None:
self.num_prefill_workers = num_prefill
if num_decode is not None:
self.num_decode_workers = num_decode
# Increment decision_id
self.decision_id += 1
# Write to ETCD
await self.etcd_kv_cache.put(
"num_prefill_workers", str(self.num_prefill_workers).encode("utf-8")
)
await self.etcd_kv_cache.put(
"num_decode_workers", str(self.num_decode_workers).encode("utf-8")
)
await self.etcd_kv_cache.put(
"decision_id", str(self.decision_id).encode("utf-8")
)
logger.info(
f"Updated scaling decision #{self.decision_id}: prefill={self.num_prefill_workers}, decode={self.num_decode_workers}"
)
async def _wait_for_scaling_completion(self):
"""Wait for the deployment environment to report that scaling is complete"""
for _ in range(SCALING_MAX_RETRIES):
scaled_decision_id_bytes = await self.etcd_kv_cache.get(
"scaled_decision_id"
)
if scaled_decision_id_bytes:
try:
scaled_decision_id = int(scaled_decision_id_bytes.decode("utf-8"))
if scaled_decision_id >= self.decision_id:
logger.info(f"Scaling decision #{self.decision_id} completed")
return
except (ValueError, AttributeError):
logger.warning("Failed to parse scaled_decision_id from ETCD")
await asyncio.sleep(SCALING_CHECK_INTERVAL)
logger.warning(
f"Timeout waiting for scaling decision #{self.decision_id} to complete after {SCALING_MAX_WAIT_TIME}s"
)
async def add_component(self, component_name: str, blocking: bool = True): async def add_component(self, component_name: str, blocking: bool = True):
"""Add a component by increasing its replica count by 1""" """Add a component by increasing its replica count by 1"""
worker_type = self._component_to_worker_type(component_name) worker_type = self._component_to_worker_type(component_name)
...@@ -259,12 +75,14 @@ class VirtualConnector(PlannerConnector): ...@@ -259,12 +75,14 @@ class VirtualConnector(PlannerConnector):
logger.warning(f"Unknown component name: {component_name}, skipping") logger.warning(f"Unknown component name: {component_name}, skipping")
return return
state = self.connector.read_state()
if worker_type == "prefill": if worker_type == "prefill":
await self._update_scaling_decision( await self._update_scaling_decision(
num_prefill=self.num_prefill_workers + 1 num_prefill=state.num_prefill_workers + 1
) )
elif worker_type == "decode": elif worker_type == "decode":
await self._update_scaling_decision(num_decode=self.num_decode_workers + 1) await self._update_scaling_decision(num_decode=state.num_decode_workers + 1)
if blocking: if blocking:
await self._wait_for_scaling_completion() await self._wait_for_scaling_completion()
...@@ -276,11 +94,13 @@ class VirtualConnector(PlannerConnector): ...@@ -276,11 +94,13 @@ class VirtualConnector(PlannerConnector):
logger.warning(f"Unknown component name: {component_name}, skipping") logger.warning(f"Unknown component name: {component_name}, skipping")
return return
state = self.connector.read_state()
if worker_type == "prefill": if worker_type == "prefill":
new_count = max(0, self.num_prefill_workers - 1) new_count = max(0, state.num_prefill_workers - 1)
await self._update_scaling_decision(num_prefill=new_count) await self._update_scaling_decision(num_prefill=new_count)
elif worker_type == "decode": elif worker_type == "decode":
new_count = max(0, self.num_decode_workers - 1) new_count = max(0, state.num_decode_workers - 1)
await self._update_scaling_decision(num_decode=new_count) await self._update_scaling_decision(num_decode=new_count)
if blocking: if blocking:
...@@ -307,6 +127,9 @@ class VirtualConnector(PlannerConnector): ...@@ -307,6 +127,9 @@ class VirtualConnector(PlannerConnector):
elif worker_type == "decode": elif worker_type == "decode":
num_decode = replicas num_decode = replicas
if num_prefill is None and num_decode is None:
return
# Update scaling decision if there are any changes # Update scaling decision if there are any changes
await self._update_scaling_decision( await self._update_scaling_decision(
num_prefill=num_prefill, num_decode=num_decode num_prefill=num_prefill, num_decode=num_decode
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# This test requires etcd and nats to be running, and the bindings to be installed
#
import asyncio
import logging
import pytest
from dynamo._core import DistributedRuntime, VirtualConnectorClient
from dynamo.planner import VirtualConnector
pytestmark = pytest.mark.pre_merge
logger = logging.getLogger(__name__)
NAMESPACE = "test_virtual_connector"
def get_runtime():
"""Get or create a DistributedRuntime instance.
This handles the case where a worker is already initialized (common in CI)
by using the detached() method to reuse the existing runtime.
"""
try:
# Try to use existing runtime (common in CI where tests run in same process)
_runtime_instance = DistributedRuntime.detached()
except Exception:
# If no existing runtime, create a new one
loop = asyncio.get_running_loop()
_runtime_instance = DistributedRuntime(loop, False)
return _runtime_instance
# Fails in CI after 30+ minutes with:
# pyo3_runtime.PanicException: Cannot drop a runtime in a context where blocking is not allowed. This happens when a runtime is dropped from within an asynchronous context.
# Disabling until we have a faster CI to iterate with.
@pytest.mark.skip("See comment in source")
def test_main():
"""
Connect a VirtualConnector (Dynamo Planner) and a VirtualConnectorClient (customer), and scale.
"""
asyncio.run(async_internal(get_runtime()))
async def next_scaling_decision(c):
"""Move the second decision in to a separate task so we can `.wait` for it."""
replicas = {"prefill": 5, "decode": 8}
await c.set_component_replicas(replicas, blocking=False)
async def async_internal(distributed_runtime):
# This is Dynamo Planner
c = VirtualConnector(distributed_runtime, NAMESPACE, "sglang")
await c._async_init()
replicas = {"prefill": 1, "decode": 2}
await c.set_component_replicas(replicas, blocking=False)
# This is the client
client = VirtualConnectorClient(distributed_runtime, NAMESPACE)
event = await client.get()
# Here the client would do the scaling
assert event.num_prefill_workers == 1
assert event.num_decode_workers == 2
assert event.decision_id == 0
await client.complete(event)
await c._wait_for_scaling_completion()
# Second decision with wait
task = asyncio.create_task(next_scaling_decision(c))
await client.wait()
await task
event = await client.get()
assert event.num_prefill_workers == 5
assert event.num_decode_workers == 8
assert event.decision_id == 1
await client.complete(event)
await c._wait_for_scaling_completion()
# Now scale to zero
replicas = {"prefill": 0, "decode": 0}
await c.set_component_replicas(replicas, blocking=False)
event = await client.get()
assert event.num_prefill_workers == 0
assert event.num_decode_workers == 0
await client.complete(event)
...@@ -123,26 +123,14 @@ kubectl apply -f disagg_planner.yaml -n {$NAMESPACE} ...@@ -123,26 +123,14 @@ kubectl apply -f disagg_planner.yaml -n {$NAMESPACE}
### Virtual Deployment ### Virtual Deployment
The SLA planner supports virtual deployment mode for customized environments (e.g., customized cluster) through the `VirtualConnector`. This connector enables the planner to communicate scaling decisions via ETCD without directly managing the deployment infrastructure. The SLA planner supports virtual deployment mode for customized environments (e.g., customized cluster) through the `VirtualConnector`. This connector enables the planner to communicate scaling decisions without directly managing the deployment infrastructure.
The `VirtualConnector` acts as a bridge between the SLA planner and external deployment environments. Instead of directly scaling Kubernetes resources, it writes scaling decisions to ETCD and waits for the deployment environment to acknowledge completion. The `VirtualConnector` acts as a bridge between the SLA planner and external deployment environments. Instead of directly scaling Kubernetes resources, it writes scaling decisions and waits for the deployment environment to acknowledge completion.
#### ETCD Communication Protocol
The VirtualConnector uses the following ETCD key structure under `/{dynamo_namespace}/planner/`:
**Planner Output Keys** (written by the planner):
- `num_prefill_workers`: Integer (stored as string) specifying the target number of prefill workers
- `num_decode_workers`: Integer (stored as string) specifying the target number of decode workers
- `decision_id`: Integer (stored as string) with incremental ID for each scaling decision (-1 if no decisions made)
**Deployment Environment Input Key** (written by the deployment environment):
- `scaled_decision_id`: Integer (stored as string) specifying the newest decision_id that has been successfully scaled
#### Scaling Decision Flow #### Scaling Decision Flow
1. **Decision Generation**: The planner calculates optimal worker counts and writes them to ETCD with an incremented `decision_id` 1. **Decision Generation**: The planner calculates optimal worker counts
2. **Change Detection**: The planner skips scaling if the target counts match current counts, logging: `"No scaling needed (prefill=X, decode=Y), skipping ETCD update"` 2. **Change Detection**: The planner skips scaling if the target counts match current counts, logging: `"No scaling needed (prefill=X, decode=Y)"`
3. **Readiness Check**: Before making new decisions, the planner verifies that previous scaling operations have completed by checking if `scaled_decision_id >= decision_id` 3. **Readiness Check**: Before making new decisions, the planner verifies that previous scaling operations have completed by checking if `scaled_decision_id >= decision_id`
4. **Timeout Handling**: If a scaling decision isn't acknowledged within 30 minutes (1800 seconds), the planner proceeds with new decisions anyway 4. **Timeout Handling**: If a scaling decision isn't acknowledged within 30 minutes (1800 seconds), the planner proceeds with new decisions anyway
5. **Completion Tracking**: The planner can optionally wait for scaling completion confirmation (blocking mode) 5. **Completion Tracking**: The planner can optionally wait for scaling completion confirmation (blocking mode)
...@@ -158,31 +146,23 @@ backend: "vllm" # or "sglang" ...@@ -158,31 +146,23 @@ backend: "vllm" # or "sglang"
#### Deployment Environment Requirements #### Deployment Environment Requirements
The external deployment environment must: The external deployment environment must use `VirtualConnectorClient`:
1. **Monitor ETCD**: Continuously watch the `/{dynamo_namespace}/planner/` prefix for scaling decisions
2. **Parse Decisions**: Read `num_prefill_workers`, `num_decode_workers`, and `decision_id` values
3. **Execute Scaling**: Apply the scaling decisions to the actual deployment infrastructure
4. **Acknowledge Completion**: Write the completed `decision_id` to `scaled_decision_id` when scaling is finished
#### Example Integration ```
from dynamo._core import DistributedRuntime, VirtualConnectorClient
```python client = VirtualConnectorClient(distributed_runtime, namespace)
# Deployment environment pseudo-code ```
async def monitor_scaling_decisions():
while True:
# Watch for changes in planner decisions
decision_id = await etcd.get("/my-namespace/planner/decision_id")
num_prefill = await etcd.get("/my-namespace/planner/num_prefill_workers")
num_decode = await etcd.get("/my-namespace/planner/num_decode_workers")
# Apply scaling to your infrastructure 1. **Monitor Planner**: Continuously watch for scaling decisions: `await client.wait()`. This blocks until there is a change.
await scale_prefill_workers(int(num_prefill)) 2. **Parse Decisions**: Read `num_prefill_workers` and `num_decode_workers` values: `decision = await client.get()`
await scale_decode_workers(int(num_decode)) 3. **Execute Scaling**: Apply the scaling decisions to the actual deployment infrastructure
4. **Acknowledge Completion**: Mark the decision completed when scaling is finished: `await client.complete(decision)`
# Acknowledge completion A scaling decision (returned by `client.get()`) contains the following fields, which are -1 if not set yet:
await etcd.put("/my-namespace/planner/scaled_decision_id", decision_id) - `num_prefill_workers`: Integer specifying the target number of prefill workers
- `num_decode_workers`: Integer specifying the target number of decode workers
- `decision_id`: Integer with incremental ID for each scaling decision
await asyncio.sleep(10) See `components/planner/test/test_virtual_connector.py` for a full example.
```
...@@ -125,6 +125,9 @@ name = "anyhow" ...@@ -125,6 +125,9 @@ name = "anyhow"
version = "1.0.99" version = "1.0.99"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100" checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100"
dependencies = [
"backtrace",
]
[[package]] [[package]]
name = "arbitrary" name = "arbitrary"
...@@ -152,6 +155,12 @@ dependencies = [ ...@@ -152,6 +155,12 @@ dependencies = [
"syn 2.0.106", "syn 2.0.106",
] ]
[[package]]
name = "arraydeque"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236"
[[package]] [[package]]
name = "arrayref" name = "arrayref"
version = "0.3.9" version = "0.3.9"
...@@ -479,6 +488,12 @@ version = "0.13.1" ...@@ -479,6 +488,12 @@ version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]]
name = "base64"
version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.22.1" version = "0.22.1"
...@@ -786,8 +801,10 @@ checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" ...@@ -786,8 +801,10 @@ checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d"
dependencies = [ dependencies = [
"android-tzdata", "android-tzdata",
"iana-time-zone", "iana-time-zone",
"js-sys",
"num-traits", "num-traits",
"serde", "serde",
"wasm-bindgen",
"windows-link", "windows-link",
] ]
...@@ -863,6 +880,15 @@ version = "1.0.4" ...@@ -863,6 +880,15 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
[[package]]
name = "colored"
version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e"
dependencies = [
"windows-sys 0.52.0",
]
[[package]] [[package]]
name = "compact_str" name = "compact_str"
version = "0.9.0" version = "0.9.0"
...@@ -887,6 +913,26 @@ dependencies = [ ...@@ -887,6 +913,26 @@ dependencies = [
"crossbeam-utils", "crossbeam-utils",
] ]
[[package]]
name = "config"
version = "0.15.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0faa974509d38b33ff89282db9c3295707ccf031727c0de9772038ec526852ba"
dependencies = [
"async-trait",
"convert_case",
"json5",
"pathdiff",
"ron",
"rust-ini",
"serde",
"serde-untagged",
"serde_json",
"toml 0.9.5",
"winnow",
"yaml-rust2",
]
[[package]] [[package]]
name = "console" name = "console"
version = "0.15.11" version = "0.15.11"
...@@ -906,12 +952,41 @@ version = "0.9.6" ...@@ -906,12 +952,41 @@ version = "0.9.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8"
[[package]]
name = "const-random"
version = "0.1.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359"
dependencies = [
"const-random-macro",
]
[[package]]
name = "const-random-macro"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e"
dependencies = [
"getrandom 0.2.16",
"once_cell",
"tiny-keccak",
]
[[package]] [[package]]
name = "constant_time_eq" name = "constant_time_eq"
version = "0.3.1" version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6"
[[package]]
name = "convert_case"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec182b0ca2f35d8fc196cf3404988fd8b8c739a4d270ff118a398feb0cbec1ca"
dependencies = [
"unicode-segmentation",
]
[[package]] [[package]]
name = "core-foundation" name = "core-foundation"
version = "0.9.4" version = "0.9.4"
...@@ -1320,6 +1395,15 @@ dependencies = [ ...@@ -1320,6 +1395,15 @@ dependencies = [
"pyo3", "pyo3",
] ]
[[package]]
name = "dlv-list"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "442039f5147480ba31067cb00ada1adae6892028e40e45fc5de7b7df6dcc1b5f"
dependencies = [
"const-random",
]
[[package]] [[package]]
name = "dunce" name = "dunce"
version = "1.0.5" version = "1.0.5"
...@@ -1420,6 +1504,8 @@ dependencies = [ ...@@ -1420,6 +1504,8 @@ dependencies = [
"memmap2", "memmap2",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"modelexpress-client",
"modelexpress-common",
"ndarray", "ndarray",
"nix 0.26.4", "nix 0.26.4",
"nixl-sys", "nixl-sys",
...@@ -1493,6 +1579,7 @@ dependencies = [ ...@@ -1493,6 +1579,7 @@ dependencies = [
"futures", "futures",
"local-ip-address", "local-ip-address",
"once_cell", "once_cell",
"parking_lot",
"prometheus", "prometheus",
"pyo3", "pyo3",
"pyo3-async-runtimes", "pyo3-async-runtimes",
...@@ -1823,7 +1910,7 @@ dependencies = [ ...@@ -1823,7 +1910,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"tempfile", "tempfile",
"toml", "toml 0.8.23",
"uncased", "uncased",
"version_check", "version_check",
] ]
...@@ -1850,6 +1937,12 @@ version = "1.0.7" ...@@ -1850,6 +1937,12 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foldhash"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]] [[package]]
name = "form_urlencoded" name = "form_urlencoded"
version = "1.2.2" version = "1.2.2"
...@@ -2378,6 +2471,18 @@ name = "hashbrown" ...@@ -2378,6 +2471,18 @@ name = "hashbrown"
version = "0.15.5" version = "0.15.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [
"foldhash",
]
[[package]]
name = "hashlink"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1"
dependencies = [
"hashbrown 0.15.5",
]
[[package]] [[package]]
name = "heck" name = "heck"
...@@ -2909,6 +3014,47 @@ version = "1.0.15" ...@@ -2909,6 +3014,47 @@ version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "jiff"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be1f93b8b1eb69c77f24bbb0afdf66f54b632ee39af40ca21c4365a1d7347e49"
dependencies = [
"jiff-static",
"jiff-tzdb-platform",
"log",
"portable-atomic",
"portable-atomic-util",
"serde",
"windows-sys 0.52.0",
]
[[package]]
name = "jiff-static"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03343451ff899767262ec32146f6d559dd759fdadf42ff0e227c7c48f72594b4"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.106",
]
[[package]]
name = "jiff-tzdb"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1283705eb0a21404d2bfd6eef2a7593d240bc42a0bdb39db0ad6fa2ec026524"
[[package]]
name = "jiff-tzdb-platform"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "875a5a69ac2bab1a891711cf5eccbec1ce0341ea805560dcd90b7a2e925132e8"
dependencies = [
"jiff-tzdb",
]
[[package]] [[package]]
name = "jobserver" name = "jobserver"
version = "0.1.34" version = "0.1.34"
...@@ -2945,6 +3091,17 @@ dependencies = [ ...@@ -2945,6 +3091,17 @@ dependencies = [
"unicode-general-category", "unicode-general-category",
] ]
[[package]]
name = "json5"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96b0db21af676c1ce64250b5f40f3ce2cf27e4e47cb91ed91eb6fe9350b430c1"
dependencies = [
"pest",
"pest_derive",
"serde",
]
[[package]] [[package]]
name = "jwalk" name = "jwalk"
version = "0.8.1" version = "0.8.1"
...@@ -3336,6 +3493,50 @@ dependencies = [ ...@@ -3336,6 +3493,50 @@ dependencies = [
"ws2_32-sys", "ws2_32-sys",
] ]
[[package]]
name = "modelexpress-client"
version = "0.1.0"
source = "git+https://github.com/ai-dynamo/modelexpress.git?rev=a232220bf268a475d293914d407f4ae186f443e3#a232220bf268a475d293914d407f4ae186f443e3"
dependencies = [
"anyhow",
"clap",
"colored",
"futures",
"modelexpress-common",
"prost",
"serde",
"serde_json",
"thiserror 2.0.16",
"tokio",
"tonic",
"tracing",
"tracing-subscriber",
"uuid",
]
[[package]]
name = "modelexpress-common"
version = "0.1.0"
source = "git+https://github.com/ai-dynamo/modelexpress.git?rev=a232220bf268a475d293914d407f4ae186f443e3#a232220bf268a475d293914d407f4ae186f443e3"
dependencies = [
"anyhow",
"async-trait",
"chrono",
"clap",
"config",
"hf-hub",
"jiff",
"prost",
"serde",
"serde_json",
"serde_yaml",
"thiserror 2.0.16",
"tokio",
"tonic",
"tonic-build",
"tracing",
]
[[package]] [[package]]
name = "monostate" name = "monostate"
version = "0.1.14" version = "0.1.14"
...@@ -3762,6 +3963,16 @@ version = "0.2.0" ...@@ -3762,6 +3963,16 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "ordered-multimap"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49203cdcae0030493bad186b28da2fa25645fa276a51b6fec8010d281e02ef79"
dependencies = [
"dlv-list",
"hashbrown 0.14.5",
]
[[package]] [[package]]
name = "os_info" name = "os_info"
version = "3.12.0" version = "3.12.0"
...@@ -3815,6 +4026,12 @@ version = "1.0.15" ...@@ -3815,6 +4026,12 @@ version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]]
name = "pathdiff"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
[[package]] [[package]]
name = "pear" name = "pear"
version = "0.2.9" version = "0.2.9"
...@@ -3853,6 +4070,50 @@ version = "2.3.2" ...@@ -3853,6 +4070,50 @@ version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220"
[[package]]
name = "pest"
version = "2.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21e0a3a33733faeaf8651dfee72dd0f388f0c8e5ad496a3478fa5a922f49cfa8"
dependencies = [
"memchr",
"thiserror 2.0.16",
"ucd-trie",
]
[[package]]
name = "pest_derive"
version = "2.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc58706f770acb1dbd0973e6530a3cff4746fb721207feb3a8a6064cd0b6c663"
dependencies = [
"pest",
"pest_generator",
]
[[package]]
name = "pest_generator"
version = "2.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d4f36811dfe07f7b8573462465d5cb8965fffc2e71ae377a33aecf14c2c9a2f"
dependencies = [
"pest",
"pest_meta",
"proc-macro2",
"quote",
"syn 2.0.106",
]
[[package]]
name = "pest_meta"
version = "2.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42919b05089acbd0a5dcd5405fb304d17d1053847b81163d09c4ad18ce8e8420"
dependencies = [
"pest",
"sha2",
]
[[package]] [[package]]
name = "petgraph" name = "petgraph"
version = "0.7.1" version = "0.7.1"
...@@ -4779,6 +5040,18 @@ dependencies = [ ...@@ -4779,6 +5040,18 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "ron"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91f7eff05f748767f183df4320a63d6936e9c6107d97c9e6bdd9784f4289c94"
dependencies = [
"base64 0.21.7",
"bitflags 2.9.3",
"serde",
"serde_derive",
]
[[package]] [[package]]
name = "rstest" name = "rstest"
version = "0.25.0" version = "0.25.0"
...@@ -4809,6 +5082,16 @@ dependencies = [ ...@@ -4809,6 +5082,16 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "rust-ini"
version = "0.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "796e8d2b6696392a43bea58116b667fb4c29727dc5abd27d6acf338bb4f688c7"
dependencies = [
"cfg-if 1.0.3",
"ordered-multimap",
]
[[package]] [[package]]
name = "rustc-demangle" name = "rustc-demangle"
version = "0.1.26" version = "0.1.26"
...@@ -5144,6 +5427,17 @@ dependencies = [ ...@@ -5144,6 +5427,17 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde-untagged"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34836a629bcbc6f1afdf0907a744870039b1e14c0561cb26094fa683b158eff3"
dependencies = [
"erased-serde",
"serde",
"typeid",
]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.219" version = "1.0.219"
...@@ -5207,6 +5501,15 @@ dependencies = [ ...@@ -5207,6 +5501,15 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "serde_spanned"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40734c41988f7306bb04f0ecf60ec0f3f1caa34290e4e8ea471dcd3346483b83"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "serde_urlencoded" name = "serde_urlencoded"
version = "0.7.1" version = "0.7.1"
...@@ -5251,6 +5554,19 @@ dependencies = [ ...@@ -5251,6 +5554,19 @@ dependencies = [
"syn 2.0.106", "syn 2.0.106",
] ]
[[package]]
name = "serde_yaml"
version = "0.9.34+deprecated"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
dependencies = [
"indexmap 2.11.0",
"itoa",
"ryu",
"serde",
"unsafe-libyaml",
]
[[package]] [[package]]
name = "sha1" name = "sha1"
version = "0.10.6" version = "0.10.6"
...@@ -5556,7 +5872,7 @@ dependencies = [ ...@@ -5556,7 +5872,7 @@ dependencies = [
"cfg-expr", "cfg-expr",
"heck", "heck",
"pkg-config", "pkg-config",
"toml", "toml 0.8.23",
"version-compare", "version-compare",
] ]
...@@ -5876,11 +6192,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -5876,11 +6192,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362"
dependencies = [ dependencies = [
"serde", "serde",
"serde_spanned", "serde_spanned 0.6.9",
"toml_datetime", "toml_datetime 0.6.11",
"toml_edit", "toml_edit",
] ]
[[package]]
name = "toml"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75129e1dc5000bfbaa9fee9d1b21f974f9fbad9daec557a521ee6e080825f6e8"
dependencies = [
"serde",
"serde_spanned 1.0.0",
"toml_datetime 0.7.0",
"toml_parser",
"winnow",
]
[[package]] [[package]]
name = "toml_datetime" name = "toml_datetime"
version = "0.6.11" version = "0.6.11"
...@@ -5890,6 +6219,15 @@ dependencies = [ ...@@ -5890,6 +6219,15 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "toml_datetime"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bade1c3e902f58d73d3f294cd7f20391c1cb2fbcb643b73566bc773971df91e3"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "toml_edit" name = "toml_edit"
version = "0.22.27" version = "0.22.27"
...@@ -5898,12 +6236,21 @@ checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" ...@@ -5898,12 +6236,21 @@ checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a"
dependencies = [ dependencies = [
"indexmap 2.11.0", "indexmap 2.11.0",
"serde", "serde",
"serde_spanned", "serde_spanned 0.6.9",
"toml_datetime", "toml_datetime 0.6.11",
"toml_write", "toml_write",
"winnow", "winnow",
] ]
[[package]]
name = "toml_parser"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cf893c33be71572e0e9aa6dd15e6677937abd686b066eac3f8cd3531688a627"
dependencies = [
"winnow",
]
[[package]] [[package]]
name = "toml_write" name = "toml_write"
version = "0.1.2" version = "0.1.2"
...@@ -6108,6 +6455,12 @@ version = "1.18.0" ...@@ -6108,6 +6455,12 @@ version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f"
[[package]]
name = "ucd-trie"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971"
[[package]] [[package]]
name = "ug" name = "ug"
version = "0.4.0" version = "0.4.0"
...@@ -6275,6 +6628,12 @@ version = "0.2.4" ...@@ -6275,6 +6628,12 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
[[package]] [[package]]
name = "untrusted" name = "untrusted"
version = "0.9.0" version = "0.9.0"
...@@ -6875,6 +7234,17 @@ version = "0.8.15" ...@@ -6875,6 +7234,17 @@ version = "0.8.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3"
[[package]]
name = "yaml-rust2"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2462ea039c445496d8793d052e13787f2b90e750b833afee748e601c17621ed9"
dependencies = [
"arraydeque",
"encoding_rs",
"hashlink",
]
[[package]] [[package]]
name = "yansi" name = "yansi"
version = "1.0.1" version = "1.0.1"
......
...@@ -39,6 +39,7 @@ either = { version = "1.13", features = ["serde"] } ...@@ -39,6 +39,7 @@ either = { version = "1.13", features = ["serde"] }
futures = { version = "0.3" } futures = { version = "0.3" }
local-ip-address = { version = "0.6" } local-ip-address = { version = "0.6" }
once_cell = { version = "1.20.3" } once_cell = { version = "1.20.3" }
parking_lot = { version = "0.12.4" }
rand = { version = "0.9" } rand = { version = "0.9" }
socket2 = { version = "0.6" } socket2 = { version = "0.6" }
serde = { version = "1" } serde = { version = "1" }
......
...@@ -54,6 +54,7 @@ mod engine; ...@@ -54,6 +54,7 @@ mod engine;
mod http; mod http;
mod llm; mod llm;
mod parsers; mod parsers;
mod planner;
type JsonServerStreamingIngress = type JsonServerStreamingIngress =
Ingress<SingleIn<serde_json::Value>, ManyOut<RsAnnotated<serde_json::Value>>>; Ingress<SingleIn<serde_json::Value>, ManyOut<RsAnnotated<serde_json::Value>>>;
...@@ -120,6 +121,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -120,6 +121,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::KvPushRouter>()?; m.add_class::<llm::kv::KvPushRouter>()?;
m.add_class::<llm::kv::KvPushRouterStream>()?; m.add_class::<llm::kv::KvPushRouterStream>()?;
m.add_class::<RouterMode>()?; m.add_class::<RouterMode>()?;
m.add_class::<planner::VirtualConnectorCoordinator>()?;
m.add_class::<planner::VirtualConnectorClient>()?;
m.add_class::<planner::PlannerDecision>()?;
engine::add_to_module(m)?; engine::add_to_module(m)?;
parsers::add_to_module(m)?; parsers::add_to_module(m)?;
...@@ -234,7 +238,7 @@ pub struct DistributedRuntime { ...@@ -234,7 +238,7 @@ pub struct DistributedRuntime {
impl DistributedRuntime { impl DistributedRuntime {
#[allow(dead_code)] #[allow(dead_code)]
fn inner(&self) -> &rs::DistributedRuntime { pub(crate) fn inner(&self) -> &rs::DistributedRuntime {
&self.inner &self.inner
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! TODO: This was ported directly from Python so some changes may be beneficial.
//! - Do we really want to convert to/from string before writing to etcd? It takes Vec<U8>
//! - We can probably replace wrap the whole InnerConnector in a Mutex, it should be uncontended.
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, SystemTime};
use parking_lot::Mutex;
use pyo3::{exceptions::PyException, prelude::*};
use super::to_pyerr;
use dynamo_runtime::CancellationToken;
use dynamo_runtime::transports::etcd::{Client, KvCache};
// All three AI's I asked agreed, this is the way
const NONE_SENTINEL: usize = usize::MAX;
struct InnerConnector {
check_interval: Duration,
max_wait_time: Duration,
max_retries: usize,
namespace: String,
etcd_client: Client,
// We need a mutex because we are `async`, but it should never be contended, planner should
// be calling it from max one thread at once.
kv_cache: Mutex<Option<Arc<KvCache>>>,
// On x86 AtomicUsize at Relaxed compiles to usize, it's free
num_prefill_workers: AtomicUsize,
num_decode_workers: AtomicUsize,
decision_id: AtomicUsize, // NONE_SENTINEL means not set
first_skip_timestamp: AtomicUsize, // In seconds since epoch, with NONE_SENTINEL
}
#[pyclass]
#[derive(Clone)]
pub struct VirtualConnectorCoordinator(Arc<InnerConnector>);
#[pymethods]
impl VirtualConnectorCoordinator {
#[new]
pub fn new(
runtime: super::DistributedRuntime,
dynamo_namespace: &str,
check_interval_secs: usize,
max_wait_time_secs: usize,
max_retries: usize,
) -> Self {
let check_interval = Duration::from_secs(check_interval_secs as u64);
let max_wait_time = Duration::from_secs(max_wait_time_secs as u64);
let c = InnerConnector {
check_interval,
max_wait_time,
max_retries,
namespace: dynamo_namespace.to_string(),
etcd_client: runtime
.inner()
.etcd_client()
.expect("Planner cannot run without etcd / in static mode"),
kv_cache: Mutex::new(None),
num_prefill_workers: AtomicUsize::new(NONE_SENTINEL),
num_decode_workers: AtomicUsize::new(NONE_SENTINEL),
decision_id: AtomicUsize::new(NONE_SENTINEL),
first_skip_timestamp: AtomicUsize::new(NONE_SENTINEL),
};
Self(Arc::new(c))
}
#[pyo3(signature = ())]
pub fn read_state(&self) -> PlannerDecision {
let current_prefill = load(&self.0.num_prefill_workers);
let current_decode = load(&self.0.num_decode_workers);
let current_decision_id = load(&self.0.decision_id);
PlannerDecision {
num_prefill_workers: if current_prefill != NONE_SENTINEL {
current_prefill as isize
} else {
-1
},
num_decode_workers: if current_decode != NONE_SENTINEL {
current_decode as isize
} else {
-1
},
decision_id: if current_decision_id != NONE_SENTINEL {
current_decision_id as isize
} else {
-1
},
}
}
#[pyo3(signature = ())]
pub fn async_init<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let prefix = root_key(&self.0.namespace);
let inner = self.0.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let kv_cache = KvCache::new(inner.etcd_client.clone(), prefix, HashMap::new())
.await
.map_err(to_pyerr)?;
*inner.kv_cache.lock() = Some(Arc::new(kv_cache));
inner.load_current_state().await.map_err(to_pyerr)
})
}
#[pyo3(signature = (num_prefill, num_decode))]
pub fn update_scaling_decision<'p>(
&self,
py: Python<'p>,
num_prefill: Option<usize>,
num_decode: Option<usize>,
) -> PyResult<Bound<'p, PyAny>> {
let inner = self.0.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let current_prefill = load(&inner.num_prefill_workers);
let has_prefill_changed = num_prefill.is_some_and(|n| n != current_prefill);
let current_decode = load(&inner.num_decode_workers);
let has_decode_changed = num_decode.is_some_and(|n| n != current_decode);
if !(has_prefill_changed || has_decode_changed) {
tracing::info!(
current_prefill,
current_decode,
"No scaling needed, skipping update"
);
return Ok(());
}
// Check if previous scaling is ready
let is_ready = inner.is_scaling_ready().await;
if !is_ready {
let current_time = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_err(to_pyerr)?
.as_secs() as usize;
// If this is the first time we're skipping, record the timestamp
if load(&inner.first_skip_timestamp) == NONE_SENTINEL {
inner
.first_skip_timestamp
.store(current_time, Ordering::Relaxed);
tracing::info!(
decision_id = load(&inner.decision_id),
"Previous scaling decision not ready, starting to track skip time"
)
}
// Check if we've been waiting too long
let time_waited = current_time - load(&inner.first_skip_timestamp);
if time_waited < inner.max_wait_time.as_secs() as usize {
tracing::warn!(
decision_id = load(&inner.decision_id),
time_waited,
"Previous scaling decision not ready, skipping new decision",
);
return Ok(());
} else {
tracing::warn!(
decision_id = load(&inner.decision_id),
scaling_max_wait_time = inner.max_wait_time.as_secs(),
"Previous scaling decision not ready, proceeding with new decision anyway"
)
}
}
// Reset the skip timestamp since we're making a decision
inner
.first_skip_timestamp
.store(NONE_SENTINEL, Ordering::Relaxed);
let Some(kv_cache) = inner.kv_cache.lock().as_ref().cloned() else {
return Err(PyErr::new::<PyException, _>(
"Call async_init before using this object",
));
};
if let Some(new_prefill) = num_prefill {
inner
.num_prefill_workers
.store(new_prefill, Ordering::Relaxed);
kv_cache
.put(
"num_prefill_workers",
new_prefill.to_string().into_bytes(),
None,
)
.await
.map_err(to_pyerr)?;
}
if let Some(new_decode) = num_decode {
inner
.num_decode_workers
.store(new_decode, Ordering::Relaxed);
kv_cache
.put(
"num_decode_workers",
new_decode.to_string().into_bytes(),
None,
)
.await
.map_err(to_pyerr)?;
}
let new_decision_id = match load(&inner.decision_id) {
NONE_SENTINEL => {
inner.decision_id.store(0, Ordering::Relaxed);
0
}
_ => {
inner.decision_id.fetch_add(1, Ordering::Relaxed);
load(&inner.decision_id)
}
};
kv_cache
.put(
"decision_id",
new_decision_id.to_string().into_bytes(),
None,
)
.await
.map_err(to_pyerr)?;
tracing::info!(
decision_id = new_decision_id,
?num_prefill,
?num_decode,
"Updated scaling decision"
);
Ok(())
})
}
#[pyo3(signature = ())]
pub fn wait_for_scaling_completion<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.0.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let Some(kv_cache) = inner.kv_cache.lock().as_ref().cloned() else {
return Err(PyErr::new::<PyException, _>(
"Call async_init before using this object",
));
};
for _ in 0..inner.max_retries {
match kv_cache.get("scaled_decision_id").await {
None => {
tokio::time::sleep(inner.check_interval).await;
}
Some(scaled_decision_id_bytes) => {
match String::from_utf8_lossy(&scaled_decision_id_bytes).parse::<usize>() {
Ok(scaled_decision_id) => {
let current = load(&inner.decision_id);
if scaled_decision_id >= current || current == NONE_SENTINEL {
tracing::info!(
decision_id = current,
"Scaling decision completed"
);
return Ok(());
}
}
Err(err) => {
tracing::warn!(%err, "Failed to parse scaled_decision_id");
}
}
}
}
}
tracing::warn!(
decision_id = load(&inner.decision_id),
scaling_max_wait_time = inner.max_wait_time.as_secs(),
"Timeout waiting for scaling decision to complete"
);
Ok(())
})
}
}
impl InnerConnector {
async fn load_current_state(&self) -> PyResult<()> {
let Some(kv_cache) = self.kv_cache.lock().as_ref().cloned() else {
return Err(PyErr::new::<PyException, _>(
"Call async_init before using this object",
));
};
let all_values = kv_cache.get_all().await;
if let Some(v) = all_values.get("num_prefill_workers") {
match String::from_utf8_lossy(v).parse() {
Ok(vv) => self.num_prefill_workers.store(vv, Ordering::Relaxed),
Err(err) => {
tracing::error!(
"Failed to parse num_prefill_workers from ETCD, using default 0: {err}"
);
self.num_prefill_workers.store(0, Ordering::Relaxed);
}
}
}
if let Some(v) = all_values.get("num_decode_workers") {
match String::from_utf8_lossy(v).parse() {
Ok(vv) => self.num_decode_workers.store(vv, Ordering::Relaxed),
Err(err) => {
tracing::error!(
"Failed to parse num_decode_workers from ETCD, using default 0: {err}"
);
self.num_decode_workers.store(0, Ordering::Relaxed);
}
}
}
if let Some(v) = all_values.get("decision_id") {
match String::from_utf8_lossy(v).parse() {
Ok(vv) => self.decision_id.store(vv, Ordering::Relaxed),
Err(err) => {
tracing::error!(
"Failed to parse decision_id from ETCD, using default None: {err}"
);
self.decision_id.store(NONE_SENTINEL, Ordering::Relaxed);
}
}
}
Ok(())
}
/// Check if the previous scaling decision has been completed"""
async fn is_scaling_ready(&self) -> bool {
let current = load(&self.decision_id);
// If this is the first decision, it's always ready
if current == NONE_SENTINEL {
return true;
}
let Some(kv_cache) = self.kv_cache.lock().as_ref().cloned() else {
tracing::warn!("Call async_init before using this object");
return false;
};
// Check if scaled_decision_id matches current decision_id
if let Some(scaled_decision_id_bytes) = kv_cache.get("scaled_decision_id").await {
match String::from_utf8_lossy(&scaled_decision_id_bytes).parse::<usize>() {
Ok(scaled_decision_id) => {
// Success case
// We checked for NONE_SENTINEL earlier
return scaled_decision_id >= current;
}
Err(err) => {
tracing::warn!(%err, "Failed to parse scaled_decision_id");
}
}
}
// If no scaled_decision_id exists, assume not ready
false
}
}
#[pyclass]
#[derive(Clone)]
pub struct VirtualConnectorClient(Arc<InnerClient>);
#[pymethods]
impl VirtualConnectorClient {
#[new]
pub fn new(runtime: super::DistributedRuntime, dynamo_namespace: &str) -> Self {
let c = InnerClient {
etcd_client: runtime
.inner
.etcd_client()
.expect("Planner cannot run without etcd / in static mode"),
key: root_key(dynamo_namespace),
cancellation_token: runtime.inner().child_token(),
};
Self(Arc::new(c))
}
/// Get the current values as a PlannerDecision
#[pyo3(signature = ())]
pub fn get<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.0.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner.get().await.map_err(to_pyerr)
})
}
/// Mark this scaling decision complete
#[pyo3(signature = (event))]
pub fn complete<'p>(
&self,
py: Python<'p>,
event: PlannerDecision,
) -> PyResult<Bound<'p, PyAny>> {
let inner = self.0.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner.complete(event).await.map_err(to_pyerr)
})
}
/// Wait until a new PlannerDecision appears. Will block until there is one to fetch.
/// Use `get` to fetch the decision.
#[pyo3(signature = ())]
pub fn wait<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.0.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner.wait().await.map_err(to_pyerr)
})
}
}
#[pyclass]
#[derive(Clone, Copy)]
/// The decision Planner made. The client should make necessary changes to the environment to make
/// this true, and then call `complete` on the VirtualConnectorClient.
pub struct PlannerDecision {
#[pyo3(get)]
pub num_prefill_workers: isize,
#[pyo3(get)]
pub num_decode_workers: isize,
#[pyo3(get)]
pub decision_id: isize,
}
struct InnerClient {
key: String,
etcd_client: Client,
cancellation_token: CancellationToken,
}
impl InnerClient {
/// Fetch the latest scaling decision
async fn get(&self) -> anyhow::Result<PlannerDecision> {
let mut num_prefill_workers = -1;
let mut num_decode_workers = -1;
let mut decision_id = -1;
for kv in self.etcd_client.kv_get_prefix(&self.key).await? {
match kv.key_str()? {
x if x.ends_with("/num_prefill_workers") => {
num_prefill_workers = kv.value_str()?.parse()?;
}
x if x.ends_with("/num_decode_workers") => {
num_decode_workers = kv.value_str()?.parse()?;
}
x if x.ends_with("/decision_id") => {
decision_id = kv.value_str()?.parse()?;
}
x if x.ends_with("/scaled_decision_id") => {
// This is the client's response, it doesn't go in PlannerDecision
}
x => {
tracing::warn!(
unexpected_key = x,
root = self.key,
"Unexpected key in planner etcd"
);
}
}
}
Ok(PlannerDecision {
num_prefill_workers,
num_decode_workers,
decision_id,
})
}
/// Mark this decision as having been handled.
async fn complete(&self, event: PlannerDecision) -> anyhow::Result<()> {
self.etcd_client
.kv_put(
format!("{}scaled_decision_id", self.key),
event.decision_id.to_string().as_bytes(),
None,
)
.await
}
/// Wait for a new scaling decision. Use `get` when this returns to fetch the values.
async fn wait(&self) -> anyhow::Result<()> {
let watcher = self.etcd_client.kv_watch_prefix(&self.key).await?;
let (_prefix, _watcher, mut receiver) = watcher.dissolve();
tokio::select! {
_ = receiver.recv() => {
Ok(())
}
_ = self.cancellation_token.cancelled() => {
anyhow::bail!("VirtualConnectorClient.wait: Runtime shutdown");
},
}
}
}
// This compiles to a `mov`, it's basically free
fn load(a: &AtomicUsize) -> usize {
a.load(Ordering::Relaxed)
}
fn root_key(namespace: &str) -> String {
format!("/{namespace}/planner/")
}
...@@ -1359,3 +1359,48 @@ class EntrypointArgs: ...@@ -1359,3 +1359,48 @@ class EntrypointArgs:
""" """
... ...
class PlannerDecision:
"""A request from planner to client to perform a scaling action.
Fields: num_prefill_workers, num_decode_workers, decision_id.
-1 in any of those fields mean not set, usually because planner hasn't decided anything yet.
Call VirtualConnectorClient.complete(event) when action is completed.
"""
...
class VirtualConnectorCoordinator:
"""Internal planner virtual connector component"""
def __init__(self, runtime: DistributedRuntime, dynamo_namespace: str, check_interval_secs: int, max_wait_time_secs: int, max_retries: int) -> None:
...
async def async_init(self) -> None:
"""Call this before using the object"""
...
def read_state(self) -> PlannerDecision:
"""Get the current values. Most for test / debug."""
...
async def update_scaling_decision(self, num_prefill: Optional[int] = None, num_decode: Optional[int] = None) -> None:
...
async def wait_for_scaling_completion(self) -> None:
...
class VirtualConnectorClient:
"""How a client discovers planner requests and marks them complete"""
def __init__(self, runtime: DistributedRuntime, dynamo_namespace: str) -> None:
...
async def get(self) -> PlannerDecision:
...
async def complete(self, decision: PlannerDecision) -> None:
...
async def wait(self) -> None:
"""Blocks until there is a new decision to fetch using 'get'"""
...
...@@ -321,9 +321,25 @@ impl Client { ...@@ -321,9 +321,25 @@ impl Client {
Ok(()) Ok(())
} }
/// Like kv_get_and_watch_prefix but only for new changes, does not include existing values.
pub async fn kv_watch_prefix(
&self,
prefix: impl AsRef<str> + std::fmt::Display,
) -> Result<PrefixWatcher> {
self.watch_internal(prefix, false).await
}
pub async fn kv_get_and_watch_prefix( pub async fn kv_get_and_watch_prefix(
&self, &self,
prefix: impl AsRef<str> + std::fmt::Display, prefix: impl AsRef<str> + std::fmt::Display,
) -> Result<PrefixWatcher> {
self.watch_internal(prefix, true).await
}
async fn watch_internal(
&self,
prefix: impl AsRef<str> + std::fmt::Display,
include_existing: bool,
) -> Result<PrefixWatcher> { ) -> Result<PrefixWatcher> {
let mut kv_client = self.client.kv_client(); let mut kv_client = self.client.kv_client();
let mut watch_client = self.client.watch_client(); let mut watch_client = self.client.watch_client();
...@@ -352,16 +368,23 @@ impl Client { ...@@ -352,16 +368,23 @@ impl Client {
) )
.await?; .await?;
let kvs = get_response.take_kvs(); let kvs = if include_existing {
tracing::trace!("initial kv count: {:?}", kvs.len()); let kvs = get_response.take_kvs();
tracing::trace!("initial kv count: {:?}", kvs.len());
kvs
} else {
vec![]
};
let (tx, rx) = mpsc::channel(32); let (tx, rx) = mpsc::channel(32);
self.rt.spawn(async move { self.rt.spawn(async move {
for kv in kvs { if include_existing {
if tx.send(WatchEvent::Put(kv)).await.is_err() { for kv in kvs {
// receiver is already closed if tx.send(WatchEvent::Put(kv)).await.is_err() {
return; // receiver is already closed
return;
}
} }
} }
...@@ -519,7 +542,7 @@ impl KvCache { ...@@ -519,7 +542,7 @@ impl KvCache {
} }
// Start watching for changes // Start watching for changes
// we won't miss events bewteen the initial push and the watcher starting because // we won't miss events between the initial push and the watcher starting because
// client.kv_get_and_watch_prefix() will get all kv pairs and put them back again // client.kv_get_and_watch_prefix() will get all kv pairs and put them back again
let watcher = client.kv_get_and_watch_prefix(&prefix).await?; let watcher = client.kv_get_and_watch_prefix(&prefix).await?;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment