"mmgen/datasets/vscode:/vscode.git/clone" did not exist on "b7536f78b8574f78c30bc1603be632dabfff5541"
Unverified Commit 7d5d6f8c authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files
parent 0715d469
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
block-size: 64
max-model-len: 16384
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.Processor.chat/completions
port: 8000
Processor:
router: kv-load
common-configs: [model, block-size]
VllmWorker:
remote-prefill: true
conditional-disagg: false
ServiceArgs:
workers: 2
resources:
gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config]
PrefillWorker:
max-num-batched-tokens: 16384
ServiceArgs:
workers: 2
resources:
gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
import math
import numpy as np
from tqdm import tqdm
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main(args):
output_data = []
def get_isl_osl(t):
isl_osl_ratio = (args.isl_osl_ratio_min + args.isl_osl_ratio_max) / 2 + (
args.isl_osl_ratio_max - args.isl_osl_ratio_min
) / 2 * np.sin(2 * np.pi / args.isl_osl_ratio_period * t)
logger.info(f"isl_osl_ratio at {t:.2f}: {isl_osl_ratio:.2f}")
if np.random.uniform(0, 1) < isl_osl_ratio:
return (args.isl1, args.osl1)
else:
return (args.isl2, args.osl2)
total_hash_ids = np.arange(args.total_blocks)
for t in tqdm(range(0, args.time_duration, args.process_interval)):
t_e = min(t + args.process_interval, args.time_duration)
request_rate = (args.request_rate_min + args.request_rate_max) / 2 + (
args.request_rate_max - args.request_rate_min
) / 2 * np.sin(2 * np.pi / args.request_rate_period * t)
logger.info(f"request_rate at {t:.2f}: {request_rate:.2f}")
num_requests = np.random.poisson(request_rate * (t_e - t))
for req_idx in range(num_requests):
t_req = t + (t_e - t) * req_idx / num_requests
isl, osl = get_isl_osl(t_req)
output_data.append(
{
"timestamp": t_req * 1000, # in ms
"input_length": isl,
"output_length": osl,
"hash_ids": np.random.choice(
total_hash_ids, size=math.ceil(isl / args.block_size)
).tolist(),
}
)
with open(args.output_file, "w") as f:
for item in output_data:
f.write(json.dumps(item) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate synthetic dataset with sinusoidal request rate and isl/osl"
)
parser.add_argument(
"--block-size", type=int, default=512, help="Block size for hashing"
)
parser.add_argument(
"--total-blocks",
type=int,
default=10000,
help="ISL prompt blocks are randomly sampled from this range",
)
parser.add_argument(
"--output-file",
type=str,
default=None,
help="Output file name (in jsonl format)",
)
parser.add_argument(
"--time-duration",
type=int,
default=100,
help="Time duration of the dataset in seconds",
)
parser.add_argument(
"--process-interval",
type=int,
default=1,
help="Sampling interval used to generate the dataset",
)
# request rate parameters
# for the process interval at [t, t + process_interval), the number of requests to generate is sampled
# from a poison distribution with the following parameters:
# request_rate(t) = (min + max) / 2 + (max - min) / 2 * sin(2 * pi / period * t)
# num_requests[t, t + process_interval) ~ Poisson(request_rate(t) * process_interval)
# requests are uniformly distributed in the interval [t, t + process_interval)
parser.add_argument(
"--request-rate-min",
type=float,
default=5,
help="Minimum request rate in requests per second",
)
parser.add_argument(
"--request-rate-max",
type=float,
default=10,
help="Maximum request rate in requests per second",
)
parser.add_argument(
"--request-rate-period",
type=float,
default=10,
help="Period of the sinusoidal request rate in seconds",
)
# isl/osl parameters
# isl/osl is randomly sampled from two candidates following the isl-osl-ratio.
# at time t, the isl-osl-ratio is calculated as:
# isl-osl-ratio(t) = (min + max) / 2 + (max - min) / 2 * sin(2 * pi / period * t)
# Then, we sample [isl1/osl1, isl2/osl2] from the distribution [isl-osl-ratio(t), 1 - isl-osl-ratio(t)]
parser.add_argument(
"--isl1", type=int, default=100, help="Minimum input sequence length"
)
parser.add_argument(
"--osl1", type=int, default=2000, help="Minimum output sequence length"
)
parser.add_argument(
"--isl2", type=int, default=5000, help="Maximum input sequence length"
)
parser.add_argument(
"--osl2", type=int, default=100, help="Maximum output sequence length"
)
parser.add_argument(
"--isl-osl-ratio-min",
type=float,
default=0.2,
help="Minimum ratio of input sequence length to output sequence length",
)
parser.add_argument(
"--isl-osl-ratio-max",
type=float,
default=0.8,
help="Maximum ratio of input sequence length to output sequence length",
)
parser.add_argument(
"--isl-osl-ratio-period",
type=float,
default=10,
help="Period of the sinusoidal input/output sequence length ratio",
)
args = parser.parse_args()
if args.output_file is None:
args.output_file = f"sin_b{args.block_size}_t{args.time_duration}_rr{args.request_rate_min}-{args.request_rate_max}-{args.request_rate_period}_io{args.isl1}{args.osl1}-{args.isl2}{args.osl2}-{args.isl_osl_ratio_min}-{args.isl_osl_ratio_max}-{args.isl_osl_ratio_period}.jsonl"
main(args)
<!--
SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Planner
The planner is a component that monitors the state of the system and makes adjustments to workers to ensure that the system is running efficiently. Currently, planner can scale up and down the number of vllm workers based on the kv cache load and prefill queue size:
* Backend:
* local ✅
* kubernetes ✅
* LLM framework:
* vllm ✅
* tensorrt-llm ❌
* SGLang ❌
* llama.cpp ❌
* Serving type:
* Aggregated ✅
* Disaggregated ✅
* Planner actions:
* Load-based scaling up/down prefill/decode workers ✅
* SLA-based scaling up/down prefill/decode workers ❌
* Adjusting engine knobs ❌
## Load-based Scaling Up/Down Prefill/Decode Workers
To adjust the number of prefill/decode workers, planner monitors the following metrics:
* Prefill worker: planner monitors the number of requests pending in the prefill queue to estimate the prefill workload.
* Decode/aggregated worker: planner monitors the average KV cache utilization rate to estimate the decode/aggregated workload.
Every `metric-pulling-interval`, planner will gather the aforementioned metrics. Every `adjustment-interval`, planner compares the aggregated metrics in this interval with pre-set thresholds and decide to scale up/down prefill/decode workers. To avoid over-compensation, planner only changes the number of workers by 1 in one adjustment interval. In addition, when the number of workers is being adjusted, the planner will block the metric pulling and adjustment.
To scale up a prefill/decode worker, planner just need to launch the worker in the correct namespace. The auto-discovery mechanism will pick up the workers and add them to the routers. To scale down a prefill worker, planner send a SIGTERM signal to the prefill worker. The prefill worker store the signal and exit when it finishes the current request pulled from the prefill queue. This ensures that no remote prefill request is dropped. To scale down a decode worker, currently, planner revoke the etcd lease of the decode worker. When the etcd lease is revoked, the corresponding decode worker will be immediately removed from the router and will not get any new requests. The decode worker will then finish all the current requests in their original stream and exit gracefully.
There are two additional rules set by planner to prevent over-compensation:
1. After a new decode worker is added, since it needs time to populate the kv cache, planner will not scale down the number of decode workers in the next `NEW_DECODE_WORKER_GRACE_PERIOD=3` adjustment intervals.
1. We do not scale up prefill worker if the prefill queue size is estimated to reduce below the `--prefill-queue-scale-up-threshold` within the next `NEW_PREFILL_WORKER_QUEUE_BUFFER_PERIOD=3` adjustment intervals following the trend observed in the current adjustment interval.
## Usage
After you've deployed a dynamo graph - you can start the planner with the following command:
```bash
python components/planner.py <arguments>
```
Planner takes the following arguments:
* `--namespace` (str, default: "dynamo"): Namespace planner will look at
* `--served-model-name` (str, default: "vllm"): Model name that is being served`
* `--no-operation` (flag): Do not make any adjustments, just observe the metrics and log to tensorboard
* `--log-dir` (str, default: None): Tensorboard logging directory
* `--adjustment-interval` (int, default: 30): Interval in seconds between scaling adjustments
* `--metric-pulling-interval` (int, default: 1): Interval in seconds between metric pulls
* `--max-gpu-budget` (int, default: 8): Maximum number of GPUs to use, planner will not scale up more than this number of GPUs for prefill plus decode workers
* `--min-gpu-budget` (int, default: 1): Minimum number of GPUs to use, planner will not scale down below this number of GPUs for prefill or decode workers
* `--decode-kv-scale-up-threshold` (float, default: 0.9): KV cache utilization threshold to scale up decode workers
* `--decode-kv-scale-down-threshold` (float, default: 0.5): KV cache utilization threshold to scale down decode workers
* `--prefill-queue-scale-up-threshold` (float, default: 0.5): Queue utilization threshold to scale up prefill workers
* `--prefill-queue-scale-down-threshold` (float, default: 0.2): Queue utilization threshold to scale down prefill workers
* `--decode-engine-num-gpu` (int, default: 1): Number of GPUs per decode engine
* `--prefill-engine-num-gpu` (int, default: 1): Number of GPUs per prefill engine
### Tensorboard
Planner logs to tensorboard to visualize the metrics and the scaling actions. You can start tensorboard with the following command:
```bash
tensorboard --logdir=<path-to-tensorboard-log-dir>
```
## Backends
We currently only support one backend:
1. `local` - uses circus to start/stop worker subprocesses
### Local Backend
Circus is a Python program which can be used to monitor and control processes and sockets. Dynamo serve uses circus to start each node in a graph and monitors each subprocesses. We leverage a core feature to do this called `Watcher`. A `Watcher` is the target program that you would like to run (which in our case is `serve_dynamo.py`). When planner decides to scale up or down, it will either add or remove a watcher from the existing `circus`.
> [!NOTE]
> Although circus allows you to `increment` an existing watcher, it was not designed to allow variables to be passed in which does not allow us to schedule on a GPU. So instead we start a new watcher per process. When planner decdies to add or remove a worker, we have logic to handle this adding/removing and incrementing/decrementing the workers.
#### Statefile
The statefile is a json file created when initially running `dynamo serve` and is filled in with custom leases in `serve_dynamo`. Each worker is named `{namespace}_{component_name}` when it is initially created. The `resources` come from the allocator and allows us to keep track of which GPUs are available. This statefile is read in by the LocalConnector and after each planner update we make the relevant change to the statefile. Currently, this statefile is locally saved in `~/.dynamo/state/{namespace}.json` (or in `DYN_LOCAL_STATE_DIR `) and is automatically cleaned up when the arbiter dies.
When one Decode worker is spun up, the statefile looks like:
```json
{
"dynamo_VllmWorker": {..., resources={...}},
}
```
Now another decode worker is added:
```json
{
"dynamo_VllmWorker": {..., resources={...}},
"dynamo_VllmWorker_1": {..., resources={...}},
}
```
Then one decode worker is removed:
```json
{
"dynamo_VllmWorker": {..., resources={...}},
}
```
If the last decode worker is removed, the statefile looks like:
```json
{
"dynamo_VllmWorker": {...},
}
```
Note that we keep the initial non-suffix entry in order to know what cmd we will need to spin up another worker. This is the same for prefill workers as well.
> [!NOTE]
> At the moment - planner work best if your initial replicas per worker are 1. This is because if you specify replicas > 1 when you initially start `dynamo serve`, the current implementation in `serving.py` starts each process in the same watcher.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import json
import logging
import os
import time
from datetime import datetime
from typing import Any, List
import numpy as np
from rich.console import Console
from rich.table import Table
from tensorboardX import SummaryWriter
from utils.prefill_queue import PrefillQueue
from dynamo.llm import KvMetricsAggregator
from dynamo.planner import LocalConnector
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.sdk.lib.logging import configure_server_logging
configure_server_logging()
logger = logging.getLogger(__name__)
# will not decrease decode worker number within 3 adjustment interval after a new decode worker
# is added. this is to leave time for the new decode worker to populate its kv cache.
NEW_DECODE_WORKER_GRACE_PERIOD = 3
# we do not scale up prefill worker if the prefill queue size is estimated to reduce within
# --prefill-queue-scale-up-threshold within the next NEW_PREFILL_WORKER_QUEUE_BUFFER_PERIOD
# adjustment intervals following the trend observed in the current adjustment interval.
# this is to account for the time for prefill workers to start.
NEW_PREFILL_WORKER_QUEUE_BUFFER_PERIOD = 3
class Planner:
def __init__(self, runtime: DistributedRuntime, args: argparse.Namespace):
self.runtime = runtime
self.args = args
self.namespace = args.namespace
self.connector = LocalConnector(args.namespace, runtime)
self._prefill_queue_nats_server = os.getenv(
"NATS_SERVER", "nats://localhost:4222"
)
self._prefill_queue_stream_name = self.args.served_model_name
self.prefill_client: Any | None = None
self.workers_client: Any | None = None
self.p_endpoints: List[int] = []
self.d_endpoints: List[int] = []
self.decode_worker_remaining_grace_period = 0
if args.log_dir is None:
args.log_dir = f"logs/{datetime.now().strftime('%m%d_%H%M%S')}"
self.writer = SummaryWriter(args.log_dir)
logger.info(f"Components present in namespace: {args.namespace}")
self.init_time = time.time()
async def set_metric_aggregator(self):
# TODO: separate KV metrics and prefill metrics
kv_listener = self.runtime.namespace(self.namespace).component("VllmWorker")
await kv_listener.create_service()
self.metrics_aggregator = KvMetricsAggregator(kv_listener)
async def get_workers_info(self):
try:
if self.prefill_client is None:
self.prefill_client = (
await self.runtime.namespace(self.namespace)
.component("PrefillWorker")
.endpoint("mock")
.client()
)
# TODO: remove this sleep after rust client() is blocking until watching state
await asyncio.sleep(0.1)
# TODO: use etcd events instead of pulling endpoints_ids
p_endpoints = self.prefill_client.endpoint_ids()
except Exception:
p_endpoints = []
logger.info("No prefill workers found, operating in aggregated mode")
try:
if self.workers_client is None:
self.workers_client = (
await self.runtime.namespace(self.namespace)
.component("VllmWorker")
.endpoint("generate")
.client()
)
# TODO: remove this sleep after rust client() is blocking until watching state
await asyncio.sleep(0.1)
# TODO: use etcd events instead of pulling endpoints_ids
d_endpoints = self.workers_client.endpoint_ids()
except Exception as e:
raise RuntimeError(f"Failed to get decode worker endpoints: {e}")
return p_endpoints, d_endpoints
async def reset_adjustment_interval(self):
logger.info(
f"Reset metrics for new adjustment interval at t={time.time() - self.init_time:.1f}s"
)
self.p_endpoints, self.d_endpoints = await self.get_workers_info()
logger.info(
f"Number of prefill workers: {len(self.p_endpoints)}, number of decode workers: {len(self.d_endpoints)}"
)
self.metrics_collection_time = []
self.prefill_queue_load = []
self.kv_load = []
self.last_adjustment_time = time.time()
async def collect_metrics(self):
logger.info(f"Collecting metrics at t={time.time() - self.init_time:.1f}s")
# collect prefill queue load
try:
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
prefill_queue_size = await prefill_queue.get_queue_size()
measure_time = time.time() - self.init_time
self.prefill_queue_load.append(prefill_queue_size)
logger.info(
f"Collected prefill queue size at t={measure_time:.1f}s: {int(prefill_queue_size)}"
)
self.writer.add_scalar(
"prefill_queue_size", prefill_queue_size, measure_time
)
except Exception as e:
logger.info(f"Failed to collect prefill queue size metrics: {e}")
# collect kv load
total_active_requests: int = 0
total_queued_requests: int = 0
metrics = await self.metrics_aggregator.get_metrics()
try:
prev_kv_load_len = len(self.kv_load)
for endpoint in metrics.endpoints:
kv_load = getattr(endpoint, "gpu_cache_usage_perc", 0.0)
num_requests_waiting = getattr(endpoint, "num_requests_waiting", 0)
total_queued_requests += num_requests_waiting
request_active_slots = getattr(endpoint, "request_active_slots", None)
if request_active_slots:
total_active_requests += request_active_slots
if num_requests_waiting > 0:
# estimate kv load after waiting requests are scheduled based on current isl/osl
# TODO: use actual isl/osl estimation after the request_active_slot bug in disaggg is fixed
# Currently, we assume each request uses 0.02 kv cache
# kv_load = kv_load * (request_active_slots + num_requests_waiting) / request_active_slots
kv_load = kv_load + 0.02 * num_requests_waiting
self.kv_load.append(kv_load)
measure_time = time.time() - self.init_time
logger.info(
f"Collected kv load at t={measure_time:.1f}s: {self.kv_load[prev_kv_load_len:]} (act/pnd req: {total_active_requests}/{total_queued_requests})"
)
average_kv_load = np.mean(self.kv_load[prev_kv_load_len:])
self.writer.add_scalar("average_kv_load", average_kv_load, measure_time)
self.writer.add_scalar(
"total_queued_requests", total_queued_requests, measure_time
)
except Exception as e:
logger.info(f"Failed to collect kv load metrics: {e}")
p_endpoints, d_endpoints = await self.get_workers_info()
self.writer.add_scalar(
"num_prefill_workers", len(p_endpoints), time.time() - self.init_time
)
self.writer.add_scalar(
"num_decode_workers", len(d_endpoints), time.time() - self.init_time
)
curr_gpu_usage = (
len(p_endpoints) * self.args.prefill_engine_num_gpu
+ len(d_endpoints) * self.args.decode_engine_num_gpu
)
self.writer.add_scalar("num_gpu", curr_gpu_usage, time.time() - self.init_time)
self.metrics_collection_time.append(time.time())
async def make_adjustments(self):
# Note: all adjustments are blocking. Non-blocking adjustment and metric pulling
# make the optimization problem too complex and should not be needed in most cases.
logger.info(f"Making adjustments at t={time.time() - self.init_time:.1f}s")
# check if decode/prefill workers is still the same
# note that we only check length as endpoint ids might change
new_p_endpoints, new_d_endpoints = await self.get_workers_info()
if len(new_p_endpoints) != len(self.p_endpoints) or len(new_d_endpoints) != len(
self.d_endpoints
):
logger.info("Decode/prefill workers changed, no adjustments will be made")
return
# compute current gpu usage
curr_gpu_usage = (
len(self.p_endpoints) * self.args.prefill_engine_num_gpu
+ len(self.d_endpoints) * self.args.decode_engine_num_gpu
)
logger.info(f"Current engines use {curr_gpu_usage} GPUs")
avg_prefill_queue_load = np.mean(self.prefill_queue_load)
avg_kv_load = np.mean(self.kv_load)
# first check if we need to scale down any workers
if (
avg_prefill_queue_load < self.args.prefill_queue_scale_down_threshold
and len(self.p_endpoints) > self.args.min_endpoint
):
logger.info(
f"Average prefill queue load ({avg_prefill_queue_load:.2f}) is below threshold ({self.args.prefill_queue_scale_down_threshold:.2f}), scaling down prefill workers"
)
success = await self.connector.remove_component("PrefillWorker")
if success:
curr_gpu_usage -= self.args.prefill_engine_num_gpu
else:
logger.info("Failed to scale down prefill worker")
if (
avg_kv_load < self.args.decode_kv_scale_down_threshold
and len(self.d_endpoints) > self.args.min_endpoint
):
if self.decode_worker_remaining_grace_period > 0:
logger.info(
f"Decode worker remaining grace period is {self.decode_worker_remaining_grace_period}, skipping scale down"
)
else:
logger.info(
f"Average kv load ({avg_kv_load:.2f}) is below threshold ({self.args.decode_kv_scale_down_threshold:.2f}), scaling down decode workers"
)
success = await self.connector.remove_component("VllmWorker")
if success:
curr_gpu_usage -= self.args.decode_engine_num_gpu
else:
logger.info("Failed to scale down decode worker")
# check if we need to scale up workers
# we first check for prefill worker because prefill queueing can also lead
# to high kv load on decode workers
if (
avg_prefill_queue_load > self.args.prefill_queue_scale_up_threshold
and curr_gpu_usage + self.args.prefill_engine_num_gpu
<= self.args.max_gpu_budget
):
logger.info(
f"Average prefill queue load ({avg_prefill_queue_load:.2f}) is above threshold ({self.args.prefill_queue_scale_up_threshold:.2f})"
)
# check prefill queue size trend:
prefill_queue_size_change = (
self.prefill_queue_load[-1] - self.prefill_queue_load[0]
)
predicted_prefill_future_queue_size = (
self.prefill_queue_load[-1]
+ prefill_queue_size_change * NEW_PREFILL_WORKER_QUEUE_BUFFER_PERIOD
)
if (
predicted_prefill_future_queue_size
> self.args.prefill_queue_scale_up_threshold
):
logger.info(
f"Predicted future prefill queue size ({predicted_prefill_future_queue_size:.2f}) is also above threshold ({self.args.prefill_queue_scale_up_threshold:.2f}), scaling up prefill workers"
)
success = await self.connector.add_component("PrefillWorker")
if success:
curr_gpu_usage += self.args.prefill_engine_num_gpu
else:
logger.info("Failed to scale up prefill worker")
else:
logger.info(
f"Predicted future prefill queue size ({predicted_prefill_future_queue_size:.2f}) is below threshold ({self.args.prefill_queue_scale_up_threshold:.2f}), skipping prefill worker scaling"
)
if (
avg_kv_load > self.args.decode_kv_scale_up_threshold
and curr_gpu_usage + self.args.decode_engine_num_gpu
<= self.args.max_gpu_budget
):
logger.info(
f"Average kv load ({avg_kv_load:.2f}) is above threshold ({self.args.decode_kv_scale_up_threshold:.2f}), scaling up decode workers"
)
success = await self.connector.add_component("VllmWorker")
if success:
curr_gpu_usage += self.args.decode_engine_num_gpu
self.decode_worker_remaining_grace_period = (
NEW_DECODE_WORKER_GRACE_PERIOD
)
else:
logger.info("Failed to scale up decode worker")
# no adjustment needed, just log the current metrics
if (
avg_prefill_queue_load > self.args.prefill_queue_scale_down_threshold
and avg_prefill_queue_load < self.args.prefill_queue_scale_up_threshold
):
logger.info(
f"Average prefill queue load ({avg_prefill_queue_load:.2f}) is within threshold, no prefill worker scaling needed"
)
if (
avg_kv_load > self.args.decode_kv_scale_down_threshold
and avg_kv_load < self.args.decode_kv_scale_up_threshold
):
logger.info(
f"Average kv load ({avg_kv_load:.2f}) is within threshold, no decode worker scaling needed"
)
logger.info(f"Engines after adjustment use {curr_gpu_usage} GPUs")
if self.decode_worker_remaining_grace_period > 0:
self.decode_worker_remaining_grace_period -= 1
async def run(self):
"""Main loop for the planner"""
await self.set_metric_aggregator()
await self.reset_adjustment_interval()
while True:
current_time = time.time()
# Collect metrics at each metric pulling interval
if (
len(self.metrics_collection_time) == 0
or current_time - self.metrics_collection_time[-1]
>= self.args.metric_pulling_interval
):
await self.collect_metrics()
# Check if it's time for adjustment
if (
current_time - self.last_adjustment_time
>= self.args.adjustment_interval
):
if not self.args.no_operation:
# blockingly make adjustments to avoid overcompensation
await self.make_adjustments()
await self.reset_adjustment_interval()
# Sleep to avoid busy waiting
await asyncio.sleep(self.args.metric_pulling_interval / 10)
@dynamo_worker()
async def start_planner(runtime: DistributedRuntime, args: argparse.Namespace):
planner = Planner(runtime, args)
console = Console()
table = Table()
table.add_column("Component", style="cyan")
table.add_column("Endpoint", style="green")
components = await runtime.etcd_client().kv_get_prefix(args.namespace)
for component in components:
try:
data = json.loads(component["value"].decode("utf-8"))
if "component" in data:
name = data["component"]
endpoint = data["endpoint"]
table.add_row(name, endpoint)
except Exception:
# Some entries may not be valid JSON or might be binary data
pass
console.print(table)
await planner.run()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--namespace",
type=str,
default="dynamo",
help="Namespace planner will look at",
)
parser.add_argument(
"--served-model-name",
type=str,
default="vllm",
help="Model name that is being served (used for prefill queue name)",
)
parser.add_argument(
"--no-operation",
action="store_true",
help="Do not make any adjustments, just observe the metrics",
)
parser.add_argument(
"--log-dir",
type=str,
default=None,
help="Tensorboard logging directory",
)
parser.add_argument(
"--adjustment-interval",
type=int,
default=10,
help="Interval in seconds between scaling adjustments",
)
parser.add_argument(
"--metric-pulling-interval",
type=int,
default=1,
help="Interval in seconds between metric pulls",
)
parser.add_argument(
"--max-gpu-budget",
type=int,
default=8,
help="Maximum number of GPUs to use",
)
parser.add_argument(
"--min-endpoint",
type=int,
default=1,
help="Minimum number of endpoints to keep for prefill/decode workers",
)
parser.add_argument(
"--decode-kv-scale-up-threshold",
type=float,
default=0.9,
help="KV cache utilization threshold to scale up decode workers",
)
parser.add_argument(
"--decode-kv-scale-down-threshold",
type=float,
default=0.5,
help="KV cache utilization threshold to scale down decode workers",
)
parser.add_argument(
"--prefill-queue-scale-up-threshold",
type=float,
default=5,
help="Queue utilization threshold to scale up prefill workers",
)
parser.add_argument(
"--prefill-queue-scale-down-threshold",
type=float,
default=0.2,
help="Queue utilization threshold to scale down prefill workers",
)
parser.add_argument(
"--decode-engine-num-gpu",
type=int,
default=1,
help="Number of GPUs per decode engine",
)
parser.add_argument(
"--prefill-engine-num-gpu",
type=int,
default=1,
help="Number of GPUs per prefill engine",
)
args = parser.parse_args()
asyncio.run(start_planner(args))
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import asyncio import asyncio
import logging import logging
import os import os
import signal
import sys import sys
from pydantic import BaseModel from pydantic import BaseModel
...@@ -31,6 +30,7 @@ from vllm.inputs.data import TokensPrompt ...@@ -31,6 +30,7 @@ from vllm.inputs.data import TokensPrompt
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.service import LeaseConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -43,6 +43,7 @@ class RequestType(BaseModel): ...@@ -43,6 +43,7 @@ class RequestType(BaseModel):
dynamo={ dynamo={
"enabled": True, "enabled": True,
"namespace": "dynamo", "namespace": "dynamo",
"custom_lease": LeaseConfig(ttl=1), # 1 second
}, },
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1, workers=1,
...@@ -75,9 +76,6 @@ class PrefillWorker: ...@@ -75,9 +76,6 @@ class PrefillWorker:
) )
self.engine_args.enable_prefix_caching = False self.engine_args.enable_prefix_caching = False
signal.signal(signal.SIGTERM, self.shutdown_vllm_engine)
signal.signal(signal.SIGINT, self.shutdown_vllm_engine)
@async_on_start @async_on_start
async def async_init(self): async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args( self._engine_context = build_async_engine_client_from_engine_args(
...@@ -91,7 +89,7 @@ class PrefillWorker: ...@@ -91,7 +89,7 @@ class PrefillWorker:
metadata = self.engine_client.nixl_metadata metadata = self.engine_client.nixl_metadata
self._metadata_store = NixlMetadataStore("dynamo", runtime) self._metadata_store = NixlMetadataStore("dynamo", runtime)
await self._metadata_store.put(metadata.engine_id, metadata) await self._metadata_store.put(metadata.engine_id, metadata)
task = asyncio.create_task(self.prefill_queue_handler()) self.task = asyncio.create_task(self.prefill_queue_handler())
def prefill_queue_handler_cb(fut): def prefill_queue_handler_cb(fut):
try: try:
...@@ -101,12 +99,13 @@ class PrefillWorker: ...@@ -101,12 +99,13 @@ class PrefillWorker:
logger.error(f"[ERROR] prefill queue handler failed: {e!r}") logger.error(f"[ERROR] prefill queue handler failed: {e!r}")
sys.exit(1) sys.exit(1)
task.add_done_callback(prefill_queue_handler_cb) self.task.add_done_callback(prefill_queue_handler_cb)
self.lease = dynamo_context["lease"]
logger.info("PrefillWorker initialized") logger.info("PrefillWorker initialized")
def shutdown_vllm_engine(self, signum, frame): def shutdown_vllm_engine(self):
"""Shutdown the background loop""" """Shutdown the background loop"""
logger.info(f"Received signal {signum}, shutting down") logger.info("Shutting down vllm engine")
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
self.engine_client.close() self.engine_client.close()
...@@ -144,6 +143,20 @@ class PrefillWorker: ...@@ -144,6 +143,20 @@ class PrefillWorker:
) )
async for _ in self.generate(prefill_request): async for _ in self.generate(prefill_request):
pass pass
is_valid = await self.lease.is_valid()
if not is_valid:
logger.info(
"Shutdown requested, checking if engine has any pending prefill sending requests"
)
while True:
if not await self.engine_client.has_unfinished_requests():
break
logger.info(
"Engine has pending prefill sending requests, rechecking in 1 second..."
)
await asyncio.sleep(1)
self.shutdown_vllm_engine()
break
async def generate(self, request: RemotePrefillRequest): async def generate(self, request: RemotePrefillRequest):
sampling_params = request.sampling_params sampling_params = request.sampling_params
......
...@@ -24,12 +24,13 @@ from transformers import AutoTokenizer ...@@ -24,12 +24,13 @@ from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.logging import check_required_workers from utils.logging import check_required_workers
from utils.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest from utils.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest
from utils.vllm import parse_vllm_args from utils.vllm import RouterType, parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.llm import KvMetricsAggregator
from dynamo.runtime import EtcdKvCache from dynamo.runtime import EtcdKvCache
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
...@@ -94,7 +95,7 @@ class Processor(ProcessMixIn): ...@@ -94,7 +95,7 @@ class Processor(ProcessMixIn):
.client() .client()
) )
if self.engine_args.router == "kv": if self.engine_args.router == RouterType.KV:
router_ns, router_name = Router.dynamo_address() # type: ignore router_ns, router_name = Router.dynamo_address() # type: ignore
self.router_client = ( self.router_client = (
await runtime.namespace(router_ns) await runtime.namespace(router_ns)
...@@ -105,12 +106,32 @@ class Processor(ProcessMixIn): ...@@ -105,12 +106,32 @@ class Processor(ProcessMixIn):
await check_required_workers(self.worker_client, self.min_workers) await check_required_workers(self.worker_client, self.min_workers)
kv_listener = runtime.namespace("dynamo").component("VllmWorker")
await kv_listener.create_service()
self.metrics_aggregator = KvMetricsAggregator(kv_listener)
self.etcd_kv_cache = await EtcdKvCache.create( self.etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(), runtime.etcd_client(),
"/dynamo/processor/", "/dynamo/processor/",
{"router": self.engine_args.router}, {"router": self.engine_args.router},
) )
async def _get_kv_load(self):
metrics = await self.metrics_aggregator.get_metrics()
kv_load = {}
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
kv_load[worker_id] = getattr(endpoint, "gpu_cache_usage_perc", 0.0)
return kv_load
async def _get_pending_requests(self):
metrics = await self.metrics_aggregator.get_metrics()
pending_requests = {}
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
pending_requests[worker_id] = getattr(endpoint, "num_requests_waiting", 0)
return pending_requests
async def _generate( async def _generate(
self, self,
raw_request: Union[CompletionRequest, ChatCompletionRequest], raw_request: Union[CompletionRequest, ChatCompletionRequest],
...@@ -125,8 +146,9 @@ class Processor(ProcessMixIn): ...@@ -125,8 +146,9 @@ class Processor(ProcessMixIn):
engine_prompt, engine_prompt,
sampling_params, sampling_params,
) = await self._parse_raw_request(raw_request) ) = await self._parse_raw_request(raw_request)
# TODO: queue request at processor when engines are full
router_mode = (await self.etcd_kv_cache.get("router")).decode() router_mode = (await self.etcd_kv_cache.get("router")).decode()
if router_mode == "kv": if router_mode == RouterType.KV:
router_generator = await self.router_client.generate( router_generator = await self.router_client.generate(
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json() Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json()
) )
...@@ -157,7 +179,7 @@ class Processor(ProcessMixIn): ...@@ -157,7 +179,7 @@ class Processor(ProcessMixIn):
).model_dump_json(), ).model_dump_json(),
int(worker_id), int(worker_id),
) )
elif router_mode == "random": elif router_mode == RouterType.RANDOM:
engine_generator = await self.worker_client.generate( engine_generator = await self.worker_client.generate(
vLLMGenerateRequest( vLLMGenerateRequest(
engine_prompt=engine_prompt, engine_prompt=engine_prompt,
...@@ -165,7 +187,7 @@ class Processor(ProcessMixIn): ...@@ -165,7 +187,7 @@ class Processor(ProcessMixIn):
request_id=request_id, request_id=request_id,
).model_dump_json() ).model_dump_json()
) )
elif router_mode == "round-robin": elif router_mode == RouterType.ROUND_ROBIN:
engine_generator = await self.worker_client.round_robin( engine_generator = await self.worker_client.round_robin(
vLLMGenerateRequest( vLLMGenerateRequest(
engine_prompt=engine_prompt, engine_prompt=engine_prompt,
...@@ -173,7 +195,32 @@ class Processor(ProcessMixIn): ...@@ -173,7 +195,32 @@ class Processor(ProcessMixIn):
request_id=request_id, request_id=request_id,
).model_dump_json() ).model_dump_json()
) )
elif router_mode == RouterType.KV_LOAD:
# route to worker with least kv load
# TODO: move the router to a separate file and clean up processor.py
try:
kv_load = await self._get_kv_load()
best_worker_id = min(kv_load, key=kv_load.get)
logger.info(f"Routing to worker {best_worker_id} (kv load: {kv_load})")
engine_generator = await self.worker_client.direct(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json(),
int(best_worker_id),
)
except Exception as e:
logger.info(
f"Error finding worker with least kv load: {e}, fallback to random"
)
engine_generator = await self.worker_client.generate(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json()
)
output = self._generate_responses(engine_generator, request_type) output = self._generate_responses(engine_generator, request_type)
async for response in await self._stream_response( async for response in await self._stream_response(
......
...@@ -24,7 +24,7 @@ from components.prefill_worker import PrefillWorker ...@@ -24,7 +24,7 @@ from components.prefill_worker import PrefillWorker
from utils.nixl import NixlMetadataStore from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue from utils.prefill_queue import PrefillQueue
from utils.protocol import MyRequestOutput, vLLMGenerateRequest from utils.protocol import MyRequestOutput, vLLMGenerateRequest
from utils.vllm import parse_vllm_args from utils.vllm import RouterType, parse_vllm_args
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args, build_async_engine_client_from_engine_args,
) )
...@@ -82,7 +82,7 @@ class VllmWorker: ...@@ -82,7 +82,7 @@ class VllmWorker:
logger.info("Pipeline parallel size is not supported yet, setting to 1") logger.info("Pipeline parallel size is not supported yet, setting to 1")
self.engine_args.pipeline_parallel_size = 1 self.engine_args.pipeline_parallel_size = 1
if self.engine_args.router == "kv": if self.engine_args.router == RouterType.KV:
if not self.engine_args.enable_prefix_caching: if not self.engine_args.enable_prefix_caching:
logger.info( logger.info(
"When using KV router, prefix caching must be enabled, setting to True" "When using KV router, prefix caching must be enabled, setting to True"
...@@ -107,8 +107,6 @@ class VllmWorker: ...@@ -107,8 +107,6 @@ class VllmWorker:
self.engine_client = await self._engine_context.__aenter__() self.engine_client = await self._engine_context.__aenter__()
else: else:
raise RuntimeError("Failed to initialize engine client") raise RuntimeError("Failed to initialize engine client")
if self.engine_args.router == "kv":
assert self.engine_client is not None, "engine_client was not initialized"
self.engine_client.set_metrics_publisher(self.metrics_publisher) self.engine_client.set_metrics_publisher(self.metrics_publisher)
# Initially send dummy metrics to kick start, # Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered # vLLM will not update stat until forward pass is triggered
...@@ -144,7 +142,6 @@ class VllmWorker: ...@@ -144,7 +142,6 @@ class VllmWorker:
else: else:
self.disaggregated_router = None self.disaggregated_router = None
self.lease = dynamo_context.get("lease")
logger.info("VllmWorker has been initialized") logger.info("VllmWorker has been initialized")
def shutdown_vllm_engine(self, signum, frame): def shutdown_vllm_engine(self, signum, frame):
...@@ -161,13 +158,12 @@ class VllmWorker: ...@@ -161,13 +158,12 @@ class VllmWorker:
async def create_metrics_publisher_endpoint(self): async def create_metrics_publisher_endpoint(self):
component = dynamo_context["component"] component = dynamo_context["component"]
if self.lease is None: lease = dynamo_context["lease"]
if lease is None:
logger.info("Creating metrics publisher endpoint with primary lease") logger.info("Creating metrics publisher endpoint with primary lease")
else: else:
logger.info( logger.info(f"Creating metrics publisher endpoint with lease: {lease}")
f"Creating metrics publisher endpoint with lease: {self.lease.id()}" await self.metrics_publisher.create_endpoint(component, lease)
)
await self.metrics_publisher.create_endpoint(component, self.lease)
def get_remote_prefill_request_callback(self): def get_remote_prefill_request_callback(self):
# TODO: integrate prefill_queue to dynamo endpoint # TODO: integrate prefill_queue to dynamo endpoint
......
...@@ -92,7 +92,11 @@ class NATSQueue: ...@@ -92,7 +92,11 @@ class NATSQueue:
await self._js.stream_info(self._stream_name) await self._js.stream_info(self._stream_name)
except NotFoundError: except NotFoundError:
await self._js.add_stream( await self._js.add_stream(
name=self._stream_name, subjects=[self._subject] name=self._stream_name,
subjects=[self._subject],
# TODO: make these configurable and add guide to set these values
max_bytes=1073741824, # 1GB total storage limit
max_msgs=1000000, # 1 million messages limit
) )
# Create persistent subscriber # Create persistent subscriber
self._subscriber = await self._js.pull_subscribe( self._subscriber = await self._js.pull_subscribe(
......
...@@ -20,6 +20,13 @@ from vllm.utils import FlexibleArgumentParser ...@@ -20,6 +20,13 @@ from vllm.utils import FlexibleArgumentParser
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
class RouterType:
RANDOM = "random"
ROUND_ROBIN = "round-robin"
KV = "kv"
KV_LOAD = "kv-load"
def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
config = ServiceConfig.get_instance() config = ServiceConfig.get_instance()
vllm_args = config.as_args(service_name, prefix=prefix) vllm_args = config.as_args(service_name, prefix=prefix)
...@@ -27,8 +34,13 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: ...@@ -27,8 +34,13 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
parser.add_argument( parser.add_argument(
"--router", "--router",
type=str, type=str,
choices=["random", "round-robin", "kv"], choices=[
default="round-robin", RouterType.RANDOM,
RouterType.ROUND_ROBIN,
RouterType.KV,
RouterType.KV_LOAD,
],
default=RouterType.RANDOM,
help="Router type to use for scheduling requests to workers", help="Router type to use for scheduling requests to workers",
) )
parser.add_argument( parser.add_argument(
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import uvloop
from dynamo.runtime import DistributedRuntime, dynamo_worker
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
parser = argparse.ArgumentParser()
parser.add_argument("lease_id", type=int, help="Lease ID to revoke")
args = parser.parse_args()
await init(runtime, args.lease_id)
async def init(runtime: DistributedRuntime, lease_id: int):
client = runtime.etcd_client()
await client.revoke_lease(lease_id)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
...@@ -164,6 +164,14 @@ impl PyLease { ...@@ -164,6 +164,14 @@ impl PyLease {
fn revoke(&self) { fn revoke(&self) {
self.inner.revoke(); self.inner.revoke();
} }
fn is_valid<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let lease = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let is_valid = lease.is_valid().await.map_err(to_pyerr)?;
Ok(is_valid)
})
}
} }
#[pymethods] #[pymethods]
...@@ -530,6 +538,14 @@ impl EtcdClient { ...@@ -530,6 +538,14 @@ impl EtcdClient {
Ok(py_list) Ok(py_list)
}) })
} }
fn revoke_lease<'p>(&self, py: Python<'p>, lease_id: i64) -> PyResult<Bound<'p, PyAny>> {
let client = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
client.revoke_lease(lease_id).await.map_err(to_pyerr)?;
Ok(())
})
}
} }
#[pymethods] #[pymethods]
......
...@@ -68,6 +68,12 @@ class PyLease: ...@@ -68,6 +68,12 @@ class PyLease:
""" """
... ...
def is_valid(self) -> bool:
"""
Check if the lease is still valid (not revoked)
"""
...
class EtcdClient: class EtcdClient:
""" """
Etcd is used for discovery in the DistributedRuntime Etcd is used for discovery in the DistributedRuntime
......
...@@ -74,6 +74,13 @@ impl Lease { ...@@ -74,6 +74,13 @@ impl Lease {
pub fn revoke(&self) { pub fn revoke(&self) {
self.cancel_token.cancel(); self.cancel_token.cancel();
} }
/// Check if the lease is still valid (not revoked)
pub async fn is_valid(&self) -> Result<bool> {
// A lease is valid if its cancellation token has not been triggered
// We can use try_cancelled which returns immediately with a boolean
Ok(!self.cancel_token.is_cancelled())
}
} }
impl Client { impl Client {
...@@ -150,6 +157,15 @@ impl Client { ...@@ -150,6 +157,15 @@ impl Client {
.await? .await?
} }
// Revoke an etcd lease given its lease id. A wrapper over etcd_client::LeaseClient::revoke
pub async fn revoke_lease(&self, lease_id: i64) -> Result<()> {
let lease_client = self.client.lease_client();
self.runtime
.secondary()
.spawn(revoke_lease(lease_client, lease_id))
.await?
}
pub async fn kv_create( pub async fn kv_create(
&self, &self,
key: String, key: String,
......
...@@ -44,6 +44,17 @@ pub async fn create_lease( ...@@ -44,6 +44,17 @@ pub async fn create_lease(
}) })
} }
/// Revoke a lease given its lease id. A wrapper over etcd_client::LeaseClient::revoke
pub async fn revoke_lease(mut lease_client: LeaseClient, lease_id: i64) -> Result<()> {
match lease_client.revoke(lease_id).await {
Ok(_) => Ok(()),
Err(e) => {
tracing::warn!("failed to revoke lease: {:?}", e);
Err(e.into())
}
}
}
/// Task to keep leases alive. /// Task to keep leases alive.
/// ///
/// If this task returns an error, the cancellation token will be invoked on the runtime. /// If this task returns an error, the cancellation token will be invoked on the runtime.
......
...@@ -32,6 +32,7 @@ dependencies = [ ...@@ -32,6 +32,7 @@ dependencies = [
"ai-dynamo-runtime==0.2.0", "ai-dynamo-runtime==0.2.0",
"distro", "distro",
"typer", "typer",
"circus>=0.17.0",
] ]
classifiers = [ classifiers = [
...@@ -76,7 +77,7 @@ requires = ["hatchling"] ...@@ -76,7 +77,7 @@ requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["deploy/dynamo/sdk/src/dynamo"] packages = ["deploy/dynamo/sdk/src/dynamo", "components/planner/src/dynamo"]
# This section is for including the binaries in the wheel package # This section is for including the binaries in the wheel package
# but doesn't make them executable scripts in the venv bin directory # but doesn't make them executable scripts in the venv bin directory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment