Unverified Commit d23d48ba authored by hhzhang16's avatar hhzhang16 Committed by GitHub
Browse files

feat: Deploy SLA planner to Kubernetes (#2135)


Signed-off-by: default avatarHongkuan Zhou <tedzhouhk@gmail.com>
Co-authored-by: default avatarhongkuan <hongkuanz@nvidia.com>
Co-authored-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Co-authored-by: default avatarHongkuan Zhou <tedzhouhk@gmail.com>
parent ca0035fb
../../docs/architecture/pre_deployment_profiling.md
\ No newline at end of file
...@@ -589,9 +589,9 @@ if __name__ == "__main__": ...@@ -589,9 +589,9 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--backend", "--backend",
type=str, type=str,
default="vllm_v1", default="vllm",
choices=["vllm_v1"], choices=["vllm"],
help="backend type, currently support [vllm_v1]", help="backend type, currently support [vllm]",
) )
parser.add_argument( parser.add_argument(
"--config", "--config",
......
...@@ -80,7 +80,7 @@ class VllmV1ConfigModifier: ...@@ -80,7 +80,7 @@ class VllmV1ConfigModifier:
config = deepcopy(config) config = deepcopy(config)
# set metadata name # set metadata name
config["metadata"]["name"] = "vllm-v1-agg" config["metadata"]["name"] = "vllm-agg"
# disable planner # disable planner
if "Planner" in config["spec"]["services"]: if "Planner" in config["spec"]["services"]:
...@@ -89,16 +89,16 @@ class VllmV1ConfigModifier: ...@@ -89,16 +89,16 @@ class VllmV1ConfigModifier:
if target == "prefill": if target == "prefill":
# convert prefill worker into decode worker # convert prefill worker into decode worker
config["spec"]["services"][ config["spec"]["services"][
WORKER_COMPONENT_NAMES["vllm_v1"].decode_worker WORKER_COMPONENT_NAMES["vllm"].decode_worker
] = config["spec"]["services"][ ] = config["spec"]["services"][
WORKER_COMPONENT_NAMES["vllm_v1"].prefill_worker WORKER_COMPONENT_NAMES["vllm"].prefill_worker
] ]
del config["spec"]["services"][ del config["spec"]["services"][
WORKER_COMPONENT_NAMES["vllm_v1"].prefill_worker WORKER_COMPONENT_NAMES["vllm"].prefill_worker
] ]
args = config["spec"]["services"][ args = config["spec"]["services"][
WORKER_COMPONENT_NAMES["vllm_v1"].decode_worker WORKER_COMPONENT_NAMES["vllm"].decode_worker
]["extraPodSpec"]["mainContainer"]["args"] ]["extraPodSpec"]["mainContainer"]["args"]
args = break_arguments(args) args = break_arguments(args)
...@@ -112,18 +112,18 @@ class VllmV1ConfigModifier: ...@@ -112,18 +112,18 @@ class VllmV1ConfigModifier:
if "--no-enable-prefix-caching" not in args: if "--no-enable-prefix-caching" not in args:
args = append_argument(args, "--no-enable-prefix-caching") args = append_argument(args, "--no-enable-prefix-caching")
config["spec"]["services"][WORKER_COMPONENT_NAMES["vllm_v1"].decode_worker][ config["spec"]["services"][WORKER_COMPONENT_NAMES["vllm"].decode_worker][
"extraPodSpec" "extraPodSpec"
]["mainContainer"]["args"] = join_arguments(args) ]["mainContainer"]["args"] = join_arguments(args)
elif target == "decode": elif target == "decode":
# delete prefill worker # delete prefill worker
del config["spec"]["services"][ del config["spec"]["services"][
WORKER_COMPONENT_NAMES["vllm_v1"].prefill_worker WORKER_COMPONENT_NAMES["vllm"].prefill_worker
] ]
args = config["spec"]["services"][ args = config["spec"]["services"][
WORKER_COMPONENT_NAMES["vllm_v1"].decode_worker WORKER_COMPONENT_NAMES["vllm"].decode_worker
]["extraPodSpec"]["mainContainer"]["args"] ]["extraPodSpec"]["mainContainer"]["args"]
args = break_arguments(args) args = break_arguments(args)
...@@ -134,13 +134,13 @@ class VllmV1ConfigModifier: ...@@ -134,13 +134,13 @@ class VllmV1ConfigModifier:
if "--no-enable-prefix-caching" in args: if "--no-enable-prefix-caching" in args:
args.remove("--no-enable-prefix-caching") args.remove("--no-enable-prefix-caching")
config["spec"]["services"][WORKER_COMPONENT_NAMES["vllm_v1"].decode_worker][ config["spec"]["services"][WORKER_COMPONENT_NAMES["vllm"].decode_worker][
"extraPodSpec" "extraPodSpec"
]["mainContainer"]["args"] = join_arguments(args) ]["mainContainer"]["args"] = join_arguments(args)
# set num workers to 1 # set num workers to 1
decode_worker_config = config["spec"]["services"][ decode_worker_config = config["spec"]["services"][
WORKER_COMPONENT_NAMES["vllm_v1"].decode_worker WORKER_COMPONENT_NAMES["vllm"].decode_worker
] ]
decode_worker_config["replicas"] = 1 decode_worker_config["replicas"] = 1
...@@ -150,16 +150,16 @@ class VllmV1ConfigModifier: ...@@ -150,16 +150,16 @@ class VllmV1ConfigModifier:
def set_config_tp_size(cls, config: dict, tp_size: int): def set_config_tp_size(cls, config: dict, tp_size: int):
config = deepcopy(config) config = deepcopy(config)
config["spec"]["services"][WORKER_COMPONENT_NAMES["vllm_v1"].decode_worker][ config["spec"]["services"][WORKER_COMPONENT_NAMES["vllm"].decode_worker][
"resources" "resources"
]["requests"]["gpu"] = str(tp_size) ]["requests"]["gpu"] = str(tp_size)
config["spec"]["services"][WORKER_COMPONENT_NAMES["vllm_v1"].decode_worker][ config["spec"]["services"][WORKER_COMPONENT_NAMES["vllm"].decode_worker][
"resources" "resources"
]["limits"]["gpu"] = str(tp_size) ]["limits"]["gpu"] = str(tp_size)
args = config["spec"]["services"][ args = config["spec"]["services"][WORKER_COMPONENT_NAMES["vllm"].decode_worker][
WORKER_COMPONENT_NAMES["vllm_v1"].decode_worker "extraPodSpec"
]["extraPodSpec"]["mainContainer"]["args"] ]["mainContainer"]["args"]
args = break_arguments(args) args = break_arguments(args)
...@@ -169,7 +169,7 @@ class VllmV1ConfigModifier: ...@@ -169,7 +169,7 @@ class VllmV1ConfigModifier:
except ValueError: except ValueError:
args = append_argument(args, ["--tensor-parallel-size", str(tp_size)]) args = append_argument(args, ["--tensor-parallel-size", str(tp_size)])
config["spec"]["services"][WORKER_COMPONENT_NAMES["vllm_v1"].decode_worker][ config["spec"]["services"][WORKER_COMPONENT_NAMES["vllm"].decode_worker][
"extraPodSpec" "extraPodSpec"
]["mainContainer"]["args"] = join_arguments(args) ]["mainContainer"]["args"] = join_arguments(args)
...@@ -177,7 +177,7 @@ class VllmV1ConfigModifier: ...@@ -177,7 +177,7 @@ class VllmV1ConfigModifier:
@classmethod @classmethod
def get_model_name(cls, config: dict) -> str: def get_model_name(cls, config: dict) -> str:
worker_name = WORKER_COMPONENT_NAMES["vllm_v1"].decode_worker worker_name = WORKER_COMPONENT_NAMES["vllm"].decode_worker
args = config["spec"]["services"][worker_name]["extraPodSpec"]["mainContainer"][ args = config["spec"]["services"][worker_name]["extraPodSpec"]["mainContainer"][
"args" "args"
] ]
...@@ -232,5 +232,5 @@ class VllmV1ConfigModifier: ...@@ -232,5 +232,5 @@ class VllmV1ConfigModifier:
CONFIG_MODIFIERS = { CONFIG_MODIFIERS = {
"vllm_v1": VllmV1ConfigModifier, "vllm": VllmV1ConfigModifier,
} }
...@@ -17,9 +17,9 @@ import argparse ...@@ -17,9 +17,9 @@ import argparse
import asyncio import asyncio
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Any, Dict, List, Optional, Union
import aiofiles import aiofiles # type: ignore[import-untyped]
import httpx # added for HTTP requests import httpx # added for HTTP requests
import kubernetes_asyncio as kubernetes import kubernetes_asyncio as kubernetes
import yaml import yaml
...@@ -62,9 +62,9 @@ class DynamoDeploymentClient: ...@@ -62,9 +62,9 @@ class DynamoDeploymentClient:
self.deployment_name = deployment_name self.deployment_name = deployment_name
self.model_name = model_name self.model_name = model_name
self.service_name = service_name or f"{deployment_name}-frontend" self.service_name = service_name or f"{deployment_name}-frontend"
self.components: list[str] = [] # Will store component names from CR self.components: List[str] = [] # Will store component names from CR
self.deployment_spec: Optional[ self.deployment_spec: Optional[
dict Dict[str, Any]
] = None # Will store the full deployment spec ] = None # Will store the full deployment spec
self.base_log_dir = Path(base_log_dir) if base_log_dir else Path("logs") self.base_log_dir = Path(base_log_dir) if base_log_dir else Path("logs")
self.frontend_port = frontend_port self.frontend_port = frontend_port
......
...@@ -112,6 +112,7 @@ For Kubernetes deployment, YAML manifests are provided in the `deploy/` director ...@@ -112,6 +112,7 @@ For Kubernetes deployment, YAML manifests are provided in the `deploy/` director
- `agg_router.yaml` - Aggregated serving with KV routing - `agg_router.yaml` - Aggregated serving with KV routing
- `disagg.yaml` - Disaggregated serving - `disagg.yaml` - Disaggregated serving
- `disagg_router.yaml` - Disaggregated serving with KV routing - `disagg_router.yaml` - Disaggregated serving with KV routing
- `disagg_planner.yaml` - Disaggregated serving with [SLA Planner](../../../docs/architecture/sla_planner.md). See [SLA Planner Deployment Guide](../../../docs/guides/dynamo_deploy/sla_planner_deployment.md) for more details.
#### Prerequisites #### Prerequisites
...@@ -124,6 +125,8 @@ For Kubernetes deployment, YAML manifests are provided in the `deploy/` director ...@@ -124,6 +125,8 @@ For Kubernetes deployment, YAML manifests are provided in the `deploy/` director
# Update the image references in the YAML files # Update the image references in the YAML files
``` ```
- **Pre-Deployment Profiling (if Using SLA Planner)**: Follow the [pre-deployment profiling guide](../../../docs/architecture/pre_deployment_profiling.md) to run pre-deployment profiling. The results will be saved to the `profiling-pvc` PVC and queried by the SLA Planner.
- **Port Forwarding**: After deployment, forward the frontend service to access the API: - **Port Forwarding**: After deployment, forward the frontend service to access the API:
```bash ```bash
kubectl port-forward deployment/vllm-v1-disagg-frontend-<pod-uuid-info> 8080:8000 kubectl port-forward deployment/vllm-v1-disagg-frontend-<pod-uuid-info> 8080:8000
......
...@@ -6,6 +6,13 @@ kind: DynamoGraphDeployment ...@@ -6,6 +6,13 @@ kind: DynamoGraphDeployment
metadata: metadata:
name: vllm-disagg-planner name: vllm-disagg-planner
spec: spec:
envs:
- name: DYNAMO_SERVICE_CONFIG
value: '{"Prometheus":{"global":{"scrape_interval":"5s"},"scrape_configs":[{"job_name":"prometheus","static_configs":[{"targets":["localhost:9090"]}]},{"job_name":"frontend","static_configs":[{"targets":["vllm-disagg-planner-frontend:8000"]}]}]}}'
- name: DYNAMO_PORT
value: "8000"
- name: DYNAMO_NAMESPACE
value: "vllm-disagg-planner"
services: services:
Frontend: Frontend:
dynamoNamespace: vllm-disagg-planner dynamoNamespace: vllm-disagg-planner
...@@ -31,25 +38,114 @@ spec: ...@@ -31,25 +38,114 @@ spec:
failureThreshold: 10 failureThreshold: 10
resources: resources:
requests: requests:
cpu: "1" cpu: "32"
memory: "10Gi"
limits:
cpu: "32"
memory: "10Gi"
extraPodSpec:
mainContainer:
image: nvcr.io/nvidian/nim-llm-dev/vllm-runtime:dep-253.17
workingDir: /workspace/components/backends/vllm
command:
- /bin/sh
- -c
args:
- "python3 -m dynamo.frontend --http-port 8000"
Planner:
dynamoNamespace: vllm-disagg-planner
envFromSecret: hf-token-secret
componentType: planner
replicas: 1
livenessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
readinessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
initialDelaySeconds: 60
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
resources:
requests:
cpu: "2"
memory: "2Gi" memory: "2Gi"
limits: limits:
cpu: "1" cpu: "2"
memory: "2Gi" memory: "2Gi"
pvc:
create: false
name: profiling-pvc # Must be pre-created before deployment and SLA profiler must have been run
mountPoint: /workspace/profiling_results
extraPodSpec: extraPodSpec:
mainContainer: mainContainer:
image: nvcr.io/nvidian/nim-llm-dev/vllm-runtime:dep-233.17 image: nvcr.io/nvidian/nim-llm-dev/vllm-runtime:dep-253.17
workingDir: /workspace/components/planner/src/dynamo/planner
args:
- python
- -m
- planner_sla
- --environment=kubernetes
- --backend=vllm
- --adjustment-interval=60
- --profile-results-dir=/workspace/profiling_results
Prometheus:
dynamoNamespace: vllm-disagg-planner
componentType: main
replicas: 1
envs:
- name: PYTHONPATH
value: "/workspace/components/planner/src"
livenessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
readinessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
initialDelaySeconds: 30
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
resources:
requests:
cpu: "2"
memory: "2Gi"
limits:
cpu: "2"
memory: "2Gi"
extraPodSpec:
mainContainer:
image: nvcr.io/nvidian/nim-llm-dev/vllm-runtime:dep-253.17
workingDir: /workspace/components/backends/vllm workingDir: /workspace/components/backends/vllm
command: command:
- /bin/sh - /bin/sh
- -c - -c
args: args:
- "python3 -m dynamo.frontend --http-port 8000" - "python3 -m dynamo.planner.prometheus"
VllmDecodeWorker: backend:
dynamoNamespace: vllm-disagg-planner dynamoNamespace: vllm-disagg-planner
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
componentType: worker componentType: worker
replicas: 1 replicas: 2
livenessProbe: livenessProbe:
httpGet: httpGet:
path: /live path: /live
...@@ -66,12 +162,12 @@ spec: ...@@ -66,12 +162,12 @@ spec:
failureThreshold: 60 failureThreshold: 60
resources: resources:
requests: requests:
cpu: "10" cpu: "8"
memory: "20Gi" memory: "16Gi"
gpu: "1" gpu: "1"
limits: limits:
cpu: "10" cpu: "8"
memory: "20Gi" memory: "16Gi"
gpu: "1" gpu: "1"
envs: envs:
- name: DYN_SYSTEM_ENABLED - name: DYN_SYSTEM_ENABLED
...@@ -88,18 +184,18 @@ spec: ...@@ -88,18 +184,18 @@ spec:
port: 9090 port: 9090
periodSeconds: 10 periodSeconds: 10
failureThreshold: 60 failureThreshold: 60
image: nvcr.io/nvidian/nim-llm-dev/vllm-runtime:dep-233.17 image: nvcr.io/nvidian/nim-llm-dev/vllm-runtime:dep-253.17
workingDir: /workspace/components/backends/vllm workingDir: /workspace/components/backends/vllm
command: command:
- /bin/sh - /bin/sh
- -c - -c
args: args:
- "python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B 2>&1 | tee /tmp/vllm.log" - "python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B 2>&1 | tee /tmp/vllm.log"
VllmPrefillWorker: prefill:
dynamoNamespace: vllm-disagg-planner dynamoNamespace: vllm-disagg-planner
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
componentType: worker componentType: worker
replicas: 1 replicas: 2
livenessProbe: livenessProbe:
httpGet: httpGet:
path: /health path: /health
...@@ -116,12 +212,12 @@ spec: ...@@ -116,12 +212,12 @@ spec:
failureThreshold: 60 failureThreshold: 60
resources: resources:
requests: requests:
cpu: "10" cpu: "8"
memory: "20Gi" memory: "16Gi"
gpu: "1" gpu: "1"
limits: limits:
cpu: "10" cpu: "8"
memory: "20Gi" memory: "16Gi"
gpu: "1" gpu: "1"
envs: envs:
- name: DYN_SYSTEM_ENABLED - name: DYN_SYSTEM_ENABLED
...@@ -138,10 +234,10 @@ spec: ...@@ -138,10 +234,10 @@ spec:
port: 9090 port: 9090
periodSeconds: 10 periodSeconds: 10
failureThreshold: 60 failureThreshold: 60
image: nvcr.io/nvidian/nim-llm-dev/vllm-runtime:dep-233.17 image: nvcr.io/nvidian/nim-llm-dev/vllm-runtime:dep-253.17
workingDir: /workspace/components/backends/vllm workingDir: /workspace/components/backends/vllm
command: command:
- /bin/sh - /bin/sh
- -c - -c
args: args:
- python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --is-prefill-worker 2>&1 | tee /tmp/vllm.log - python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --is-prefill-worker 2>&1 | tee /tmp/vllm.log
...@@ -55,7 +55,7 @@ def parse_args() -> Config: ...@@ -55,7 +55,7 @@ def parse_args() -> Config:
parser.add_argument( parser.add_argument(
"--is-prefill-worker", "--is-prefill-worker",
action="store_true", action="store_true",
help="Enable prefill functionality for this worker. Currently overwrites the --endpoint to be a specially chosen dyn://dynamo.prefill.generate", help="Enable prefill functionality for this worker. Uses the provided namespace to construct dyn://namespace.prefill.generate",
) )
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
...@@ -79,8 +79,13 @@ def parse_args() -> Config: ...@@ -79,8 +79,13 @@ def parse_args() -> Config:
# This becomes an `Option` on the Rust side # This becomes an `Option` on the Rust side
config.served_model_name = None config.served_model_name = None
namespace = os.environ.get("DYNAMO_NAMESPACE", "dynamo")
if args.is_prefill_worker: if args.is_prefill_worker:
args.endpoint = "dyn://dynamo.prefill.generate" args.endpoint = f"dyn://{namespace}.prefill.generate"
else:
# For decode workers, also use the provided namespace instead of hardcoded "dynamo"
args.endpoint = f"dyn://{namespace}.backend.generate"
endpoint_str = args.endpoint.replace("dyn://", "", 1) endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".") endpoint_parts = endpoint_str.split(".")
...@@ -127,6 +132,14 @@ async def allocate_and_reserve_port( ...@@ -127,6 +132,14 @@ async def allocate_and_reserve_port(
""" """
node_name = socket.gethostname() node_name = socket.gethostname()
try:
node_ip = socket.gethostbyname(node_name)
except socket.gaierror:
# If hostname cannot be resolved, fall back to localhost
logger.warning(
f"Hostname '{node_name}' cannot be resolved, falling back to '127.0.0.1'"
)
node_ip = "127.0.0.1"
for attempt in range(1, max_attempts + 1): for attempt in range(1, max_attempts + 1):
# Hold socket open just long enough to reserve in ETCD # Hold socket open just long enough to reserve in ETCD
...@@ -136,7 +149,7 @@ async def allocate_and_reserve_port( ...@@ -136,7 +149,7 @@ async def allocate_and_reserve_port(
port = sock.getsockname()[1] port = sock.getsockname()[1]
# Reserve in ETCD while holding the socket # Reserve in ETCD while holding the socket
key = f"dyn://{namespace}/ports/{node_name}/{port}" key = f"dyn://{namespace}/ports/{node_ip}/{port}"
value = { value = {
"worker_id": worker_id, "worker_id": worker_id,
"reason": reason, "reason": reason,
...@@ -242,23 +255,41 @@ def overwrite_args(config): ...@@ -242,23 +255,41 @@ def overwrite_args(config):
raise ValueError(f"{key} not found in AsyncEngineArgs from vLLM.") raise ValueError(f"{key} not found in AsyncEngineArgs from vLLM.")
def set_side_channel_host_and_port(config: Config, hostname: Optional[str] = None): def get_host_ip() -> str:
"""vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors. """Get the IP address of the host.
This sets the port number for the side channel. This is needed for the side channel to work in multi-node deployments.
""" """
if hostname is None: try:
hostname = socket.gethostname() host_name = socket.gethostname()
# Test if hostname is usable by attempting to bind to it except socket.error as e:
logger.warning(f"Failed to get hostname: {e}, falling back to '127.0.0.1'")
return "127.0.0.1"
else:
try: try:
# Get the IP address of the hostname - this is needed for the side channel to work in multi-node deployments
host_ip = socket.gethostbyname(host_name)
# Test if the IP is actually usable by binding to it
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket:
test_socket.bind((hostname, 0)) test_socket.bind((host_ip, 0))
except (socket.error, socket.gaierror): return host_ip
# If hostname is not usable, fall back to localhost except socket.gaierror as e:
logger.warning( logger.warning(
f"Hostname '{hostname}' is not usable, falling back to '127.0.0.1'" f"Hostname '{host_name}' cannot be resolved: {e}, falling back to '127.0.0.1'"
) )
hostname = "127.0.0.1" return "127.0.0.1"
except socket.error as e:
# If hostname is not usable for binding, fall back to localhost
logger.warning(
f"Hostname '{host_name}' is not usable for binding: {e}, falling back to '127.0.0.1'"
)
return "127.0.0.1"
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = hostname def set_side_channel_host_and_port(config: Config):
"""vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors.
This sets the port number for the side channel.
"""
host_ip = get_host_ip()
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = host_ip
os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(config.side_channel_port) os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(config.side_channel_port)
logger.debug(f"Set NIXL side channel to {hostname}:{config.side_channel_port}") logger.debug(f"Set NIXL side channel to {host_ip}:{config.side_channel_port}")
...@@ -15,113 +15,4 @@ See the License for the specific language governing permissions and ...@@ -15,113 +15,4 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
--> -->
# Planner Please refer to [planner docs](../../docs/architecture/planner_intro.rst) for planner documentation.
The planner is a component that monitors the state of the system and makes adjustments to the number of workers to ensure that the system is running efficiently. It can dynamically scale prefill/decode workers up and down based on a variety of KV metrics. You can find documentation and benchmarking examples in the [planner docs](../../docs/guides/planner.md).
## Usage
After you've deployed a dynamo graph, you can start the planner with the following command:
```bash
PYTHONPATH=/workspace/examples/llm python components/planner.py --namespace <namespace>
```
## Backends
1. `local` - uses circus to start/stop worker subprocesses
2. `kubernetes` - uses the kubernetes API to adjust replicas of the DynamoGraphDeployment resource, which automatically scales the corresponding worker pods up or down
## Local Backend (LocalPlanner)
The LocalPlanner is built on top of circus, which is what we use to manage component subprocesses when running with the frontend and workers. LocalPlanner allows the planner component to scale workers up and down based on system metrics.
**Current limitations**
1. Single node only
2. Workers must be using only a single GPU
3. Your initial deployment must be replicas=1 for both prefill and decode
We are working on addressing these as fast as possible.
### Under the Hood
Circus has a concept of an arbiter and a watcher:
- **Arbiter**: The supervisor process that manages all watchers
- **Watcher**: A process that encodes environment variables, command, name, and other information needed to run a component
When a service is started, each worker process is spun up as a watcher. For example, when starting a VllmWorker, a watcher is created that looks like:
```json
{
"dynamo_VllmWorker": {
"watcher_name": "dynamo_VllmWorker",
"cmd": "/opt/dynamo/venv/bin/python3 -m dynamo.sdk.cli.serve_dynamo graphs.agg_router:Frontend --service-name VllmWorker --worker-id $(CIRCUS.WID) --worker-env [{\"CUDA_VISIBLE_DEVICES\": \"0\"}]",
"resources": {
"allocated_gpus": [
0
]
},
"lease": 7587886183172559418
}
}
```
The arbiter exposes an endpoint allowing messages to add/remove/change watchers. The LocalPlanner leverages this functionality to dynamically adjust worker counts.
### Implementation
The planner architecture is designed to be simple and extensible:
- An abstract class supports basic add/remove component operations
- This is implemented in `local_connector.py`
- Circus interaction logic is in `circusd.py`, which reads the statefile, connects to the endpoint, and provides add/remove functionality
- Planner starts an instance of `LocalConnector` and uses it to modify the deployment topology
### Statefile
The statefile maintains the current state of all running workers and is used by the LocalPlanner to track and modify the deployment. It's stored at `~/.dynamo/state/{namespace}.json` (or in the directory specified by `DYN_LOCAL_STATE_DIR`). The statefile is automatically created when you run the frontend with workers and is cleaned up when the arbiter terminates. Each worker is identified as `{namespace}_{component_name}` with an optional numeric suffix for additional instances.
#### Example: Adding and Removing Workers
Starting with a single decode worker:
```json
{
"dynamo_VllmWorker": {..., "resources":{...}}
}
```
After adding a worker:
```json
{
"dynamo_VllmWorker": {..., "resources":{...}},
"dynamo_VllmWorker_1": {..., "resources":{...}}
}
```
After removing a worker (removes the highest suffix):
```json
{
"dynamo_VllmWorker": {..., "resources":{...}}
}
```
If scaled to zero, the initial entry is kept without resources to maintain configuration information:
```json
{
"dynamo_VllmWorker": {...}
}
```
### Looking forward
- Support for a multinode LocalPlanner
- Storing the statefile (and initial configurations) in ETCD using the the new `EtcdKvCache`.
### Testing
For manual testing, you can use the controller_test.py file to add/remove components after you've run a serve command on a Dynamo pipeline where the planner is linked.
## Kubernetes Backend
The Kubernetes backend works by updating the replicas count of the DynamoGraphDeployment custom resource. When the planner determines that workers need to be scaled up or down based on workload metrics, it uses the Kubernetes API to patch the DynamoGraphDeployment resource specification, changing the replicas count for the appropriate worker component. The Kubernetes operator then reconciles this change by creating or terminating the necessary pods. This provides a seamless autoscaling experience in Kubernetes environments without requiring manual intervention.
The Kubernetes backend will automatically be used by Planner when your pipeline is deployed using a DynamoGraphDeployment CR. By default, the planner will run in no-op mode, which means it will monitor metrics but not take scaling actions. To enable actual scaling, you should also specify `--Planner.no-operation=false`.
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
__all__ = [ __all__ = [
"CircusController", "CircusController",
"LocalConnector",
"PlannerConnector", "PlannerConnector",
"KubernetesConnector", "KubernetesConnector",
"LoadPlannerDefaults", "LoadPlannerDefaults",
...@@ -26,5 +25,4 @@ __all__ = [ ...@@ -26,5 +25,4 @@ __all__ = [
from dynamo.planner.circusd import CircusController from dynamo.planner.circusd import CircusController
from dynamo.planner.defaults import LoadPlannerDefaults, SLAPlannerDefaults from dynamo.planner.defaults import LoadPlannerDefaults, SLAPlannerDefaults
from dynamo.planner.kubernetes_connector import KubernetesConnector from dynamo.planner.kubernetes_connector import KubernetesConnector
from dynamo.planner.local_connector import LocalConnector
from dynamo.planner.planner_connector import PlannerConnector from dynamo.planner.planner_connector import PlannerConnector
...@@ -13,12 +13,21 @@ ...@@ -13,12 +13,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import os
from dynamo.planner.kube import get_current_k8s_namespace
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
# Source of truth for planner defaults # Source of truth for planner defaults
class BasePlannerDefaults: class BasePlannerDefaults:
namespace = "dynamo" namespace = "dynamo"
environment = "local" environment = "kubernetes"
backend = "vllm_v0" backend = "vllm"
no_operation = False no_operation = False
log_dir = None log_dir = None
adjustment_interval = 180 # in seconds adjustment_interval = 180 # in seconds
...@@ -36,8 +45,25 @@ class LoadPlannerDefaults(BasePlannerDefaults): ...@@ -36,8 +45,25 @@ class LoadPlannerDefaults(BasePlannerDefaults):
prefill_queue_scale_down_threshold = 0.2 prefill_queue_scale_down_threshold = 0.2
def _get_default_prometheus_endpoint(port: str, namespace: str):
"""Compute default prometheus endpoint using environment variables and Kubernetes service discovery"""
k8s_namespace = get_current_k8s_namespace()
if k8s_namespace and k8s_namespace != "default":
prometheus_service = f"{namespace}-prometheus"
return f"http://{prometheus_service}.{k8s_namespace}.svc.cluster.local:{port}"
else:
logger.warning(
f"Cannot determine Prometheus endpoint. Running in namespace '{k8s_namespace}'. "
"Ensure the planner is deployed in a Kubernetes cluster with proper namespace configuration."
)
return f"{namespace}-prometheus"
class SLAPlannerDefaults(BasePlannerDefaults): class SLAPlannerDefaults(BasePlannerDefaults):
prometheus_endpoint = "http://localhost:9090" port = os.environ.get("DYNAMO_PORT", "8000")
namespace = os.environ.get("DYNAMO_NAMESPACE", "vllm-disagg-planner")
prometheus_endpoint = _get_default_prometheus_endpoint(port, namespace)
profile_results_dir = "profiling_results" profile_results_dir = "profiling_results"
isl = 3000 # in number of tokens isl = 3000 # in number of tokens
osl = 150 # in number of tokens osl = 150 # in number of tokens
...@@ -47,21 +73,13 @@ class SLAPlannerDefaults(BasePlannerDefaults): ...@@ -47,21 +73,13 @@ class SLAPlannerDefaults(BasePlannerDefaults):
load_prediction_window_size = 50 # predict load using how many recent load samples load_prediction_window_size = 50 # predict load using how many recent load samples
class VllmV0ComponentName: class VllmComponentName:
prefill_worker = "PrefillWorker" prefill_worker = "prefill"
prefill_worker_endpoint = "mock"
decode_worker = "VllmWorker"
decode_worker_endpoint = "generate"
class VllmV1ComponentName:
prefill_worker = "VllmPrefillWorker"
prefill_worker_endpoint = "generate" prefill_worker_endpoint = "generate"
decode_worker = "VllmDecodeWorker" decode_worker = "backend"
decode_worker_endpoint = "generate" decode_worker_endpoint = "generate"
WORKER_COMPONENT_NAMES = { WORKER_COMPONENT_NAMES = {
"vllm_v0": VllmV0ComponentName, "vllm": VllmComponentName,
"vllm_v1": VllmV1ComponentName,
} }
...@@ -17,26 +17,41 @@ import asyncio ...@@ -17,26 +17,41 @@ import asyncio
from typing import Optional from typing import Optional
from kubernetes import client, config from kubernetes import client, config
from kubernetes.config.config_exception import ConfigException
def get_current_k8s_namespace() -> str:
"""Get the current namespace if running inside a k8s cluster"""
try:
with open("/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r") as f:
return f.read().strip()
except FileNotFoundError:
# Fallback to 'default' if not running in k8s
return "default"
class KubernetesAPI: class KubernetesAPI:
def __init__(self): def __init__(self, k8s_namespace: Optional[str] = None):
# Load kubernetes configuration # Load kubernetes configuration
config.load_incluster_config() # for in-cluster deployment try:
config.load_incluster_config() # for in-cluster deployment
except ConfigException:
config.load_kube_config() # for out-of-cluster deployment
self.custom_api = client.CustomObjectsApi() self.custom_api = client.CustomObjectsApi()
self.current_namespace = self._get_current_namespace() self.current_namespace = k8s_namespace or get_current_k8s_namespace()
def _get_current_namespace(self) -> str: def _get_graph_deployment_from_name(
"""Get the current namespace if running inside a k8s cluster""" self, graph_deployment_name: str
try: ) -> Optional[dict]:
with open( """Get the graph deployment from the dynamo graph deployment name"""
"/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r" return self.custom_api.get_namespaced_custom_object(
) as f: group="nvidia.com",
return f.read().strip() version="v1alpha1",
except FileNotFoundError: namespace=self.current_namespace,
# Fallback to 'default' if not running in k8s plural="dynamographdeployments",
return "default" name=graph_deployment_name,
)
async def get_graph_deployment( async def get_graph_deployment(
self, component_name: str, dynamo_namespace: str self, component_name: str, dynamo_namespace: str
...@@ -98,12 +113,8 @@ class KubernetesAPI: ...@@ -98,12 +113,8 @@ class KubernetesAPI:
if not graph_deployment_name: if not graph_deployment_name:
return None return None
graph_deployment = self.custom_api.get_namespaced_custom_object( graph_deployment = self._get_graph_deployment_from_name(
group="nvidia.com", graph_deployment_name
version="v1alpha1",
namespace=self.current_namespace,
plural="dynamographdeployments",
name=graph_deployment_name,
) )
return graph_deployment return graph_deployment
...@@ -127,19 +138,36 @@ class KubernetesAPI: ...@@ -127,19 +138,36 @@ class KubernetesAPI:
body=patch, body=patch,
) )
async def is_deployment_ready(self, graph_deployment_name: str) -> bool:
"""Check if a graph deployment is ready"""
graph_deployment = self._get_graph_deployment_from_name(graph_deployment_name)
if not graph_deployment:
raise ValueError(f"Graph deployment {graph_deployment_name} not found")
conditions = graph_deployment.get("status", {}).get("conditions", [])
ready_condition = next(
(c for c in conditions if c.get("type") == "Ready"), None
)
return ready_condition is not None and ready_condition.get("status") == "True"
async def wait_for_graph_deployment_ready( async def wait_for_graph_deployment_ready(
self, self,
graph_deployment_name: str, graph_deployment_name: str,
max_attempts: int = 60, # default: 10 minutes total max_attempts: int = 180, # default: 30 minutes total
delay_seconds: int = 10, # default: check every 10 seconds delay_seconds: int = 10, # default: check every 10 seconds
) -> None: ) -> None:
"""Wait for a graph deployment to be ready""" """Wait for a graph deployment to be ready"""
for attempt in range(max_attempts): for attempt in range(max_attempts):
await asyncio.sleep(delay_seconds) await asyncio.sleep(delay_seconds)
graph_deployment = await self.get_graph_deployment(
graph_deployment_name, self.current_namespace graph_deployment = self._get_graph_deployment_from_name(
graph_deployment_name
) )
if not graph_deployment: if not graph_deployment:
raise ValueError(f"Graph deployment {graph_deployment_name} not found") raise ValueError(f"Graph deployment {graph_deployment_name} not found")
......
...@@ -13,24 +13,33 @@ ...@@ -13,24 +13,33 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .kube import KubernetesAPI import logging
from .planner_connector import PlannerConnector from typing import Optional
from dynamo.planner.kube import KubernetesAPI
from dynamo.planner.planner_connector import PlannerConnector
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class KubernetesConnector(PlannerConnector): class KubernetesConnector(PlannerConnector):
def __init__(self, namespace: str): def __init__(self, dynamo_namespace: str, k8s_namespace: Optional[str] = None):
self.kube_api = KubernetesAPI() self.kube_api = KubernetesAPI(k8s_namespace)
self.namespace = namespace self.dynamo_namespace = dynamo_namespace
async def add_component(self, component_name: str, blocking: bool = True): async def add_component(self, component_name: str, blocking: bool = True):
"""Add a component by increasing its replica count by 1""" """Add a component by increasing its replica count by 1"""
deployment = await self.kube_api.get_graph_deployment( deployment = await self.kube_api.get_graph_deployment(
component_name, self.namespace component_name, self.dynamo_namespace
) )
if deployment is None: if deployment is None:
raise ValueError( raise ValueError(
f"Graph not found for component {component_name} in dynamo namespace {self.namespace}" f"Graph not found for component {component_name} in dynamo namespace {self.dynamo_namespace}"
) )
# get current replicas or 1 if not found # get current replicas or 1 if not found
current_replicas = self._get_current_replicas(deployment, component_name) current_replicas = self._get_current_replicas(deployment, component_name)
await self.kube_api.update_graph_replicas( await self.kube_api.update_graph_replicas(
...@@ -45,13 +54,15 @@ class KubernetesConnector(PlannerConnector): ...@@ -45,13 +54,15 @@ class KubernetesConnector(PlannerConnector):
async def remove_component(self, component_name: str, blocking: bool = True): async def remove_component(self, component_name: str, blocking: bool = True):
"""Remove a component by decreasing its replica count by 1""" """Remove a component by decreasing its replica count by 1"""
deployment = await self.kube_api.get_graph_deployment( deployment = await self.kube_api.get_graph_deployment(
component_name, self.namespace component_name, self.dynamo_namespace
) )
if deployment is None: if deployment is None:
raise ValueError( raise ValueError(
f"Graph {component_name} not found for namespace {self.namespace}" f"Graph {component_name} not found for namespace {self.dynamo_namespace}"
) )
# get current replicas or 1 if not found # get current replicas or 1 if not found
current_replicas = self._get_current_replicas(deployment, component_name) current_replicas = self._get_current_replicas(deployment, component_name)
if current_replicas > 0: if current_replicas > 0:
...@@ -65,6 +76,68 @@ class KubernetesConnector(PlannerConnector): ...@@ -65,6 +76,68 @@ class KubernetesConnector(PlannerConnector):
self._get_graph_deployment_name(deployment) self._get_graph_deployment_name(deployment)
) )
async def _validate_components_same_deployment(
self, target_replicas: dict[str, int]
) -> dict:
"""
Validate that all target components belong to the same DynamoGraphDeployment.
"""
if not target_replicas:
raise ValueError("target_replicas cannot be empty")
# Get deployment for first component
first_component = next(iter(target_replicas))
deployment = await self.kube_api.get_graph_deployment(
first_component, self.dynamo_namespace
)
if deployment is None:
raise ValueError(
f"Component {first_component} not found in namespace {self.dynamo_namespace}"
)
# Validate that all components belong to the same DGD
graph_name = deployment["metadata"]["name"]
for component in target_replicas:
comp_deployment = await self.kube_api.get_graph_deployment(
component, self.dynamo_namespace
)
if comp_deployment is None:
raise ValueError(
f"Component {component} not found in namespace {self.dynamo_namespace}"
)
if comp_deployment["metadata"]["name"] != graph_name:
raise ValueError(
f"Component {component} belongs to graph '{comp_deployment['metadata']['name']}' "
f"but expected graph '{graph_name}'. All components must belong to the same GraphDeployment."
)
return deployment
async def set_component_replicas(
self, target_replicas: dict[str, int], blocking: bool = True
):
"""Set the replicas for multiple components at once"""
deployment = await self._validate_components_same_deployment(target_replicas)
if not await self.kube_api.is_deployment_ready(
self._get_graph_deployment_name(deployment)
):
logger.warning(
f"Deployment {self._get_graph_deployment_name(deployment)} is not ready, ignoring this scaling"
)
return
for component_name, replicas in target_replicas.items():
await self.kube_api.update_graph_replicas(
self._get_graph_deployment_name(deployment),
component_name,
replicas,
)
if blocking:
await self.kube_api.wait_for_graph_deployment_ready(
self._get_graph_deployment_name(deployment)
)
def _get_current_replicas(self, deployment: dict, component_name: str) -> int: def _get_current_replicas(self, deployment: dict, component_name: str) -> int:
"""Get the current replicas for a component in a graph deployment""" """Get the current replicas for a component in a graph deployment"""
return ( return (
...@@ -84,12 +157,13 @@ if __name__ == "__main__": ...@@ -84,12 +157,13 @@ if __name__ == "__main__":
import asyncio import asyncio
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--namespace", type=str, default="dynamo") parser.add_argument("--dynamo_namespace", type=str, default="dynamo")
parser.add_argument("--k8s_namespace", type=str, default="default")
parser.add_argument("--action", type=str, choices=["add", "remove"]) parser.add_argument("--action", type=str, choices=["add", "remove"])
parser.add_argument("--component", type=str, default="planner") parser.add_argument("--component", type=str, default="planner")
parser.add_argument("--blocking", action="store_true") parser.add_argument("--blocking", action="store_true")
args = parser.parse_args() args = parser.parse_args()
connector = KubernetesConnector(args.namespace) connector = KubernetesConnector(args.dynamo_namespace, args.k8s_namespace)
if args.action == "add": if args.action == "add":
task = connector.add_component(args.component, args.blocking) task = connector.add_component(args.component, args.blocking)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import logging
import os
from pathlib import Path
from typing import Any, Dict, List
import filelock
from dynamo.planner.circusd import CircusController
from dynamo.planner.planner_connector import PlannerConnector
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class LocalConnector(PlannerConnector):
def __init__(self, namespace: str, runtime: DistributedRuntime):
"""
Initialize LocalConnector and connect to CircusController.
Args:
namespace: The Dynamo namespace
runtime: Optional DistributedRuntime instance
"""
self.namespace = namespace
self.runtime = runtime
self.state_file = Path.home() / ".dynamo" / "state" / f"{namespace}.json"
self.circus = CircusController.from_state_file(namespace)
self._lockfile = self.state_file.with_suffix(".lock")
self._file_lock = filelock.FileLock(self._lockfile)
self.worker_client: Any | None = None
self.prefill_client: Any | None = None
self.etcd_client: Any | None = None
async def _load_state(self) -> Dict[str, Any]:
"""Load state from state file.
Returns:
State dictionary
"""
if not self.state_file.exists():
raise FileNotFoundError(f"State file not found: {self.state_file}")
with self._file_lock:
with open(self.state_file, "r") as f:
return json.load(f)
async def _save_state(self, state: Dict[str, Any]) -> bool:
"""Save state to state file.
Args:
state: State dictionary to save
Returns:
True if successful
"""
try:
with self._file_lock:
with open(self.state_file, "w") as f:
json.dump(state, f, indent=2)
return True
except Exception as e:
logger.error(f"Failed to save state: {e}")
return False
async def _get_available_gpus(self) -> List[str]:
"""Get list of unallocated GPU IDs.
Returns:
List of available GPU IDs
"""
state = await self._load_state()
system_resources = state.get("environment", {}).get("SYSTEM_RESOURCES", {})
all_gpus = set(str(gpu) for gpu in system_resources.get("gpu_info", []))
allocated_gpus: set[str] = set()
for component_info in state.get("components", {}).values():
resources = component_info.get("resources", {})
gpu_list = resources.get("allocated_gpus", [])
allocated_gpus.update(str(gpu) for gpu in gpu_list)
logger.info(f"Allocated GPUs: {allocated_gpus}")
available = sorted(list(all_gpus - allocated_gpus))
logger.info(f"Available GPUs: {available}")
return available
async def add_component(self, component_name: str, blocking: bool = True) -> bool:
"""
Add a component. The steps are as follows:
1. Load state
2. Find max suffix to create unique watcher name
3. Built environment and command for watcher
4. Block until component is running
Args:
component_name: Name of the component
Returns:
True if successful
"""
state = await self._load_state()
# Find max suffix
max_suffix = 0
for watcher_name in state["components"].keys():
if watcher_name.startswith(f"{self.namespace}_{component_name}_"):
suffix = int(
watcher_name.replace(f"{self.namespace}_{component_name}_", "")
)
max_suffix = max(max_suffix, suffix)
watcher_name = f"{self.namespace}_{component_name}_{max_suffix + 1}"
if component_name not in [
c.replace(f"{self.namespace}_", "") for c in state["components"]
]:
raise ValueError(
f"Component {component_name} not found in state configuration"
)
# Get base command and config
component_info = state["components"][f"{self.namespace}_{component_name}"]
base_cmd = component_info["cmd"].split("--worker-env")[0].strip()
service_config = state["environment"].get("DYNAMO_SERVICE_CONFIG")
# Build environment
watcher_env = os.environ.copy()
if component_name in ["VllmWorker", "PrefillWorker"]:
available_gpus = await self._get_available_gpus()
if not available_gpus:
raise ValueError("No GPUs available for allocation")
gpu_id = available_gpus[0]
watcher_env["CUDA_VISIBLE_DEVICES"] = gpu_id
watcher_env["DYNAMO_SERVICE_CONFIG"] = service_config
# Build worker env list and command
worker_env_list = [watcher_env]
worker_env_arg = json.dumps(worker_env_list)
# We add a custom component name to ensure that the lease is attatched to this specific watcher
full_cmd = f"{base_cmd} --worker-env '{worker_env_arg}' --custom-component-name '{watcher_name}'"
pre_add_endpoint_ids = await self._count_instance_ids(component_name)
logger.info(f"Pre-add endpoint IDs: {pre_add_endpoint_ids}")
logger.info(f"Adding watcher {watcher_name}")
success = await self.circus.add_watcher(
name=watcher_name, cmd=full_cmd, env=watcher_env, singleton=True
)
if success:
resources = {}
if component_name in ["VllmWorker", "PrefillWorker"]:
resources["allocated_gpus"] = [gpu_id]
state["components"][watcher_name] = {
"watcher_name": watcher_name,
"cmd": full_cmd,
"resources": resources,
}
await self._save_state(state)
logger.info(
f"Succesfully created {watcher_name}. Waiting for worker to start..."
)
if blocking:
required_endpoint_ids = pre_add_endpoint_ids + 1
while True:
current_endpoint_ids = await self._count_instance_ids(component_name)
if current_endpoint_ids == required_endpoint_ids:
break
logger.info(
f"Waiting for {component_name} to start. Current endpoint IDs: {current_endpoint_ids}, Required endpoint IDs: {required_endpoint_ids}"
)
await asyncio.sleep(5)
return success
async def remove_component(
self, component_name: str, blocking: bool = True
) -> bool:
"""
Remove a component. The initial components are not numbered so we simply remove their resources
and lease but keep the entry in order to use the cmd. This allows us to re-add the component
without having to re-specify the cmd. For components that have been added, we remove their entry
entry
Args:
component_name: Name of the component
Returns:
True if successful
"""
logger.info(f"Attempting to remove component {component_name}")
state = await self._load_state()
matching_components = {}
base_name = f"{self.namespace}_{component_name}"
base_name_with_underscore = f"{base_name}_"
for watcher_name in state["components"].keys():
if watcher_name == base_name:
matching_components[0] = watcher_name
elif watcher_name.startswith(base_name_with_underscore):
suffix = int(watcher_name.replace(base_name_with_underscore, ""))
matching_components[suffix] = watcher_name
if not matching_components:
logger.error(f"No matching components found for {component_name}")
return False
highest_suffix = max(matching_components.keys())
target_watcher = matching_components[highest_suffix]
logger.info(f"Removing watcher {target_watcher}")
success = await self.circus.remove_watcher(
name=target_watcher, blocking=blocking
)
if not blocking:
logger.info(
f"Circus remove_watcher for {target_watcher} {'succeeded' if success else 'failed'}"
)
if success:
if highest_suffix > 0: # Numbered watcher - remove entire entry
if target_watcher in state["components"]:
del state["components"][target_watcher]
else: # Base watcher - just clear resources and lease
if target_watcher in state["components"]:
state["components"][target_watcher]["resources"] = {}
state["components"][target_watcher]["lease"] = None
await self._save_state(state)
return success
async def _count_instance_ids(self, component_name: str) -> int:
"""
Count the instance IDs for the 'generate' endpoint of given component.
Args:
component_name: Name of the component
Returns:
Number of endpoint IDs for a component
"""
if component_name == "VllmWorker":
if self.worker_client is None:
self.worker_client = (
await self.runtime.namespace(self.namespace)
.component(component_name)
.endpoint("generate")
.client()
)
worker_ids = self.worker_client.instance_ids()
return len(worker_ids)
elif component_name == "PrefillWorker":
if self.prefill_client is None:
self.prefill_client = (
await self.runtime.namespace(self.namespace)
.component(component_name)
.endpoint("mock")
.client()
)
prefill_ids = self.prefill_client.instance_ids()
return len(prefill_ids)
else:
raise ValueError(f"Component {component_name} not supported")
async def _revoke_lease(self, lease_id: int) -> bool:
"""
Wrapper function around the etcd client to revoke a lease
Args:
lease_id: Lease ID to revoke
Returns:
True if successful
"""
if self.etcd_client is None:
self.etcd_client = self.runtime.etcd_client() # type: ignore
try:
await self.etcd_client.revoke_lease(lease_id)
logger.info(f"Revoked lease {lease_id}")
return True
except Exception as e:
logger.error(f"Failed to revoke lease {lease_id}: {e}")
return False
def __del__(self):
"""Cleanup circus controller connection on deletion."""
if hasattr(self, "circus"):
self.circus.close()
...@@ -21,11 +21,7 @@ from pydantic import BaseModel ...@@ -21,11 +21,7 @@ from pydantic import BaseModel
from dynamo.planner.defaults import SLAPlannerDefaults from dynamo.planner.defaults import SLAPlannerDefaults
from dynamo.planner.utils.planner_core import start_sla_planner from dynamo.planner.utils.planner_core import start_sla_planner
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.sdk import async_on_start, dynamo_context, endpoint, service
from dynamo.sdk.core.protocol.interface import ComponentType
from dynamo.sdk.lib.config import ServiceConfig
from dynamo.sdk.lib.image import DYNAMO_IMAGE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -38,79 +34,107 @@ class RequestType(BaseModel): ...@@ -38,79 +34,107 @@ class RequestType(BaseModel):
text: str text: str
@service( @dynamo_worker(static=False)
dynamo={ async def init_planner(runtime: DistributedRuntime, args):
"namespace": "dynamo", await asyncio.sleep(INIT_PLANNER_START_DELAY)
"component_type": ComponentType.PLANNER,
}, await start_sla_planner(runtime, args)
resources={"cpu": "10", "memory": "20Gi"},
workers=1, component = runtime.namespace(SLAPlannerDefaults.namespace).component("Planner")
image=DYNAMO_IMAGE, await component.create_service()
)
class Planner: async def generate(request: RequestType):
def __init__(self):
configure_dynamo_logging(service_name="Planner")
logger.info("Starting planner")
self.runtime = dynamo_context["runtime"]
config = ServiceConfig.get_instance()
# Get namespace directly from dynamo_context as it contains the active namespace
self.namespace = dynamo_context["namespace"]
config_instance = config.get("Planner", {})
self.args = argparse.Namespace(
namespace=self.namespace,
environment=config_instance.get(
"environment", SLAPlannerDefaults.environment
),
backend=config_instance.get("backend", SLAPlannerDefaults.backend),
no_operation=config_instance.get(
"no-operation", SLAPlannerDefaults.no_operation
),
log_dir=config_instance.get("log-dir", SLAPlannerDefaults.log_dir),
adjustment_interval=config_instance.get(
"adjustment-interval", SLAPlannerDefaults.adjustment_interval
),
max_gpu_budget=config_instance.get(
"max-gpu-budget", SLAPlannerDefaults.max_gpu_budget
),
min_endpoint=config_instance.get(
"min-endpoint", SLAPlannerDefaults.min_endpoint
),
decode_engine_num_gpu=config_instance.get(
"decode-engine-num-gpu", SLAPlannerDefaults.decode_engine_num_gpu
),
prefill_engine_num_gpu=config_instance.get(
"prefill-engine-num-gpu", SLAPlannerDefaults.prefill_engine_num_gpu
),
prometheus_endpoint=config_instance.get(
"prometheus-endpoint", SLAPlannerDefaults.prometheus_endpoint
),
profile_results_dir=config_instance.get(
"profile-results-dir", SLAPlannerDefaults.profile_results_dir
),
isl=config_instance.get("isl", SLAPlannerDefaults.isl),
osl=config_instance.get("osl", SLAPlannerDefaults.osl),
ttft=config_instance.get("ttft", SLAPlannerDefaults.ttft),
itl=config_instance.get("itl", SLAPlannerDefaults.itl),
load_predictor=config_instance.get(
"load-predictor", SLAPlannerDefaults.load_predictor
),
load_prediction_window_size=config_instance.get(
"load-prediction-window-size",
SLAPlannerDefaults.load_prediction_window_size,
),
)
@async_on_start
async def async_init(self):
await asyncio.sleep(INIT_PLANNER_START_DELAY)
logger.info("Calling start_planner")
await start_sla_planner(self.runtime, self.args)
logger.info("Planner started")
@endpoint()
async def generate(self, request: RequestType):
"""Dummy endpoint to satisfy that each component has an endpoint""" """Dummy endpoint to satisfy that each component has an endpoint"""
yield "mock endpoint" yield "mock endpoint"
generate_endpoint = component.endpoint("generate")
await generate_endpoint.serve_endpoint(generate)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SLA Planner")
parser.add_argument(
"--environment",
default=SLAPlannerDefaults.environment,
choices=["kubernetes"],
help="Environment type",
)
parser.add_argument(
"--backend",
default=SLAPlannerDefaults.backend,
choices=["vllm"],
help="Backend type",
)
parser.add_argument(
"--no-operation",
action="store_true",
default=SLAPlannerDefaults.no_operation,
help="Enable no-operation mode",
)
parser.add_argument(
"--log-dir", default=SLAPlannerDefaults.log_dir, help="Log directory path"
)
parser.add_argument(
"--adjustment-interval",
type=int,
default=SLAPlannerDefaults.adjustment_interval,
help="Adjustment interval in seconds",
)
parser.add_argument(
"--max-gpu-budget",
type=int,
default=SLAPlannerDefaults.max_gpu_budget,
help="Maximum GPU budget",
)
parser.add_argument(
"--min-endpoint",
type=int,
default=SLAPlannerDefaults.min_endpoint,
help="Minimum number of endpoints",
)
parser.add_argument(
"--decode-engine-num-gpu",
type=int,
default=SLAPlannerDefaults.decode_engine_num_gpu,
help="Number of GPUs for decode engine",
)
parser.add_argument(
"--prefill-engine-num-gpu",
type=int,
default=SLAPlannerDefaults.prefill_engine_num_gpu,
help="Number of GPUs for prefill engine",
)
parser.add_argument(
"--profile-results-dir",
default=SLAPlannerDefaults.profile_results_dir,
help="Profile results directory",
)
parser.add_argument(
"--isl", type=int, default=SLAPlannerDefaults.isl, help="Input sequence length"
)
parser.add_argument(
"--osl", type=int, default=SLAPlannerDefaults.osl, help="Output sequence length"
)
parser.add_argument(
"--ttft",
type=float,
default=SLAPlannerDefaults.ttft,
help="Time to first token",
)
parser.add_argument(
"--itl", type=float, default=SLAPlannerDefaults.itl, help="Inter-token latency"
)
parser.add_argument(
"--load-predictor",
default=SLAPlannerDefaults.load_predictor,
help="Load predictor type",
)
parser.add_argument(
"--load-prediction-window-size",
type=int,
default=SLAPlannerDefaults.load_prediction_window_size,
help="Load prediction window size",
)
args = parser.parse_args()
asyncio.run(init_planner(args))
...@@ -13,55 +13,69 @@ ...@@ -13,55 +13,69 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import logging import logging
import subprocess import subprocess
import tempfile import tempfile
import yaml import yaml
from dynamo.sdk import service from dynamo.planner.defaults import SLAPlannerDefaults
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
from dynamo.sdk.lib.image import DYNAMO_IMAGE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@service( @dynamo_worker(static=False)
dynamo={ async def worker(runtime: DistributedRuntime):
"namespace": "dynamo", """Initialize and run Prometheus server with Dynamo config."""
}, config = ServiceConfig.get_parsed_config("Prometheus")
workers=1,
image=DYNAMO_IMAGE, logger.info(f"Prometheus config: {config}")
)
class Prometheus: await start_prometheus_server(config)
def __init__(self):
"""Initialize Frontend service with HTTP server and model configuration."""
self.config = ServiceConfig.get_parsed_config("Prometheus") async def start_prometheus_server(config):
self.process = None logger.info("Starting prometheus server...")
logger.info(f"Prometheus config: {self.config}") temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False)
yaml.dump(config, temp_file)
self.start_prometheus_server() temp_file.close()
config_path = temp_file.name
def start_prometheus_server(self):
logger.info("Starting prometheus server...") prometheus_port = SLAPlannerDefaults.port
cmd = [
self.temp_file = tempfile.NamedTemporaryFile( "prometheus",
mode="w", suffix=".yml", delete=False f"--config.file={config_path}",
) f"--web.listen-address=0.0.0.0:{prometheus_port}",
yaml.dump(self.config, self.temp_file) ]
self.temp_file.close()
config_path = self.temp_file.name logger.info(f"Prometheus cmd: {cmd}")
cmd = [ process = subprocess.Popen(
"prometheus", cmd,
f"--config.file={config_path}", stdout=None,
] stderr=None,
)
logger.info(f"Prometheus cmd: {cmd}")
# Keep the worker running
self.process = subprocess.Popen( try:
cmd, while True:
stdout=None, await asyncio.sleep(1)
stderr=None, if process.poll() is not None:
) logger.error("Prometheus process died")
break
except asyncio.CancelledError:
logger.info("Shutting down Prometheus...")
process.terminate()
process.wait()
raise
if __name__ == "__main__":
# The dynamo_worker decorator handles runtime setup
import asyncio
asyncio.run(worker())
...@@ -21,7 +21,7 @@ import time ...@@ -21,7 +21,7 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from dynamo.planner import KubernetesConnector, LocalConnector from dynamo.planner import KubernetesConnector
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES, SLAPlannerDefaults from dynamo.planner.defaults import WORKER_COMPONENT_NAMES, SLAPlannerDefaults
from dynamo.planner.utils.load_predictor import LOAD_PREDICTORS from dynamo.planner.utils.load_predictor import LOAD_PREDICTORS
from dynamo.planner.utils.perf_interpolation import ( from dynamo.planner.utils.perf_interpolation import (
...@@ -47,22 +47,35 @@ class Metrics: ...@@ -47,22 +47,35 @@ class Metrics:
p_load: Optional[float] = None p_load: Optional[float] = None
d_load: Optional[float] = None d_load: Optional[float] = None
def is_valid(self) -> bool:
"""Check if all metrics are valid (not None and not NaN)."""
return (
self.ttft is not None
and self.itl is not None
and self.isl is not None
and self.osl is not None
and not math.isnan(self.ttft)
and not math.isnan(self.itl)
and not math.isnan(self.isl)
and not math.isnan(self.osl)
)
class Planner: class Planner:
def __init__(self, runtime: DistributedRuntime, args: argparse.Namespace): def __init__(self, runtime: DistributedRuntime, args: argparse.Namespace):
self.runtime = runtime self.runtime = runtime
self.args = args self.args = args
self.namespace = args.namespace self.namespace = SLAPlannerDefaults.namespace
if not args.no_operation: if not args.no_operation:
if args.environment == "local": if args.environment == "kubernetes":
self.connector = LocalConnector(args.namespace, runtime) self.connector = KubernetesConnector(self.namespace)
elif args.environment == "kubernetes":
self.connector = KubernetesConnector(args.namespace)
else: else:
raise ValueError(f"Invalid environment: {args.environment}") raise ValueError(f"Invalid environment: {args.environment}")
self.prometheus_api_client = PrometheusAPIClient(args.prometheus_endpoint) self.prometheus_api_client = PrometheusAPIClient(
SLAPlannerDefaults.prometheus_endpoint
)
self.num_req_predictor = LOAD_PREDICTORS[args.load_predictor]( self.num_req_predictor = LOAD_PREDICTORS[args.load_predictor](
window_size=args.load_prediction_window_size, window_size=args.load_prediction_window_size,
...@@ -167,6 +180,13 @@ class Planner: ...@@ -167,6 +180,13 @@ class Planner:
async def make_adjustments(self): async def make_adjustments(self):
try: try:
# Skip adjustment if no traffic
if not self.last_metrics.is_valid():
logger.info(
"Metrics contain None or NaN values (no active requests), skipping adjustment"
)
return
self.p_endpoints, self.d_endpoints = await self.get_workers_info() self.p_endpoints, self.d_endpoints = await self.get_workers_info()
logger.info( logger.info(
f"Number of prefill workers: {len(self.p_endpoints)}, number of decode workers: {len(self.d_endpoints)}" f"Number of prefill workers: {len(self.p_endpoints)}, number of decode workers: {len(self.d_endpoints)}"
...@@ -224,7 +244,14 @@ class Planner: ...@@ -224,7 +244,14 @@ class Planner:
# compute how many replicas are needed for decode # compute how many replicas are needed for decode
# 1. apply d_correction_factor to the ITL SLA # 1. apply d_correction_factor to the ITL SLA
corrected_itl = self.args.itl / self.d_correction_factor # Prevent divide by zero when d_correction_factor is 0 (no metrics yet)
if self.d_correction_factor <= 0:
logger.warning(
f"d_correction_factor is {self.d_correction_factor}, using default value of 1.0"
)
corrected_itl = self.args.itl
else:
corrected_itl = self.args.itl / self.d_correction_factor
# 2. reversely find out what is best throughput/gpu that can achieve corrected_itl under the predicted context length # 2. reversely find out what is best throughput/gpu that can achieve corrected_itl under the predicted context length
pred_decode_thpt_per_gpu = ( pred_decode_thpt_per_gpu = (
self.decode_interpolator.find_best_throughput_per_gpu( self.decode_interpolator.find_best_throughput_per_gpu(
...@@ -272,33 +299,11 @@ class Planner: ...@@ -272,33 +299,11 @@ class Planner:
return return
if not self.args.no_operation: if not self.args.no_operation:
# scale up/down the number of prefill/decode non-blockingly target_replicas = {
# TODO: add a check to avoid scaling before the previous scaling is completed WORKER_COMPONENT_NAMES[self.args.backend].prefill_worker: next_num_p,
if next_num_p > len(self.p_endpoints): WORKER_COMPONENT_NAMES[self.args.backend].decode_worker: next_num_d,
for _ in range(next_num_p - len(self.p_endpoints)): }
self.connector.add_component( await self.connector.set_component_replicas(target_replicas, blocking=False)
WORKER_COMPONENT_NAMES[self.args.backend].prefill_worker,
blocking=False,
)
elif next_num_p < len(self.p_endpoints):
for _ in range(len(self.p_endpoints) - next_num_p):
self.connector.remove_component(
WORKER_COMPONENT_NAMES[self.args.backend].prefill_worker,
blocking=False,
)
if next_num_d > len(self.d_endpoints):
for _ in range(next_num_d - len(self.d_endpoints)):
self.connector.add_component(
WORKER_COMPONENT_NAMES[self.args.backend].decode_worker,
blocking=False,
)
elif next_num_d < len(self.d_endpoints):
for _ in range(len(self.d_endpoints) - next_num_d):
self.connector.remove_component(
WORKER_COMPONENT_NAMES[self.args.backend].decode_worker,
blocking=False,
)
async def run(self): async def run(self):
"""Main loop for the planner""" """Main loop for the planner"""
...@@ -329,12 +334,6 @@ async def start_sla_planner(runtime: DistributedRuntime, args: argparse.Namespac ...@@ -329,12 +334,6 @@ async def start_sla_planner(runtime: DistributedRuntime, args: argparse.Namespac
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Common planner arguments # Common planner arguments
parser.add_argument(
"--namespace",
type=str,
default=SLAPlannerDefaults.namespace,
help="Namespace planner will look at",
)
parser.add_argument( parser.add_argument(
"--environment", "--environment",
type=str, type=str,
......
...@@ -27,40 +27,55 @@ class PrometheusAPIClient: ...@@ -27,40 +27,55 @@ class PrometheusAPIClient:
def __init__(self, url: str): def __init__(self, url: str):
self.prom = PrometheusConnect(url=url, disable_ssl=True) self.prom = PrometheusConnect(url=url, disable_ssl=True)
def get_avg_inter_token_latency(self, interval: str): def _get_average_metric(
self, metric_name: str, interval: str, operation_name: str
) -> float:
"""
Helper method to get average metrics using the pattern:
increase(metric_sum[interval])/increase(metric_count[interval])
Args:
metric_name: Base metric name (e.g., 'nv_llm_http_service_inter_token_latency_seconds')
interval: Time interval for the query (e.g., '60s')
operation_name: Human-readable name for error logging
Returns:
Average metric value or 0 if no data/error
"""
try: try:
return float( query = f"increase({metric_name}_sum[{interval}])/increase({metric_name}_count[{interval}])"
self.prom.custom_query( result = self.prom.custom_query(query=query)
query=f"increase(nv_llm_http_service_inter_token_latency_seconds_sum[{interval}])/increase(nv_llm_http_service_inter_token_latency_seconds_count[{interval}])", if not result:
)[0]["value"][1] # No data available yet (no requests made) - return 0 silently
) return 0
return float(result[0]["value"][1])
except Exception as e: except Exception as e:
logger.error(f"Error getting avg inter token latency: {e}") logger.error(f"Error getting {operation_name}: {e}")
return 0 return 0
def get_avg_inter_token_latency(self, interval: str):
return self._get_average_metric(
"nv_llm_http_service_inter_token_latency_seconds",
interval,
"avg inter token latency",
)
def get_avg_time_to_first_token(self, interval: str): def get_avg_time_to_first_token(self, interval: str):
try: return self._get_average_metric(
return float( "nv_llm_http_service_time_to_first_token_seconds",
self.prom.custom_query( interval,
query=f"increase(nv_llm_http_service_time_to_first_token_seconds_sum[{interval}])/increase(nv_llm_http_service_time_to_first_token_seconds_count[{interval}])", "avg time to first token",
)[0]["value"][1] )
)
except Exception as e:
logger.error(f"Error getting avg time to first token: {e}")
return 0
def get_avg_request_duration(self, interval: str): def get_avg_request_duration(self, interval: str):
try: return self._get_average_metric(
return float( "nv_llm_http_service_request_duration_seconds",
self.prom.custom_query( interval,
query=f"increase(nv_llm_http_service_request_duration_seconds_sum[{interval}])/increase(nv_llm_http_service_request_duration_seconds_count[{interval}])", "avg request duration",
)[0]["value"][1] )
)
except Exception as e:
logger.error(f"Error getting avg request duration: {e}")
return 0
def get_avg_request_count(self, interval: str): def get_avg_request_count(self, interval: str):
# This function follows a different query pattern than the other metrics
try: try:
raw_res = self.prom.custom_query( raw_res = self.prom.custom_query(
query=f"increase(nv_llm_http_service_requests_total[{interval}])" query=f"increase(nv_llm_http_service_requests_total[{interval}])"
...@@ -75,23 +90,15 @@ class PrometheusAPIClient: ...@@ -75,23 +90,15 @@ class PrometheusAPIClient:
return 0 return 0
def get_avg_input_sequence_tokens(self, interval: str): def get_avg_input_sequence_tokens(self, interval: str):
try: return self._get_average_metric(
return float( "nv_llm_http_service_input_sequence_tokens",
self.prom.custom_query( interval,
query=f"increase(nv_llm_http_service_input_sequence_tokens_sum[{interval}])/increase(nv_llm_http_service_input_sequence_tokens_count[{interval}])", "avg input sequence tokens",
)[0]["value"][1] )
)
except Exception as e:
logger.error(f"Error getting avg input sequence tokens: {e}")
return 0
def get_avg_output_sequence_tokens(self, interval: str): def get_avg_output_sequence_tokens(self, interval: str):
try: return self._get_average_metric(
return float( "nv_llm_http_service_output_sequence_tokens",
self.prom.custom_query( interval,
query=f"increase(nv_llm_http_service_output_sequence_tokens_sum[{interval}])/increase(nv_llm_http_service_output_sequence_tokens_count[{interval}])", "avg output sequence tokens",
)[0]["value"][1] )
)
except Exception as e:
logger.error(f"Error getting avg output sequence tokens: {e}")
return 0
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from typing import Any, Dict from typing import Any, Dict
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
...@@ -39,9 +39,45 @@ def k8s_api(mock_custom_api, mock_config): ...@@ -39,9 +39,45 @@ def k8s_api(mock_custom_api, mock_config):
return KubernetesAPI() return KubernetesAPI()
@pytest.fixture
def k8s_api_with_namespace(mock_custom_api, mock_config):
return KubernetesAPI(k8s_namespace="test-namespace")
def test_kubernetes_api_init_with_namespace(mock_custom_api, mock_config):
"""Test KubernetesAPI initialization with custom namespace"""
api = KubernetesAPI(k8s_namespace="custom-namespace")
assert api.current_namespace == "custom-namespace"
def test_kubernetes_api_init_without_namespace(mock_custom_api, mock_config):
"""Test KubernetesAPI initialization without custom namespace"""
api = KubernetesAPI()
# Should use the default namespace logic
assert api.current_namespace == "default"
def test_get_graph_deployment_from_name(k8s_api, mock_custom_api):
"""Test _get_graph_deployment_from_name method"""
mock_deployment = {"metadata": {"name": "test-deployment"}}
mock_custom_api.get_namespaced_custom_object.return_value = mock_deployment
result = k8s_api._get_graph_deployment_from_name("test-deployment")
assert result == mock_deployment
mock_custom_api.get_namespaced_custom_object.assert_called_once_with(
group="nvidia.com",
version="v1alpha1",
namespace=k8s_api.current_namespace,
plural="dynamographdeployments",
name="test-deployment",
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wait_for_graph_deployment_ready_success(k8s_api, mock_custom_api): async def test_is_deployment_ready_true(k8s_api, mock_custom_api):
# Mock the get_graph_deployment response """Test is_deployment_ready method when deployment is ready"""
# Mock the _get_graph_deployment_from_name response
mock_deployment: Dict[str, Any] = { mock_deployment: Dict[str, Any] = {
"status": { "status": {
"conditions": [ "conditions": [
...@@ -49,22 +85,18 @@ async def test_wait_for_graph_deployment_ready_success(k8s_api, mock_custom_api) ...@@ -49,22 +85,18 @@ async def test_wait_for_graph_deployment_ready_success(k8s_api, mock_custom_api)
] ]
} }
} }
k8s_api.get_graph_deployment = AsyncMock(return_value=mock_deployment)
# Test with minimal attempts and delay for faster testing
await k8s_api.wait_for_graph_deployment_ready(
"test-deployment", max_attempts=2, delay_seconds=0.1
)
# Verify get_graph_deployment was called # Mock the method on the instance
k8s_api.get_graph_deployment.assert_called_once_with( with patch.object(
"test-deployment", k8s_api.current_namespace k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
) ):
result = await k8s_api.is_deployment_ready("test-deployment")
assert result is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wait_for_graph_deployment_ready_timeout(k8s_api, mock_custom_api): async def test_is_deployment_ready_false(k8s_api, mock_custom_api):
# Mock the get_graph_deployment response with not ready status """Test is_deployment_ready method when deployment is not ready"""
mock_deployment: Dict[str, Any] = { mock_deployment: Dict[str, Any] = {
"status": { "status": {
"conditions": [ "conditions": [
...@@ -76,54 +108,115 @@ async def test_wait_for_graph_deployment_ready_timeout(k8s_api, mock_custom_api) ...@@ -76,54 +108,115 @@ async def test_wait_for_graph_deployment_ready_timeout(k8s_api, mock_custom_api)
] ]
} }
} }
k8s_api.get_graph_deployment = AsyncMock(return_value=mock_deployment)
# Test with minimal attempts and delay for faster testing # Mock the method on the instance
with pytest.raises(TimeoutError) as exc_info: with patch.object(
await k8s_api.wait_for_graph_deployment_ready( k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
"test-deployment", max_attempts=2, delay_seconds=0.1 ):
) result = await k8s_api.is_deployment_ready("test-deployment")
assert result is False
assert "is not ready after" in str(exc_info.value)
assert k8s_api.get_graph_deployment.call_count == 2 @pytest.mark.asyncio
async def test_is_deployment_ready_not_found(k8s_api, mock_custom_api):
"""Test is_deployment_ready method when deployment is not found"""
# Mock the method on the instance
with patch.object(k8s_api, "_get_graph_deployment_from_name", return_value=None):
with pytest.raises(ValueError) as exc_info:
await k8s_api.is_deployment_ready("test-deployment")
assert "not found" in str(exc_info.value)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wait_for_graph_deployment_not_found(k8s_api, mock_custom_api): async def test_wait_for_graph_deployment_ready_success(k8s_api, mock_custom_api):
# Mock the get_graph_deployment response to return None """Test wait_for_graph_deployment_ready when deployment becomes ready"""
k8s_api.get_graph_deployment = AsyncMock(return_value=None) # Mock the _get_graph_deployment_from_name response
mock_deployment: Dict[str, Any] = {
"status": {
"conditions": [
{"type": "Ready", "status": "True", "message": "Deployment is ready"}
]
}
}
# Test with minimal attempts and delay for faster testing # Mock the method on the instance
with pytest.raises(ValueError) as exc_info: with patch.object(
k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
):
# Test with minimal attempts and delay for faster testing
await k8s_api.wait_for_graph_deployment_ready( await k8s_api.wait_for_graph_deployment_ready(
"test-deployment", max_attempts=2, delay_seconds=0.1 "test-deployment", max_attempts=2, delay_seconds=0.1
) )
assert "not found" in str(exc_info.value)
assert k8s_api.get_graph_deployment.call_count == 1 @pytest.mark.asyncio
async def test_wait_for_graph_deployment_ready_timeout(k8s_api, mock_custom_api):
"""Test wait_for_graph_deployment_ready when deployment times out"""
# Mock the _get_graph_deployment_from_name response with not ready status
mock_deployment: Dict[str, Any] = {
"status": {
"conditions": [
{
"type": "Ready",
"status": "False",
"message": "Deployment is not ready",
}
]
}
}
# Mock the method on the instance
with patch.object(
k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
):
# Test with minimal attempts and delay for faster testing
with pytest.raises(TimeoutError) as exc_info:
await k8s_api.wait_for_graph_deployment_ready(
"test-deployment", max_attempts=2, delay_seconds=0.1
)
assert "is not ready after" in str(exc_info.value)
@pytest.mark.asyncio
async def test_wait_for_graph_deployment_not_found(k8s_api, mock_custom_api):
"""Test wait_for_graph_deployment_ready when deployment is not found"""
# Mock the _get_graph_deployment_from_name response to return None
with patch.object(k8s_api, "_get_graph_deployment_from_name", return_value=None):
# Test with minimal attempts and delay for faster testing
with pytest.raises(ValueError) as exc_info:
await k8s_api.wait_for_graph_deployment_ready(
"test-deployment", max_attempts=2, delay_seconds=0.1
)
assert "not found" in str(exc_info.value)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wait_for_graph_deployment_no_conditions(k8s_api, mock_custom_api): async def test_wait_for_graph_deployment_no_conditions(k8s_api, mock_custom_api):
# Mock the get_graph_deployment response with no conditions """Test wait_for_graph_deployment_ready when deployment has no conditions"""
# Mock the _get_graph_deployment_from_name response with no conditions
mock_deployment: Dict[str, Any] = {"status": {}} mock_deployment: Dict[str, Any] = {"status": {}}
k8s_api.get_graph_deployment = AsyncMock(return_value=mock_deployment)
# Test with minimal attempts and delay for faster testing with patch.object(
with pytest.raises(TimeoutError) as exc_info: k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
await k8s_api.wait_for_graph_deployment_ready( ):
"test-deployment", max_attempts=2, delay_seconds=0.1 # Test with minimal attempts and delay for faster testing
) with pytest.raises(TimeoutError) as exc_info:
await k8s_api.wait_for_graph_deployment_ready(
"test-deployment", max_attempts=2, delay_seconds=0.1
)
assert "is not ready after" in str(exc_info.value) assert "is not ready after" in str(exc_info.value)
assert k8s_api.get_graph_deployment.call_count == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wait_for_graph_deployment_ready_on_second_attempt( async def test_wait_for_graph_deployment_ready_on_second_attempt(
k8s_api, mock_custom_api k8s_api, mock_custom_api
): ):
# Mock the get_graph_deployment response to return not ready first, then ready """Test wait_for_graph_deployment_ready when deployment becomes ready on second attempt"""
# Mock the _get_graph_deployment_from_name response to return not ready first, then ready
mock_deployment_not_ready: Dict[str, Any] = { mock_deployment_not_ready: Dict[str, Any] = {
"status": { "status": {
"conditions": [ "conditions": [
...@@ -142,13 +235,13 @@ async def test_wait_for_graph_deployment_ready_on_second_attempt( ...@@ -142,13 +235,13 @@ async def test_wait_for_graph_deployment_ready_on_second_attempt(
] ]
} }
} }
k8s_api.get_graph_deployment = AsyncMock(
side_effect=[mock_deployment_not_ready, mock_deployment_ready]
)
# Test with minimal attempts and delay for faster testing
await k8s_api.wait_for_graph_deployment_ready(
"test-deployment", max_attempts=2, delay_seconds=0.1
)
assert k8s_api.get_graph_deployment.call_count == 2 with patch.object(
k8s_api,
"_get_graph_deployment_from_name",
side_effect=[mock_deployment_not_ready, mock_deployment_ready],
):
# Test with minimal attempts and delay for faster testing
await k8s_api.wait_for_graph_deployment_ready(
"test-deployment", max_attempts=2, delay_seconds=0.1
)
...@@ -63,7 +63,7 @@ async def test_add_component_increases_replicas(kubernetes_connector, mock_kube_ ...@@ -63,7 +63,7 @@ async def test_add_component_increases_replicas(kubernetes_connector, mock_kube_
# Assert # Assert
mock_kube_api.get_graph_deployment.assert_called_once_with( mock_kube_api.get_graph_deployment.assert_called_once_with(
component_name, kubernetes_connector.namespace component_name, kubernetes_connector.dynamo_namespace
) )
mock_kube_api.update_graph_replicas.assert_called_once_with( mock_kube_api.update_graph_replicas.assert_called_once_with(
"test-graph", component_name, 2 "test-graph", component_name, 2
......
...@@ -29,7 +29,7 @@ The script will recommend the best TP size for prefill and decode, as well as th ...@@ -29,7 +29,7 @@ The script will recommend the best TP size for prefill and decode, as well as th
2025-05-16 15:20:24 - __main__ - INFO - Suggested planner upper/lower bound for decode kv cache utilization: 0.20/0.10 2025-05-16 15:20:24 - __main__ - INFO - Suggested planner upper/lower bound for decode kv cache utilization: 0.20/0.10
``` ```
After finding the best TP size for prefill and decode, the script will then interpolate the TTFT with ISL and ITL with active KV cache and decode context length. This is to provide a more accurate estimation of the performance when ISL and OSL changes and will be used in the sla-planner. The results will be saved to `<output_dir>/<decode/prefill>_tp<best_tp>_interpolation`. After finding the best TP size for prefill and decode, the script will then interpolate the TTFT with ISL and ITL with active KV cache and decode context length. This is to provide a more accurate estimation of the performance when ISL and OSL changes and will be used in the sla-planner. The results will be saved to `<output_dir>/<decode/prefill>_tp<best_tp>_interpolation`. Please change the prefill and decode TP size in the config file to match the best TP sizes obtained from the profiling script.
### Prefill Interpolation Data ### Prefill Interpolation Data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment