Unverified Commit 1a5016b0 authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

feat: add subComponentType in DGD API and uptake in planner (#3200)


Signed-off-by: default avatartmontfort <tmontfort@nvidia.com>
Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
Co-authored-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent 13156361
This diff is collapsed.
...@@ -7,8 +7,6 @@ metadata: ...@@ -7,8 +7,6 @@ metadata:
name: sglang-disagg-planner name: sglang-disagg-planner
spec: spec:
envs: envs:
- name: DYNAMO_SERVICE_CONFIG
value: '{"Prometheus":{"global":{"scrape_interval":"5s"},"scrape_configs":[{"job_name":"prometheus","static_configs":[{"targets":["localhost:9090"]}]},{"job_name":"frontend","static_configs":[{"targets":["sglang-disagg-planner-frontend:8000"]}]}]}}'
- name: DYNAMO_NAMESPACE - name: DYNAMO_NAMESPACE
value: "dynamo" value: "dynamo"
services: services:
...@@ -61,45 +59,11 @@ spec: ...@@ -61,45 +59,11 @@ spec:
--backend=sglang --backend=sglang
--adjustment-interval=60 --adjustment-interval=60
--profile-results-dir=/data/profiling_results --profile-results-dir=/data/profiling_results
Prometheus: # NOTE: this is set on Prometheus to ensure a service is created for the Prometheus component. This is a workaround and should be managed differently.
dynamoNamespace: dynamo
componentType: frontend
replicas: 1
envs:
- name: PYTHONPATH
value: "/workspace/components/planner/src"
livenessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
readinessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
initialDelaySeconds: 30
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
extraPodSpec:
mainContainer:
image: my-registry/sglang-runtime:my-tag
workingDir: /workspace/components/backends/sglang
command:
- /bin/sh
- -c
args:
- "python3 -m dynamo.planner.prometheus"
decode: decode:
dynamoNamespace: dynamo dynamoNamespace: dynamo
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
componentType: worker componentType: worker
subComponentType: decode
replicas: 2 replicas: 2
resources: resources:
limits: limits:
...@@ -131,6 +95,7 @@ spec: ...@@ -131,6 +95,7 @@ spec:
dynamoNamespace: dynamo dynamoNamespace: dynamo
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
componentType: worker componentType: worker
subComponentType: prefill
replicas: 2 replicas: 2
resources: resources:
limits: limits:
......
...@@ -7,8 +7,6 @@ metadata: ...@@ -7,8 +7,6 @@ metadata:
name: trtllm-disagg-planner name: trtllm-disagg-planner
spec: spec:
envs: envs:
- name: DYNAMO_SERVICE_CONFIG
value: '{"Prometheus":{"global":{"scrape_interval":"5s"},"scrape_configs":[{"job_name":"prometheus","static_configs":[{"targets":["localhost:8000"]}]},{"job_name":"frontend","static_configs":[{"targets":["trtllm-disagg-planner-frontend:8000"]}]}]}}'
- name: DYNAMO_NAMESPACE - name: DYNAMO_NAMESPACE
value: "trtllm-disagg-planner" value: "trtllm-disagg-planner"
services: services:
...@@ -41,9 +39,6 @@ spec: ...@@ -41,9 +39,6 @@ spec:
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
componentType: planner componentType: planner
replicas: 1 replicas: 1
envs:
- name: PROMETHEUS_PORT
value: "8000"
livenessProbe: livenessProbe:
exec: exec:
command: command:
...@@ -84,47 +79,11 @@ spec: ...@@ -84,47 +79,11 @@ spec:
- --adjustment-interval=60 - --adjustment-interval=60
- --profile-results-dir=/data/profiling_results - --profile-results-dir=/data/profiling_results
- --prometheus-port=9085 - --prometheus-port=9085
Prometheus: # NOTE: this is set on Prometheus to ensure a service is created for the Prometheus component. This is a workaround and should be managed differently.
dynamoNamespace: trtllm-disagg-planner
componentType: frontend
replicas: 1
envs:
- name: PYTHONPATH
value: "/workspace/components/planner/src"
- name: PROMETHEUS_PORT
value: "8000"
livenessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
readinessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
initialDelaySeconds: 30
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
extraPodSpec:
mainContainer:
image: my-registry/trtllm-runtime:my-tag
workingDir: /workspace/components/backends/trtllm
command:
- python3
args:
- -m
- dynamo.planner.prometheus
TRTLLMDecodeWorker: TRTLLMDecodeWorker:
dynamoNamespace: trtllm-disagg-planner dynamoNamespace: trtllm-disagg-planner
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
componentType: worker componentType: worker
subComponentType: decode
replicas: 1 replicas: 1
livenessProbe: livenessProbe:
httpGet: httpGet:
...@@ -173,6 +132,7 @@ spec: ...@@ -173,6 +132,7 @@ spec:
dynamoNamespace: trtllm-disagg-planner dynamoNamespace: trtllm-disagg-planner
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
componentType: worker componentType: worker
subComponentType: prefill
replicas: 1 replicas: 1
resources: resources:
limits: limits:
......
...@@ -7,12 +7,8 @@ metadata: ...@@ -7,12 +7,8 @@ metadata:
name: vllm-disagg-planner name: vllm-disagg-planner
spec: spec:
envs: envs:
- name: DYNAMO_SERVICE_CONFIG
value: '{"Prometheus":{"global":{"scrape_interval":"5s"},"scrape_configs":[{"job_name":"prometheus","static_configs":[{"targets":["localhost:9090"]}]},{"job_name":"frontend","static_configs":[{"targets":["vllm-disagg-planner-frontend:8000"]}]}]}}'
- name: DYNAMO_NAMESPACE - name: DYNAMO_NAMESPACE
value: "vllm-disagg-planner" value: "vllm-disagg-planner"
- name: PROMETHEUS_PORT
value: "8000"
services: services:
Frontend: Frontend:
dynamoNamespace: vllm-disagg-planner dynamoNamespace: vllm-disagg-planner
...@@ -63,45 +59,11 @@ spec: ...@@ -63,45 +59,11 @@ spec:
--backend=vllm --backend=vllm
--adjustment-interval=60 --adjustment-interval=60
--profile-results-dir=/data/profiling_results --profile-results-dir=/data/profiling_results
Prometheus: # NOTE: this is set on Prometheus to ensure a service is created for the Prometheus component. This is a workaround and should be managed differently.
dynamoNamespace: vllm-disagg-planner
componentType: frontend
replicas: 1
envs:
- name: PYTHONPATH
value: "/workspace/components/planner/src"
livenessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
readinessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
initialDelaySeconds: 30
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
extraPodSpec:
mainContainer:
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag
workingDir: /workspace/components/backends/vllm
command:
- /bin/sh
- -c
args:
- "python3 -m dynamo.planner.prometheus"
VllmDecodeWorker: VllmDecodeWorker:
dynamoNamespace: vllm-disagg-planner dynamoNamespace: vllm-disagg-planner
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
componentType: worker componentType: worker
subComponentType: decode
replicas: 2 replicas: 2
resources: resources:
limits: limits:
...@@ -127,6 +89,7 @@ spec: ...@@ -127,6 +89,7 @@ spec:
dynamoNamespace: vllm-disagg-planner dynamoNamespace: vllm-disagg-planner
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
componentType: worker componentType: worker
subComponentType: prefill
replicas: 2 replicas: 2
resources: resources:
limits: limits:
......
...@@ -8,11 +8,17 @@ __all__ = [ ...@@ -8,11 +8,17 @@ __all__ = [
"LoadPlannerDefaults", "LoadPlannerDefaults",
"SLAPlannerDefaults", "SLAPlannerDefaults",
"ServiceConfig", "ServiceConfig",
"TargetReplica",
"SubComponentType",
] ]
# Import the classes # Import the classes
from dynamo.planner.config import ServiceConfig from dynamo.planner.config import ServiceConfig
from dynamo.planner.defaults import LoadPlannerDefaults, SLAPlannerDefaults from dynamo.planner.defaults import (
from dynamo.planner.kubernetes_connector import KubernetesConnector LoadPlannerDefaults,
SLAPlannerDefaults,
SubComponentType,
)
from dynamo.planner.kubernetes_connector import KubernetesConnector, TargetReplica
from dynamo.planner.planner_connector import PlannerConnector from dynamo.planner.planner_connector import PlannerConnector
from dynamo.planner.virtual_connector import VirtualConnector from dynamo.planner.virtual_connector import VirtualConnector
......
...@@ -15,8 +15,16 @@ ...@@ -15,8 +15,16 @@
import logging import logging
import os import os
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from dynamo.planner.kube import get_current_k8s_namespace from dynamo.planner.kube import get_current_k8s_namespace
from dynamo.planner.utils.exceptions import (
DuplicateSubComponentError,
SubComponentNotFoundError,
)
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging() configure_dynamo_logging()
...@@ -56,6 +64,10 @@ class LoadPlannerDefaults(BasePlannerDefaults): ...@@ -56,6 +64,10 @@ class LoadPlannerDefaults(BasePlannerDefaults):
def _get_default_prometheus_endpoint(port: str, namespace: str): def _get_default_prometheus_endpoint(port: str, namespace: str):
"""Compute default prometheus endpoint using environment variables and Kubernetes service discovery""" """Compute default prometheus endpoint using environment variables and Kubernetes service discovery"""
prometheus_endpoint = os.environ.get("PROMETHEUS_ENDPOINT", "").strip()
if prometheus_endpoint:
logger.debug("Using PROMETHEUS_ENDPOINT override: %s", prometheus_endpoint)
return prometheus_endpoint
k8s_namespace = get_current_k8s_namespace() k8s_namespace = get_current_k8s_namespace()
if k8s_namespace and k8s_namespace != "default": if k8s_namespace and k8s_namespace != "default":
...@@ -124,3 +136,67 @@ WORKER_COMPONENT_NAMES = { ...@@ -124,3 +136,67 @@ WORKER_COMPONENT_NAMES = {
"sglang": SGLangComponentName, "sglang": SGLangComponentName,
"trtllm": TrtllmComponentName, "trtllm": TrtllmComponentName,
} }
class SubComponentType(str, Enum):
PREFILL = "prefill"
DECODE = "decode"
class Service(BaseModel):
name: str
service: dict
def number_replicas(self) -> int:
return self.service.get("replicas", 0)
# TODO: still supporting framework component names for backwards compatibility
# Should be deprecated in favor of service subComponentType
def get_service_from_sub_component_type_or_name(
deployment: dict,
sub_component_type: SubComponentType,
component_name: Optional[str] = None,
) -> Service:
"""
Get the current replicas for a component in a graph deployment
Returns: Service object
Raises:
SubComponentNotFoundError: If no service with the specified subComponentType is found
DuplicateSubComponentError: If multiple services with the same subComponentType are found
"""
services = deployment.get("spec", {}).get("services", {})
# Collect all available subComponentTypes for better error messages
available_types = []
matching_services = []
for curr_name, curr_service in services.items():
service_sub_type = curr_service.get("subComponentType", "")
if service_sub_type:
available_types.append(service_sub_type)
if service_sub_type == sub_component_type.value:
matching_services.append((curr_name, curr_service))
# Check for duplicates
if len(matching_services) > 1:
service_names = [name for name, _ in matching_services]
raise DuplicateSubComponentError(sub_component_type.value, service_names)
# If no service found with subCompontType and fallback component_name is not provided or not found,
# or if the fallback component has a non-empty subComponentType, raise error
if not matching_services and (
not component_name
or component_name not in services
or services[component_name].get("subComponentType", "") != ""
):
raise SubComponentNotFoundError(sub_component_type.value)
# If fallback component_name is provided and exists within services, add to matching_services
elif not matching_services and component_name in services:
matching_services.append((component_name, services[component_name]))
name, service = matching_services[0]
return Service(name=name, service=service)
...@@ -14,12 +14,18 @@ ...@@ -14,12 +14,18 @@
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import os import logging
from typing import Optional from typing import Optional
from kubernetes import client, config from kubernetes import client, config
from kubernetes.config.config_exception import ConfigException from kubernetes.config.config_exception import ConfigException
from dynamo.planner.utils.exceptions import DynamoGraphDeploymentNotFoundError
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
def get_current_k8s_namespace() -> str: def get_current_k8s_namespace() -> str:
"""Get the current namespace if running inside a k8s cluster""" """Get the current namespace if running inside a k8s cluster"""
...@@ -42,9 +48,7 @@ class KubernetesAPI: ...@@ -42,9 +48,7 @@ class KubernetesAPI:
self.custom_api = client.CustomObjectsApi() self.custom_api = client.CustomObjectsApi()
self.current_namespace = k8s_namespace or get_current_k8s_namespace() self.current_namespace = k8s_namespace or get_current_k8s_namespace()
def _get_graph_deployment_from_name( def _get_graph_deployment_from_name(self, graph_deployment_name: str) -> dict:
self, graph_deployment_name: str
) -> Optional[dict]:
"""Get the graph deployment from the dynamo graph deployment name""" """Get the graph deployment from the dynamo graph deployment name"""
return self.custom_api.get_namespaced_custom_object( return self.custom_api.get_namespaced_custom_object(
group="nvidia.com", group="nvidia.com",
...@@ -54,38 +58,27 @@ class KubernetesAPI: ...@@ -54,38 +58,27 @@ class KubernetesAPI:
name=graph_deployment_name, name=graph_deployment_name,
) )
async def get_parent_graph_deployment(self) -> Optional[dict]: def get_graph_deployment(self, graph_deployment_name: str) -> dict:
""" """
Get the parent DynamoGraphDeployment using environment variable. Get the parent DynamoGraphDeployment
Uses DYN_PARENT_DGD_K8S_NAME environment variable and assumes the DGD
is in the same namespace as this component (self.current_namespace).
Returns: Returns:
The DynamoGraphDeployment object or None if env var is not set The DynamoGraphDeployment object
"""
dgd_name = os.getenv("DYN_PARENT_DGD_K8S_NAME")
if not dgd_name:
return None
Raises:
DynamoGraphDeploymentNotFoundError: If the parent graph deployment is not found
"""
try: try:
return self._get_graph_deployment_from_name(dgd_name) return self._get_graph_deployment_from_name(graph_deployment_name)
except client.ApiException as e: except client.ApiException as e:
if e.status == 404: if e.status == 404:
return None raise DynamoGraphDeploymentNotFoundError(
deployment_name=graph_deployment_name,
namespace=self.current_namespace,
)
raise raise
async def get_graph_deployment(self) -> Optional[dict]: def update_graph_replicas(
"""
Get the parent DynamoGraphDeployment using environment variable.
Returns:
The DynamoGraphDeployment object or None if env var is not set
"""
return await self.get_parent_graph_deployment()
async def update_graph_replicas(
self, graph_deployment_name: str, component_name: str, replicas: int self, graph_deployment_name: str, component_name: str, replicas: int
) -> None: ) -> None:
"""Update the replicas count for a component in a DynamoGraphDeployment""" """Update the replicas count for a component in a DynamoGraphDeployment"""
...@@ -99,15 +92,10 @@ class KubernetesAPI: ...@@ -99,15 +92,10 @@ class KubernetesAPI:
body=patch, body=patch,
) )
async def is_deployment_ready(self, graph_deployment_name: str) -> bool: def is_deployment_ready(self, deployment: dict) -> bool:
"""Check if a graph deployment is ready""" """Check if a graph deployment is ready"""
graph_deployment = self._get_graph_deployment_from_name(graph_deployment_name) conditions = deployment.get("status", {}).get("conditions", [])
if not graph_deployment:
raise ValueError(f"Graph deployment {graph_deployment_name} not found")
conditions = graph_deployment.get("status", {}).get("conditions", [])
ready_condition = next( ready_condition = next(
(c for c in conditions if c.get("type") == "Ready"), None (c for c in conditions if c.get("type") == "Ready"), None
) )
...@@ -125,12 +113,7 @@ class KubernetesAPI: ...@@ -125,12 +113,7 @@ class KubernetesAPI:
for attempt in range(max_attempts): for attempt in range(max_attempts):
await asyncio.sleep(delay_seconds) await asyncio.sleep(delay_seconds)
graph_deployment = self._get_graph_deployment_from_name( graph_deployment = self.get_graph_deployment(graph_deployment_name)
graph_deployment_name
)
if not graph_deployment:
raise ValueError(f"Graph deployment {graph_deployment_name} not found")
conditions = graph_deployment.get("status", {}).get("conditions", []) conditions = graph_deployment.get("status", {}).get("conditions", [])
ready_condition = next( ready_condition = next(
...@@ -140,7 +123,7 @@ class KubernetesAPI: ...@@ -140,7 +123,7 @@ class KubernetesAPI:
if ready_condition and ready_condition.get("status") == "True": if ready_condition and ready_condition.get("status") == "True":
return # Deployment is ready return # Deployment is ready
print( logger.info(
f"[Attempt {attempt + 1}/{max_attempts}] " f"[Attempt {attempt + 1}/{max_attempts}] "
f"(status: {ready_condition.get('status') if ready_condition else 'N/A'}, " f"(status: {ready_condition.get('status') if ready_condition else 'N/A'}, "
f"message: {ready_condition.get('message') if ready_condition else 'no condition found'})" f"message: {ready_condition.get('message') if ready_condition else 'no condition found'})"
......
...@@ -14,104 +14,291 @@ ...@@ -14,104 +14,291 @@
# limitations under the License. # limitations under the License.
import logging import logging
import os
import shlex
from typing import Optional from typing import Optional
from pydantic import BaseModel
from dynamo.planner.defaults import (
SubComponentType,
get_service_from_sub_component_type_or_name,
)
from dynamo.planner.kube import KubernetesAPI from dynamo.planner.kube import KubernetesAPI
from dynamo.planner.planner_connector import PlannerConnector from dynamo.planner.planner_connector import PlannerConnector
from dynamo.planner.utils.exceptions import (
DeploymentModelNameMismatchError,
DeploymentValidationError,
EmptyTargetReplicasError,
ModelNameNotFoundError,
PlannerError,
UserProvidedModelNameMismatchError,
)
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging() configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Service(BaseModel):
name: str
service: dict
def number_replicas(self) -> int:
return self.service.get("replicas", 0)
def get_model_name(self) -> Optional[str]:
args = (
self.service.get("extraPodSpec", {})
.get("mainContainer", {})
.get("args", [])
)
args = break_arguments(args)
if (
"--served-model-name" in args
and len(args) > args.index("--served-model-name") + 1
):
return args[args.index("--served-model-name") + 1]
if "--model" in args and len(args) > args.index("--model") + 1:
return args[args.index("--model") + 1]
return None
def break_arguments(args: list[str] | None) -> list[str]:
ans: list[str] = []
if args is None:
return ans
if isinstance(args, str):
# Use shlex.split to properly handle quoted arguments and JSON values
ans = shlex.split(args)
else:
for arg in args:
if arg is not None:
# Use shlex.split to properly handle quoted arguments
ans.extend(shlex.split(arg))
return ans
class TargetReplica(BaseModel):
sub_component_type: SubComponentType
component_name: Optional[str] = None
desired_replicas: int
class KubernetesConnector(PlannerConnector): class KubernetesConnector(PlannerConnector):
def __init__(self, dynamo_namespace: str, k8s_namespace: Optional[str] = None): def __init__(
self,
dynamo_namespace: str,
model_name: Optional[str] = None,
k8s_namespace: Optional[str] = None,
):
self.kube_api = KubernetesAPI(k8s_namespace) self.kube_api = KubernetesAPI(k8s_namespace)
self.dynamo_namespace = dynamo_namespace
async def add_component(self, component_name: str, blocking: bool = True): self.user_provided_model_name: Optional[str] = None
if model_name:
self.user_provided_model_name = (
model_name.lower()
) # normalize model name to lowercase (MDC)
graph_deployment_name = os.getenv("DYN_PARENT_DGD_K8S_NAME")
if not graph_deployment_name:
raise DeploymentValidationError(
["DYN_PARENT_DGD_K8S_NAME environment variable is not set"]
)
self.graph_deployment_name = graph_deployment_name
async def add_component(
self, sub_component_type: SubComponentType, blocking: bool = True
):
"""Add a component by increasing its replica count by 1""" """Add a component by increasing its replica count by 1"""
deployment = await self.kube_api.get_graph_deployment() deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
if deployment is None:
raise ValueError("Parent DynamoGraphDeployment not found")
# get current replicas or 1 if not found service = get_service_from_sub_component_type_or_name(
current_replicas = self._get_current_replicas(deployment, component_name) deployment, sub_component_type
await self.kube_api.update_graph_replicas( )
self._get_graph_deployment_name(deployment), self.kube_api.update_graph_replicas(
component_name, self.graph_deployment_name,
current_replicas + 1, service.name,
service.number_replicas() + 1,
) )
if blocking: if blocking:
await self.kube_api.wait_for_graph_deployment_ready( await self.kube_api.wait_for_graph_deployment_ready(
self._get_graph_deployment_name(deployment) self.graph_deployment_name,
) )
async def remove_component(self, component_name: str, blocking: bool = True): async def remove_component(
self, sub_component_type: SubComponentType, blocking: bool = True
):
"""Remove a component by decreasing its replica count by 1""" """Remove a component by decreasing its replica count by 1"""
deployment = await self.kube_api.get_graph_deployment() deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
if deployment is None:
raise ValueError("Parent DynamoGraphDeployment not found") service = get_service_from_sub_component_type_or_name(
deployment, sub_component_type
# get current replicas or 1 if not found )
current_replicas = self._get_current_replicas(deployment, component_name) if service.number_replicas() > 0:
if current_replicas > 0: self.kube_api.update_graph_replicas(
await self.kube_api.update_graph_replicas( self.graph_deployment_name,
self._get_graph_deployment_name(deployment), service.name,
component_name, service.number_replicas() - 1,
current_replicas - 1,
) )
if blocking: if blocking:
await self.kube_api.wait_for_graph_deployment_ready( await self.kube_api.wait_for_graph_deployment_ready(
self._get_graph_deployment_name(deployment) self.graph_deployment_name,
) )
async def validate_deployment(
self,
prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None,
):
"""
Verify that the deployment contains services with subComponentType prefill and decode and the model name exists.
Will fallback to worker service names for backwards compatibility. (TODO: deprecate)
Raises:
DynamoGraphDeploymentNotFoundError: If the deployment is not found
DeploymentValidationError: If the deployment does not contain services with subComponentType prefill and decode
"""
deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
errors = []
try:
get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.PREFILL,
component_name=prefill_component_name,
)
except PlannerError as e:
errors.append(str(e))
try:
get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.DECODE,
component_name=decode_component_name,
)
except PlannerError as e:
errors.append(str(e))
try:
self.get_model_name(deployment)
except PlannerError as e:
errors.append(str(e))
# Raise combined error if any issues found
if errors:
raise DeploymentValidationError(errors)
def get_model_name(self, deployment: Optional[dict] = None) -> str:
"""Get the model name from the deployment"""
try:
if deployment is None:
deployment = self.kube_api.get_graph_deployment(
self.graph_deployment_name
)
# TODO: benchmarks/profiler/utils/config.py already contains DGD config parsing
# and model name logic, should consolidate
prefill_service = self.get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.PREFILL,
)
decode_service = self.get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.DECODE,
)
prefill_model_name = prefill_service.get_model_name()
decode_model_name = decode_service.get_model_name()
if prefill_model_name is None and decode_model_name is None:
raise ModelNameNotFoundError()
# Check model name between prefill and decode
if prefill_model_name is None:
model_name = decode_model_name
elif decode_model_name is None:
model_name = prefill_model_name
elif prefill_model_name != decode_model_name:
raise DeploymentModelNameMismatchError(
prefill_model_name, decode_model_name
)
else:
model_name = prefill_model_name
except PlannerError as e:
if self.user_provided_model_name:
logger.warning(
f"Failed to get model name from deployment with error: {e}, using provided model name: {self.user_provided_model_name}"
)
model_name = self.user_provided_model_name
else:
raise e
# If user provided a model name and it doesn't match the model name from the deployment, raise an error
if self.user_provided_model_name:
if model_name != self.user_provided_model_name:
raise UserProvidedModelNameMismatchError(
model_name, self.user_provided_model_name
)
if not model_name:
raise ModelNameNotFoundError()
return model_name
async def wait_for_deployment_ready(self):
"""Wait for the deployment to be ready"""
await self.kube_api.wait_for_graph_deployment_ready(
self.graph_deployment_name,
)
async def set_component_replicas( async def set_component_replicas(
self, target_replicas: dict[str, int], blocking: bool = True self, target_replicas: list[TargetReplica], blocking: bool = True
): ):
"""Set the replicas for multiple components at once""" """Set the replicas for multiple components at once"""
if not target_replicas: if not target_replicas:
raise ValueError("target_replicas cannot be empty") raise EmptyTargetReplicasError()
deployment = await self.kube_api.get_graph_deployment() deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
if deployment is None:
raise ValueError("Parent DynamoGraphDeployment not found")
if not await self.kube_api.is_deployment_ready( if not self.kube_api.is_deployment_ready(deployment):
self._get_graph_deployment_name(deployment)
):
logger.warning( logger.warning(
f"Deployment {self._get_graph_deployment_name(deployment)} is not ready, ignoring this scaling" f"Deployment {self.graph_deployment_name} is not ready, ignoring this scaling"
) )
return return
for component_name, replicas in target_replicas.items(): for target_replica in target_replicas:
await self.kube_api.update_graph_replicas( service = get_service_from_sub_component_type_or_name(
self._get_graph_deployment_name(deployment), deployment,
component_name, target_replica.sub_component_type,
replicas, component_name=target_replica.component_name,
) )
current_replicas = service.number_replicas()
if current_replicas != target_replica.desired_replicas:
logger.info(
f"Updating {target_replica.sub_component_type.value} component {service.name} to desired replica count {target_replica.desired_replicas}"
)
self.kube_api.update_graph_replicas(
self.graph_deployment_name,
service.name,
target_replica.desired_replicas,
)
else:
logger.info(
f"{target_replica.sub_component_type.value} component {service.name} already at desired replica count {target_replica.desired_replicas}, skipping"
)
if blocking: if blocking:
await self.kube_api.wait_for_graph_deployment_ready( await self.kube_api.wait_for_graph_deployment_ready(
self._get_graph_deployment_name(deployment) self.graph_deployment_name,
) )
def _get_current_replicas(self, deployment: dict, component_name: str) -> int:
"""Get the current replicas for a component in a graph deployment"""
return (
deployment.get("spec", {})
.get("services", {})
.get(component_name, {})
.get("replicas", 1)
)
def _get_graph_deployment_name(self, deployment: dict) -> str:
"""Get the name of the graph deployment"""
return deployment["metadata"]["name"]
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
...@@ -121,13 +308,21 @@ if __name__ == "__main__": ...@@ -121,13 +308,21 @@ if __name__ == "__main__":
parser.add_argument("--dynamo_namespace", type=str, default="dynamo") parser.add_argument("--dynamo_namespace", type=str, default="dynamo")
parser.add_argument("--k8s_namespace", type=str, default="default") parser.add_argument("--k8s_namespace", type=str, default="default")
parser.add_argument("--action", type=str, choices=["add", "remove"]) parser.add_argument("--action", type=str, choices=["add", "remove"])
parser.add_argument("--component", type=str, default="planner") parser.add_argument(
"--component",
type=str,
choices=[t.value for t in SubComponentType],
default=SubComponentType.PREFILL.value,
help="Target sub-component to scale",
)
parser.add_argument("--blocking", action="store_true") parser.add_argument("--blocking", action="store_true")
args = parser.parse_args() args = parser.parse_args()
connector = KubernetesConnector(args.dynamo_namespace, args.k8s_namespace) connector = KubernetesConnector(args.dynamo_namespace, args.k8s_namespace)
if args.action == "add": if args.action == "add":
task = connector.add_component(args.component, args.blocking) task = connector.add_component(SubComponentType(args.component), args.blocking)
elif args.action == "remove": elif args.action == "remove":
task = connector.remove_component(args.component, args.blocking) task = connector.remove_component(
SubComponentType(args.component), args.blocking
)
asyncio.run(task) asyncio.run(task)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import subprocess
import tempfile
import yaml
from dynamo.planner.config import ServiceConfig
from dynamo.planner.defaults import SLAPlannerDefaults
from dynamo.runtime import DistributedRuntime, dynamo_worker
logger = logging.getLogger(__name__)
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
"""Initialize and run Prometheus server with Dynamo config."""
config = ServiceConfig.get_parsed_config("Prometheus")
logger.info(f"Prometheus config: {config}")
await start_prometheus_server(config)
async def start_prometheus_server(config):
logger.info("Starting prometheus server...")
temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False)
yaml.dump(config, temp_file)
temp_file.close()
config_path = temp_file.name
prometheus_port = SLAPlannerDefaults.port
cmd = [
"prometheus",
f"--config.file={config_path}",
f"--web.listen-address=0.0.0.0:{prometheus_port}",
]
logger.info(f"Prometheus cmd: {cmd}")
process = subprocess.Popen(
cmd,
stdout=None,
stderr=None,
)
# Keep the worker running
try:
while True:
await asyncio.sleep(1)
if process.poll() is not None:
logger.error("Prometheus process died")
break
except asyncio.CancelledError:
logger.info("Shutting down Prometheus...")
process.terminate()
process.wait()
raise
if __name__ == "__main__":
# The dynamo_worker decorator handles runtime setup
import asyncio
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom exceptions for the dynamo planner module.
This module defines a hierarchy of custom exceptions that provide more specific
error handling than generic ValueError exceptions. Each exception includes
contextual information to help with debugging and error handling.
"""
from typing import List
class PlannerError(Exception):
"""Base exception for all planner-related errors.
This serves as the root exception class for all custom exceptions
in the planner module, allowing for broad exception catching when needed.
"""
pass
class DynamoGraphDeploymentNotFoundError(PlannerError):
"""Raised when Parent DynamoGraphDeployment cannot be found.
This typically occurs when:
- The DYN_PARENT_DGD_K8S_NAME environment variable is not set
- The referenced DynamoGraphDeployment doesn't exist in the namespace
"""
def __init__(self, deployment_name: str, namespace: str):
self.deployment_name = deployment_name
self.namespace = namespace
message = f"Parent DynamoGraphDeployment not found (name: '{deployment_name}' in namespace '{namespace}')"
super().__init__(message)
class ComponentError(PlannerError):
"""Base class for subComponent configuration issues.
This serves as a parent class for all exceptions related to
subComponentType configuration problems in DynamoGraphDeployments.
"""
pass
class ModelNameNotFoundError(PlannerError):
"""Raised when the model name is not found in the deployment"""
def __init__(self):
message = "Model name not found in DynamoGraphDeployment"
super().__init__(message)
class DeploymentModelNameMismatchError(PlannerError):
"""Raised when the model name is not the same in the deployment"""
def __init__(self, prefill_model_name: str, decode_model_name: str):
self.prefill_model_name = prefill_model_name
self.decode_model_name = decode_model_name
message = f"Model name mismatch in DynamoGraphDeployment: prefill model name {prefill_model_name} != decode model name {decode_model_name}"
self.message = message
super().__init__(self.message)
class UserProvidedModelNameMismatchError(PlannerError):
"""Raised when the model name is not the same as the user provided model name"""
def __init__(self, model_name: str, user_provided_model_name: str):
self.model_name = model_name
self.user_provided_model_name = user_provided_model_name
message = f"Model name {model_name} does not match expected model name {user_provided_model_name}"
self.message = message
super().__init__(self.message)
class BackendFrameworkNotFoundError(PlannerError):
"""Raised when the backend framework is not supported.
This occurs when the DynamoGraphDeployment contains an unsupported backend framework.
"""
def __init__(self):
message = "Backend framework not found on DynamoGraphDeployment"
super().__init__(message)
class BackendFrameworkInvalidError(PlannerError):
"""Raised when the backend framework does not exist.
This occurs when the DynamoGraphDeployment contains an unsupported backend framework.
"""
def __init__(self, backend_framework: str):
self.backend_framework = backend_framework
message = f"Backend framework {backend_framework} is invalid"
super().__init__(message)
class SubComponentNotFoundError(ComponentError):
"""Raised when a required subComponentType is not found in the deployment.
This occurs when the DynamoGraphDeployment doesn't contain any service
with the requested subComponentType (e.g., 'prefill', 'decode').
"""
def __init__(self, sub_component_type: str):
self.sub_component_type = sub_component_type
message = f"DynamoGraphDeployment must contain a service with subComponentType '{sub_component_type}'"
super().__init__(message)
class DuplicateSubComponentError(ComponentError):
"""Raised when multiple services have the same subComponentType.
This occurs when the DynamoGraphDeployment contains more than one service
with the same subComponentType, which violates the expected uniqueness constraint.
"""
def __init__(self, sub_component_type: str, service_names: List[str]):
self.sub_component_type = sub_component_type
self.service_names = service_names
message = (
f"DynamoGraphDeployment must contain only one service with "
f"subComponentType '{sub_component_type}', but found multiple: "
f"{', '.join(sorted(service_names))}"
)
super().__init__(message)
class DeploymentValidationError(PlannerError):
"""Raised when deployment validation fails for multiple components.
This is used to aggregate multiple validation errors into a single exception,
providing a comprehensive view of all validation issues.
"""
def __init__(self, errors: List[str]):
self.errors = errors
message = f"Service verification failed: {'; '.join(errors)}"
super().__init__(message)
class EmptyTargetReplicasError(PlannerError):
"""Raised when target_replicas is empty or invalid.
This occurs when attempting to set component replicas with an empty
or invalid target_replicas dictionary.
"""
def __init__(
self,
):
message = "target_replicas cannot be empty"
super().__init__(message)
...@@ -118,4 +118,9 @@ def create_sla_planner_parser() -> argparse.ArgumentParser: ...@@ -118,4 +118,9 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
default=SLAPlannerDefaults.no_correction, default=SLAPlannerDefaults.no_correction,
help="Disable correction factor", help="Disable correction factor",
) )
parser.add_argument(
"--model-name",
type=str,
help="Model name of deployment (only required for virtual environment)",
)
return parser return parser
...@@ -11,7 +11,12 @@ from typing import Optional ...@@ -11,7 +11,12 @@ from typing import Optional
from prometheus_client import Gauge, start_http_server from prometheus_client import Gauge, start_http_server
from dynamo.planner import KubernetesConnector, VirtualConnector from dynamo.planner import (
KubernetesConnector,
SubComponentType,
TargetReplica,
VirtualConnector,
)
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES, SLAPlannerDefaults from dynamo.planner.defaults import WORKER_COMPONENT_NAMES, SLAPlannerDefaults
from dynamo.planner.utils.load_predictor import LOAD_PREDICTORS from dynamo.planner.utils.load_predictor import LOAD_PREDICTORS
from dynamo.planner.utils.perf_interpolation import ( from dynamo.planner.utils.perf_interpolation import (
...@@ -63,22 +68,30 @@ class Planner: ...@@ -63,22 +68,30 @@ class Planner:
self.args = args self.args = args
self.dryrun = dryrun self.dryrun = dryrun
# Rely on getting model name from connector
self.model_name: Optional[str] = None
if not self.dryrun: if not self.dryrun:
self.runtime = runtime self.runtime = runtime
self.namespace = args.namespace self.namespace = args.namespace
if not args.no_operation: if not args.no_operation:
if args.environment == "kubernetes": if args.environment == "kubernetes":
self.connector = KubernetesConnector(self.namespace) self.connector = KubernetesConnector(
self.namespace, self.model_name
)
elif args.environment == "virtual": elif args.environment == "virtual":
self.connector = VirtualConnector( self.connector = VirtualConnector(
runtime, self.namespace, args.backend runtime,
self.namespace,
args.model_name,
) )
else: else:
raise ValueError(f"Invalid environment: {args.environment}") raise ValueError(f"Invalid environment: {args.environment}")
self.prometheus_api_client = PrometheusAPIClient( self.prometheus_api_client = PrometheusAPIClient(
SLAPlannerDefaults.prometheus_endpoint SLAPlannerDefaults.prometheus_endpoint,
args.namespace,
) )
self.num_req_predictor = LOAD_PREDICTORS[args.load_predictor]( self.num_req_predictor = LOAD_PREDICTORS[args.load_predictor](
...@@ -121,6 +134,13 @@ class Planner: ...@@ -121,6 +134,13 @@ class Planner:
self.prefill_interpolator = PrefillInterpolator(args.profile_results_dir) self.prefill_interpolator = PrefillInterpolator(args.profile_results_dir)
self.decode_interpolator = DecodeInterpolator(args.profile_results_dir) self.decode_interpolator = DecodeInterpolator(args.profile_results_dir)
self.prefill_component_name = WORKER_COMPONENT_NAMES[
self.args.backend
].prefill_worker_k8s_name
self.decode_component_name = WORKER_COMPONENT_NAMES[
self.args.backend
].decode_worker_k8s_name
if not self.dryrun: if not self.dryrun:
self.prefill_client = None self.prefill_client = None
self.workers_client = None self.workers_client = None
...@@ -230,27 +250,33 @@ class Planner: ...@@ -230,27 +250,33 @@ class Planner:
self.num_d_workers_gauge.set(len(self.d_endpoints)) self.num_d_workers_gauge.set(len(self.d_endpoints))
self.last_metrics.ttft = self.prometheus_api_client.get_avg_time_to_first_token( self.last_metrics.ttft = self.prometheus_api_client.get_avg_time_to_first_token(
f"{self.args.adjustment_interval}s" f"{self.args.adjustment_interval}s",
self.model_name,
) )
self.last_metrics.itl = self.prometheus_api_client.get_avg_inter_token_latency( self.last_metrics.itl = self.prometheus_api_client.get_avg_inter_token_latency(
f"{self.args.adjustment_interval}s" f"{self.args.adjustment_interval}s",
self.model_name,
) )
self.last_metrics.num_req = self.prometheus_api_client.get_avg_request_count( self.last_metrics.num_req = self.prometheus_api_client.get_avg_request_count(
f"{self.args.adjustment_interval}s" f"{self.args.adjustment_interval}s",
self.model_name,
) )
self.last_metrics.request_duration = ( self.last_metrics.request_duration = (
self.prometheus_api_client.get_avg_request_duration( self.prometheus_api_client.get_avg_request_duration(
f"{self.args.adjustment_interval}s" f"{self.args.adjustment_interval}s",
self.model_name,
) )
) )
self.last_metrics.isl = ( self.last_metrics.isl = (
self.prometheus_api_client.get_avg_input_sequence_tokens( self.prometheus_api_client.get_avg_input_sequence_tokens(
f"{self.args.adjustment_interval}s" f"{self.args.adjustment_interval}s",
self.model_name,
) )
) )
self.last_metrics.osl = ( self.last_metrics.osl = (
self.prometheus_api_client.get_avg_output_sequence_tokens( self.prometheus_api_client.get_avg_output_sequence_tokens(
f"{self.args.adjustment_interval}s" f"{self.args.adjustment_interval}s",
self.model_name,
) )
) )
...@@ -429,19 +455,43 @@ class Planner: ...@@ -429,19 +455,43 @@ class Planner:
return return
if not self.args.no_operation: if not self.args.no_operation:
target_replicas = { target_replicas = [
WORKER_COMPONENT_NAMES[ TargetReplica(
self.args.backend sub_component_type=SubComponentType.PREFILL,
].prefill_worker_k8s_name: next_num_p, component_name=self.prefill_component_name,
WORKER_COMPONENT_NAMES[ desired_replicas=next_num_p,
self.args.backend ),
].decode_worker_k8s_name: next_num_d, TargetReplica(
} sub_component_type=SubComponentType.DECODE,
component_name=self.decode_component_name,
desired_replicas=next_num_d,
),
]
await self.connector.set_component_replicas(target_replicas, blocking=False) await self.connector.set_component_replicas(target_replicas, blocking=False)
async def run(self): async def run(self):
"""Main loop for the planner""" """Main loop for the planner"""
if not self.args.no_operation:
# Fail fast if the deployment is not valid
logger.info("Validating deployment...")
# TODO: still supporting framework component names for backwards compatibility
# Should be deprecated in favor of service subComponentType
await self.connector.validate_deployment(
prefill_component_name=self.prefill_component_name,
decode_component_name=self.decode_component_name,
)
logger.info("Successfully validated the deployment")
await self.connector.wait_for_deployment_ready()
model_name = self.connector.get_model_name()
logger.info(f"Detected model name from deployment: {model_name}")
self.model_name = (
model_name.lower()
) # normalize model name to lowercase (MDC)
self.last_adjustment_time = time.time() self.last_adjustment_time = time.time()
while True: while True:
...@@ -453,6 +503,7 @@ class Planner: ...@@ -453,6 +503,7 @@ class Planner:
): ):
self.last_adjustment_time = time.time() self.last_adjustment_time = time.time()
logger.info("New adjustment interval started!") logger.info("New adjustment interval started!")
await self.observe_metrics() await self.observe_metrics()
await self.make_adjustments() await self.make_adjustments()
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
# limitations under the License. # limitations under the License.
import logging import logging
import typing
from prometheus_api_client import PrometheusConnect from prometheus_api_client import PrometheusConnect
from pydantic import BaseModel, ValidationError
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -23,12 +25,33 @@ configure_dynamo_logging() ...@@ -23,12 +25,33 @@ configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FrontendMetric(BaseModel):
container: typing.Optional[str] = None
dynamo_namespace: typing.Optional[str] = None
endpoint: typing.Optional[str] = None
instance: typing.Optional[str] = None
job: typing.Optional[str] = None
model: typing.Optional[str] = None
namespace: typing.Optional[str] = None
pod: typing.Optional[str] = None
class FrontendMetricContainer(BaseModel):
metric: FrontendMetric
value: typing.Tuple[float, float] # [timestamp, value]
class PrometheusAPIClient: class PrometheusAPIClient:
def __init__(self, url: str): def __init__(self, url: str, dynamo_namespace: str):
self.prom = PrometheusConnect(url=url, disable_ssl=True) self.prom = PrometheusConnect(url=url, disable_ssl=True)
self.dynamo_namespace = dynamo_namespace
def _get_average_metric( def _get_average_metric(
self, metric_name: str, interval: str, operation_name: str self,
metric_name: str,
interval: str,
operation_name: str,
model_name: str,
) -> float: ) -> float:
""" """
Helper method to get average metrics using the pattern: Helper method to get average metrics using the pattern:
...@@ -50,57 +73,92 @@ class PrometheusAPIClient: ...@@ -50,57 +73,92 @@ class PrometheusAPIClient:
if not result: if not result:
# No data available yet (no requests made) - return 0 silently # No data available yet (no requests made) - return 0 silently
return 0 return 0
return float(result[0]["value"][1]) metrics_containers = parse_frontend_metric_containers(result)
values = []
for container in metrics_containers:
if (
container.metric.model == model_name
and container.metric.dynamo_namespace == self.dynamo_namespace
):
values.append(container.value[1])
if not values:
return 0
return sum(values) / len(values)
except Exception as e: except Exception as e:
logger.error(f"Error getting {operation_name}: {e}") logger.error(f"Error getting {operation_name}: {e}")
return 0 return 0
def get_avg_inter_token_latency(self, interval: str): def get_avg_inter_token_latency(self, interval: str, model_name: str):
return self._get_average_metric( return self._get_average_metric(
"inter_token_latency_seconds", "inter_token_latency_seconds",
interval, interval,
"avg inter token latency", "avg inter token latency",
model_name,
) )
def get_avg_time_to_first_token(self, interval: str): def get_avg_time_to_first_token(self, interval: str, model_name: str):
return self._get_average_metric( return self._get_average_metric(
"time_to_first_token_seconds", "time_to_first_token_seconds",
interval, interval,
"avg time to first token", "avg time to first token",
model_name,
) )
def get_avg_request_duration(self, interval: str): def get_avg_request_duration(self, interval: str, model_name: str):
return self._get_average_metric( return self._get_average_metric(
"request_duration_seconds", "request_duration_seconds",
interval, interval,
"avg request duration", "avg request duration",
model_name,
) )
def get_avg_request_count(self, interval: str): def get_avg_request_count(self, interval: str, model_name: str):
# This function follows a different query pattern than the other metrics # This function follows a different query pattern than the other metrics
try: try:
raw_res = self.prom.custom_query( raw_res = self.prom.custom_query(
query=f"increase(dynamo_frontend_requests_total[{interval}])" query=f"increase(dynamo_frontend_requests_total[{interval}])"
) )
metrics_containers = parse_frontend_metric_containers(raw_res)
total_count = 0.0 total_count = 0.0
for res in raw_res: for container in metrics_containers:
# count all success/failed and stream/non-stream requests if (
total_count += float(res["value"][1]) container.metric.model == model_name
and container.metric.dynamo_namespace == self.dynamo_namespace
):
total_count += container.value[1]
return total_count return total_count
except Exception as e: except Exception as e:
logger.error(f"Error getting avg request count: {e}") logger.error(f"Error getting avg request count: {e}")
return 0 return 0
def get_avg_input_sequence_tokens(self, interval: str): def get_avg_input_sequence_tokens(self, interval: str, model_name: str):
return self._get_average_metric( return self._get_average_metric(
"input_sequence_tokens", "input_sequence_tokens",
interval, interval,
"avg input sequence tokens", "avg input sequence tokens",
model_name,
) )
def get_avg_output_sequence_tokens(self, interval: str): def get_avg_output_sequence_tokens(self, interval: str, model_name: str):
return self._get_average_metric( return self._get_average_metric(
"output_sequence_tokens", "output_sequence_tokens",
interval, interval,
"avg output sequence tokens", "avg output sequence tokens",
model_name,
) )
def parse_frontend_metric_containers(
result: list[dict],
) -> list[FrontendMetricContainer]:
metrics_containers: list[FrontendMetricContainer] = []
for res in result:
try:
metrics_containers.append(FrontendMetricContainer.model_validate(res))
except ValidationError as e:
logger.error(f"Error parsing frontend metric container: {e}")
continue
return metrics_containers
...@@ -6,8 +6,9 @@ import os ...@@ -6,8 +6,9 @@ import os
from typing import Optional from typing import Optional
from dynamo._core import VirtualConnectorCoordinator from dynamo._core import VirtualConnectorCoordinator
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES from dynamo.planner import SubComponentType, TargetReplica
from dynamo.planner.planner_connector import PlannerConnector from dynamo.planner.planner_connector import PlannerConnector
from dynamo.planner.utils.exceptions import EmptyTargetReplicasError
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -32,7 +33,10 @@ class VirtualConnector(PlannerConnector): ...@@ -32,7 +33,10 @@ class VirtualConnector(PlannerConnector):
""" """
def __init__( def __init__(
self, runtime: DistributedRuntime, dynamo_namespace: str, backend: str self,
runtime: DistributedRuntime,
dynamo_namespace: str,
model_name: Optional[str] = None,
): ):
self.connector = VirtualConnectorCoordinator( self.connector = VirtualConnectorCoordinator(
runtime, runtime,
...@@ -42,8 +46,12 @@ class VirtualConnector(PlannerConnector): ...@@ -42,8 +46,12 @@ class VirtualConnector(PlannerConnector):
SCALING_MAX_RETRIES, SCALING_MAX_RETRIES,
) )
self.backend = backend if model_name is None:
self.worker_component_names = WORKER_COMPONENT_NAMES[backend] raise ValueError("Model name is required for virtual connector")
self.model_name = model_name.lower() # normalize model name to lowercase (MDC)
self.dynamo_namespace = dynamo_namespace
async def _async_init(self): async def _async_init(self):
"""Async initialization that must be called after __init__""" """Async initialization that must be called after __init__"""
...@@ -59,47 +67,32 @@ class VirtualConnector(PlannerConnector): ...@@ -59,47 +67,32 @@ class VirtualConnector(PlannerConnector):
"""Wait for the deployment environment to report that scaling is complete""" """Wait for the deployment environment to report that scaling is complete"""
await self.connector.wait_for_scaling_completion() await self.connector.wait_for_scaling_completion()
def _component_to_worker_type(self, component_name: str) -> Optional[str]: async def add_component(
"""Map component name to worker type (prefill or decode)""" self, sub_component_type: SubComponentType, blocking: bool = True
if component_name == self.worker_component_names.prefill_worker_k8s_name: ):
return "prefill"
elif component_name == self.worker_component_names.decode_worker_k8s_name:
return "decode"
else:
return None
async def add_component(self, component_name: str, blocking: bool = True):
"""Add a component by increasing its replica count by 1""" """Add a component by increasing its replica count by 1"""
worker_type = self._component_to_worker_type(component_name)
if worker_type is None:
logger.warning(f"Unknown component name: {component_name}, skipping")
return
state = self.connector.read_state() state = self.connector.read_state()
if worker_type == "prefill": if sub_component_type == SubComponentType.PREFILL:
await self._update_scaling_decision( await self._update_scaling_decision(
num_prefill=state.num_prefill_workers + 1 num_prefill=state.num_prefill_workers + 1
) )
elif worker_type == "decode": elif sub_component_type == SubComponentType.DECODE:
await self._update_scaling_decision(num_decode=state.num_decode_workers + 1) await self._update_scaling_decision(num_decode=state.num_decode_workers + 1)
if blocking: if blocking:
await self._wait_for_scaling_completion() await self._wait_for_scaling_completion()
async def remove_component(self, component_name: str, blocking: bool = True): async def remove_component(
self, sub_component_type: SubComponentType, blocking: bool = True
):
"""Remove a component by decreasing its replica count by 1""" """Remove a component by decreasing its replica count by 1"""
worker_type = self._component_to_worker_type(component_name)
if worker_type is None:
logger.warning(f"Unknown component name: {component_name}, skipping")
return
state = self.connector.read_state() state = self.connector.read_state()
if worker_type == "prefill": if sub_component_type == SubComponentType.PREFILL:
new_count = max(0, state.num_prefill_workers - 1) new_count = max(0, state.num_prefill_workers - 1)
await self._update_scaling_decision(num_prefill=new_count) await self._update_scaling_decision(num_prefill=new_count)
elif worker_type == "decode": elif sub_component_type == SubComponentType.DECODE:
new_count = max(0, state.num_decode_workers - 1) new_count = max(0, state.num_decode_workers - 1)
await self._update_scaling_decision(num_decode=new_count) await self._update_scaling_decision(num_decode=new_count)
...@@ -107,25 +100,20 @@ class VirtualConnector(PlannerConnector): ...@@ -107,25 +100,20 @@ class VirtualConnector(PlannerConnector):
await self._wait_for_scaling_completion() await self._wait_for_scaling_completion()
async def set_component_replicas( async def set_component_replicas(
self, target_replicas: dict[str, int], blocking: bool = True self, target_replicas: list[TargetReplica], blocking: bool = True
): ):
"""Set the replicas for multiple components at once""" """Set the replicas for multiple components at once"""
if not target_replicas: if not target_replicas:
raise ValueError("target_replicas cannot be empty") raise EmptyTargetReplicasError()
num_prefill = None num_prefill = None
num_decode = None num_decode = None
for component_name, replicas in target_replicas.items(): for target_replica in target_replicas:
worker_type = self._component_to_worker_type(component_name) if target_replica.sub_component_type == SubComponentType.PREFILL:
if worker_type is None: num_prefill = target_replica.desired_replicas
logger.warning(f"Unknown component name: {component_name}, skipping") elif target_replica.sub_component_type == SubComponentType.DECODE:
continue num_decode = target_replica.desired_replicas
if worker_type == "prefill":
num_prefill = replicas
elif worker_type == "decode":
num_decode = replicas
if num_prefill is None and num_decode is None: if num_prefill is None and num_decode is None:
return return
...@@ -137,3 +125,19 @@ class VirtualConnector(PlannerConnector): ...@@ -137,3 +125,19 @@ class VirtualConnector(PlannerConnector):
if blocking: if blocking:
await self._wait_for_scaling_completion() await self._wait_for_scaling_completion()
async def validate_deployment(
self,
prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None,
):
"""Validate the deployment"""
pass
async def wait_for_deployment_ready(self):
"""Wait for the deployment to be ready"""
await self._wait_for_scaling_completion()
async def get_model_name(self) -> str:
"""Get the model name from the deployment"""
return self.model_name
...@@ -13,13 +13,14 @@ ...@@ -13,13 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from typing import Any, Dict from typing import Any, Dict
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from kubernetes import client
from dynamo.planner.kube import KubernetesAPI from dynamo.planner.kube import KubernetesAPI
from dynamo.planner.utils.exceptions import DynamoGraphDeploymentNotFoundError
@pytest.fixture @pytest.fixture
...@@ -75,6 +76,21 @@ def test_get_graph_deployment_from_name(k8s_api, mock_custom_api): ...@@ -75,6 +76,21 @@ def test_get_graph_deployment_from_name(k8s_api, mock_custom_api):
) )
def test_update_graph_replicas(k8s_api, mock_custom_api):
mock_custom_api.patch_namespaced_custom_object.return_value = None
k8s_api.update_graph_replicas("test-deployment", "test-component", 1)
mock_custom_api.patch_namespaced_custom_object.assert_called_once_with(
group="nvidia.com",
version="v1alpha1",
namespace=k8s_api.current_namespace,
plural="dynamographdeployments",
name="test-deployment",
body={"spec": {"services": {"test-component": {"replicas": 1}}}},
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_is_deployment_ready_true(k8s_api, mock_custom_api): async def test_is_deployment_ready_true(k8s_api, mock_custom_api):
"""Test is_deployment_ready method when deployment is ready""" """Test is_deployment_ready method when deployment is ready"""
...@@ -87,12 +103,8 @@ async def test_is_deployment_ready_true(k8s_api, mock_custom_api): ...@@ -87,12 +103,8 @@ async def test_is_deployment_ready_true(k8s_api, mock_custom_api):
} }
} }
# Mock the method on the instance result = k8s_api.is_deployment_ready(mock_deployment)
with patch.object( assert result is True
k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
):
result = await k8s_api.is_deployment_ready("test-deployment")
assert result is True
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -109,24 +121,8 @@ async def test_is_deployment_ready_false(k8s_api, mock_custom_api): ...@@ -109,24 +121,8 @@ async def test_is_deployment_ready_false(k8s_api, mock_custom_api):
] ]
} }
} }
result = k8s_api.is_deployment_ready(mock_deployment)
# Mock the method on the instance assert result is False
with patch.object(
k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
):
result = await k8s_api.is_deployment_ready("test-deployment")
assert result is False
@pytest.mark.asyncio
async def test_is_deployment_ready_not_found(k8s_api, mock_custom_api):
"""Test is_deployment_ready method when deployment is not found"""
# Mock the method on the instance
with patch.object(k8s_api, "_get_graph_deployment_from_name", return_value=None):
with pytest.raises(ValueError) as exc_info:
await k8s_api.is_deployment_ready("test-deployment")
assert "not found" in str(exc_info.value)
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -142,9 +138,7 @@ async def test_wait_for_graph_deployment_ready_success(k8s_api, mock_custom_api) ...@@ -142,9 +138,7 @@ async def test_wait_for_graph_deployment_ready_success(k8s_api, mock_custom_api)
} }
# Mock the method on the instance # Mock the method on the instance
with patch.object( with patch.object(k8s_api, "get_graph_deployment", return_value=mock_deployment):
k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
):
# Test with minimal attempts and delay for faster testing # Test with minimal attempts and delay for faster testing
await k8s_api.wait_for_graph_deployment_ready( await k8s_api.wait_for_graph_deployment_ready(
"test-deployment", max_attempts=2, delay_seconds=0.1 "test-deployment", max_attempts=2, delay_seconds=0.1
...@@ -168,9 +162,7 @@ async def test_wait_for_graph_deployment_ready_timeout(k8s_api, mock_custom_api) ...@@ -168,9 +162,7 @@ async def test_wait_for_graph_deployment_ready_timeout(k8s_api, mock_custom_api)
} }
# Mock the method on the instance # Mock the method on the instance
with patch.object( with patch.object(k8s_api, "get_graph_deployment", return_value=mock_deployment):
k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
):
# Test with minimal attempts and delay for faster testing # Test with minimal attempts and delay for faster testing
with pytest.raises(TimeoutError) as exc_info: with pytest.raises(TimeoutError) as exc_info:
await k8s_api.wait_for_graph_deployment_ready( await k8s_api.wait_for_graph_deployment_ready(
...@@ -183,15 +175,21 @@ async def test_wait_for_graph_deployment_ready_timeout(k8s_api, mock_custom_api) ...@@ -183,15 +175,21 @@ async def test_wait_for_graph_deployment_ready_timeout(k8s_api, mock_custom_api)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wait_for_graph_deployment_not_found(k8s_api, mock_custom_api): async def test_wait_for_graph_deployment_not_found(k8s_api, mock_custom_api):
"""Test wait_for_graph_deployment_ready when deployment is not found""" """Test wait_for_graph_deployment_ready when deployment is not found"""
# Mock the _get_graph_deployment_from_name response to return None
with patch.object(k8s_api, "_get_graph_deployment_from_name", return_value=None):
# Test with minimal attempts and delay for faster testing
with pytest.raises(ValueError) as exc_info:
await k8s_api.wait_for_graph_deployment_ready(
"test-deployment", max_attempts=2, delay_seconds=0.1
)
assert "not found" in str(exc_info.value) mock_custom_api.get_namespaced_custom_object.side_effect = client.ApiException(
status=404
)
# Test with minimal attempts and delay for faster testing
with pytest.raises(DynamoGraphDeploymentNotFoundError) as exc_info:
await k8s_api.wait_for_graph_deployment_ready(
"test-deployment", max_attempts=2, delay_seconds=0.1
)
# Validate the exception fields
exception = exc_info.value
assert exception.deployment_name == "test-deployment"
assert exception.namespace == "default"
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -200,9 +198,7 @@ async def test_wait_for_graph_deployment_no_conditions(k8s_api, mock_custom_api) ...@@ -200,9 +198,7 @@ async def test_wait_for_graph_deployment_no_conditions(k8s_api, mock_custom_api)
# Mock the _get_graph_deployment_from_name response with no conditions # Mock the _get_graph_deployment_from_name response with no conditions
mock_deployment: Dict[str, Any] = {"status": {}} mock_deployment: Dict[str, Any] = {"status": {}}
with patch.object( with patch.object(k8s_api, "get_graph_deployment", return_value=mock_deployment):
k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
):
# Test with minimal attempts and delay for faster testing # Test with minimal attempts and delay for faster testing
with pytest.raises(TimeoutError) as exc_info: with pytest.raises(TimeoutError) as exc_info:
await k8s_api.wait_for_graph_deployment_ready( await k8s_api.wait_for_graph_deployment_ready(
...@@ -249,37 +245,28 @@ async def test_wait_for_graph_deployment_ready_on_second_attempt( ...@@ -249,37 +245,28 @@ async def test_wait_for_graph_deployment_ready_on_second_attempt(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_parent_graph_deployment_with_env_var(k8s_api, mock_custom_api): async def test_get_graph_deployment(k8s_api, mock_custom_api):
"""Test get_parent_graph_deployment with environment variable set""" """Test get_graph_deployment"""
mock_deployment = {"metadata": {"name": "parent-dgd"}} mock_deployment = {"metadata": {"name": "parent-dgd"}}
with patch.dict(os.environ, {"DYN_PARENT_DGD_K8S_NAME": "parent-dgd"}): with patch.object(
with patch.object( k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment ) as mock_get:
) as mock_get: result = await k8s_api.get_graph_deployment("parent-dgd")
result = await k8s_api.get_parent_graph_deployment()
assert result == mock_deployment
mock_get.assert_called_once_with("parent-dgd")
@pytest.mark.asyncio assert result == mock_deployment
async def test_get_parent_graph_deployment_without_env_var(k8s_api, mock_custom_api): mock_get.assert_called_once_with("parent-dgd")
"""Test get_parent_graph_deployment without environment variable"""
with patch.dict(os.environ, {}, clear=True):
result = await k8s_api.get_parent_graph_deployment()
assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_graph_deployment_delegates_to_parent(k8s_api, mock_custom_api): async def test_get_graph_deployment_not_found(k8s_api, mock_custom_api):
"""Test get_graph_deployment delegates to get_parent_graph_deployment""" """Test get_graph_deployment when deployment is not found"""
mock_deployment = {"metadata": {"name": "parent-dgd"}} k8s_api.custom_api.get_namespaced_custom_object.side_effect = client.ApiException(
status=404
with patch.object( )
k8s_api, "get_parent_graph_deployment", return_value=mock_deployment with pytest.raises(DynamoGraphDeploymentNotFoundError) as exc_info:
) as mock_parent: await k8s_api.get_graph_deployment("parent-dgd")
result = await k8s_api.get_graph_deployment()
assert result == mock_deployment exception = exc_info.value
mock_parent.assert_called_once() assert exception.deployment_name == "parent-dgd"
assert exception.namespace == "default"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from unittest.mock import patch
import pytest
from dynamo.planner.utils.prometheus import (
FrontendMetric,
FrontendMetricContainer,
PrometheusAPIClient,
)
@pytest.fixture
def mock_prometheus_result():
"""Fixture providing mock prometheus result data for testing"""
return [
{
"metric": {
"container": "main",
"dynamo_namespace": "different_namespace",
"model": "different_model",
"namespace": "dynamo-system",
},
"value": [1758857776.071, 10.5],
},
{
"metric": {
"container": "main",
"dynamo_namespace": "target_namespace",
"model": "target_model",
"namespace": "dynamo-system",
},
"value": [1758857776.071, 42.7],
},
{
"metric": {
"container": "worker",
"dynamo_namespace": "target_namespace",
"model": "target_model",
"namespace": "dynamo-system",
},
"value": [1758857776.071, 35.5],
},
{
"metric": {
"container": "sidecar",
"dynamo_namespace": "target_namespace",
"model": "target_model",
"namespace": "dynamo-system",
},
"value": [30.0, 15.5],
},
]
def test_frontend_metric_container_with_nan_value():
test_data = {
"metric": {
"container": "main",
"dynamo_namespace": "vllm-disagg-planner",
"endpoint": "http",
"instance": "10.244.2.163:8000",
"job": "dynamo-system/dynamo-frontend",
"model": "qwen/qwen3-0.6b",
"namespace": "dynamo-system",
"pod": "vllm-disagg-planner-frontend-865f84c49-6q7s5",
},
"value": [1758857776.071, "NaN"],
}
container = FrontendMetricContainer.model_validate(test_data)
assert container.metric.container == "main"
assert container.metric.dynamo_namespace == "vllm-disagg-planner"
assert container.metric.endpoint == "http"
assert container.metric.instance == "10.244.2.163:8000"
assert container.metric.job == "dynamo-system/dynamo-frontend"
assert container.metric.model == "qwen/qwen3-0.6b"
assert container.metric.namespace == "dynamo-system"
assert container.metric.pod == "vllm-disagg-planner-frontend-865f84c49-6q7s5"
assert container.value[0] == 1758857776.071
assert math.isnan(
container.value[1]
) # becomes special float value that can't be asserted to itself
test_data["value"][1] = 42.5 # type: ignore[index]
container = FrontendMetricContainer.model_validate(test_data)
assert container.value[1] == 42.5
def test_frontend_metric_with_partial_data():
"""Test FrontendMetric with partial data (optional fields)"""
test_data = {
"container": "main",
"model": "qwen/qwen3-0.6b",
"namespace": "dynamo-system",
}
metric = FrontendMetric.model_validate(test_data)
# Assert provided fields
assert metric.container == "main"
assert metric.model == "qwen/qwen3-0.6b"
assert metric.namespace == "dynamo-system"
# Assert optional fields are None
assert metric.dynamo_namespace is None
assert metric.endpoint is None
assert metric.instance is None
assert metric.job is None
assert metric.pod is None
def test_get_average_metric_none_result():
"""Test _get_average_metric when prometheus returns None"""
client = PrometheusAPIClient("http://localhost:9090", "test_namespace")
with patch.object(client.prom, "custom_query") as mock_query:
mock_query.return_value = None
result = client._get_average_metric(
metric_name="test_metric",
interval="60s",
operation_name="test operation",
model_name="test_model",
)
assert result == 0
def test_get_average_metric_empty_result():
"""Test _get_average_metric when prometheus returns empty list"""
client = PrometheusAPIClient("http://localhost:9090", "test_namespace")
with patch.object(client.prom, "custom_query") as mock_query:
mock_query.return_value = []
result = client._get_average_metric(
metric_name="test_metric",
interval="60s",
operation_name="test operation",
model_name="test_model",
)
assert result == 0
def test_get_average_metric_no_matching_containers(mock_prometheus_result):
"""Test _get_average_metric with valid containers but no matches"""
client = PrometheusAPIClient("http://localhost:9090", "test_namespace")
with patch.object(client.prom, "custom_query") as mock_query:
# Use only the first container which doesn't match target criteria
mock_query.return_value = [mock_prometheus_result[0]]
result = client._get_average_metric(
metric_name="test_metric",
interval="60s",
operation_name="test operation",
model_name="target_model",
)
assert result == 0
def test_get_average_metric_one_matching_container(mock_prometheus_result):
"""Test _get_average_metric with one matching container"""
client = PrometheusAPIClient("http://localhost:9090", "target_namespace")
with patch.object(client.prom, "custom_query") as mock_query:
# Use first two containers - one doesn't match, one does
mock_query.return_value = mock_prometheus_result[:2]
result = client._get_average_metric(
metric_name="test_metric",
interval="60s",
operation_name="test operation",
model_name="target_model",
)
assert result == 42.7
def test_get_average_metric_with_validation_error():
"""Test _get_average_metric with one valid container and one that fails validation"""
client = PrometheusAPIClient("http://localhost:9090", "target_namespace")
mock_result = [
{
"metric": {
"container": "main",
"dynamo_namespace": "target_namespace",
"model": "target_model",
"namespace": "dynamo-system",
},
"value": [1758857776.071, 25.5],
},
{
# Invalid structure - missing required fields that will cause validation error
"invalid_structure": "bad_data",
"value": "not_a_tuple",
},
]
with patch.object(client.prom, "custom_query") as mock_query:
mock_query.return_value = mock_result
result = client._get_average_metric(
metric_name="test_metric",
interval="60s",
operation_name="test operation",
model_name="target_model",
)
assert result == 25.5
def test_get_average_metric_multiple_matching_containers(mock_prometheus_result):
"""Test _get_average_metric with multiple matching containers returns average"""
client = PrometheusAPIClient("http://localhost:9090", "target_namespace")
with patch.object(client.prom, "custom_query") as mock_query:
# Use containers 1, 2, 3 which all match target criteria
mock_query.return_value = mock_prometheus_result[1:]
result = client._get_average_metric(
metric_name="test_metric",
interval="60s",
operation_name="test operation",
model_name="target_model",
)
# Average of 42.7, 35.5, and 15.5 (using value[1] from each container)
expected = (42.7 + 35.5 + 15.5) / 3
assert result == expected
...@@ -11,7 +11,7 @@ import logging ...@@ -11,7 +11,7 @@ import logging
import pytest import pytest
from dynamo._core import DistributedRuntime, VirtualConnectorClient from dynamo._core import DistributedRuntime, VirtualConnectorClient
from dynamo.planner import VirtualConnector from dynamo.planner import SubComponentType, TargetReplica, VirtualConnector
pytestmark = pytest.mark.pre_merge pytestmark = pytest.mark.pre_merge
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -49,7 +49,10 @@ def test_main(): ...@@ -49,7 +49,10 @@ def test_main():
async def next_scaling_decision(c): async def next_scaling_decision(c):
"""Move the second decision in to a separate task so we can `.wait` for it.""" """Move the second decision in to a separate task so we can `.wait` for it."""
replicas = {"prefill": 5, "decode": 8} replicas = [
TargetReplica(sub_component_type=SubComponentType.PREFILL, desired_replicas=5),
TargetReplica(sub_component_type=SubComponentType.DECODE, desired_replicas=8),
]
await c.set_component_replicas(replicas, blocking=False) await c.set_component_replicas(replicas, blocking=False)
...@@ -57,7 +60,10 @@ async def async_internal(distributed_runtime): ...@@ -57,7 +60,10 @@ async def async_internal(distributed_runtime):
# This is Dynamo Planner # This is Dynamo Planner
c = VirtualConnector(distributed_runtime, NAMESPACE, "sglang") c = VirtualConnector(distributed_runtime, NAMESPACE, "sglang")
await c._async_init() await c._async_init()
replicas = {"prefill": 1, "decode": 2} replicas = [
TargetReplica(sub_component_type=SubComponentType.PREFILL, desired_replicas=1),
TargetReplica(sub_component_type=SubComponentType.DECODE, desired_replicas=2),
]
await c.set_component_replicas(replicas, blocking=False) await c.set_component_replicas(replicas, blocking=False)
# This is the client # This is the client
...@@ -86,7 +92,10 @@ async def async_internal(distributed_runtime): ...@@ -86,7 +92,10 @@ async def async_internal(distributed_runtime):
await c._wait_for_scaling_completion() await c._wait_for_scaling_completion()
# Now scale to zero # Now scale to zero
replicas = {"prefill": 0, "decode": 0} replicas = [
TargetReplica(sub_component_type=SubComponentType.PREFILL, desired_replicas=0),
TargetReplica(sub_component_type=SubComponentType.DECODE, desired_replicas=0),
]
await c.set_component_replicas(replicas, blocking=False) await c.set_component_replicas(replicas, blocking=False)
event = await client.get() event = await client.get()
assert event.num_prefill_workers == 0 assert event.num_prefill_workers == 0
......
...@@ -156,21 +156,6 @@ RUN apt-get update && \ ...@@ -156,21 +156,6 @@ RUN apt-get update && \
ca-certificates && \ ca-certificates && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# Install prometheus
ARG PROM_VERSION=3.4.1
RUN ARCH=$(dpkg --print-architecture) && \
case "$ARCH" in \
amd64) PLATFORM=linux-amd64 ;; \
arm64) PLATFORM=linux-arm64 ;; \
*) echo "Unsupported architecture: $ARCH" && exit 1 ;; \
esac && \
curl -fsSL --retry 5 --retry-delay 5 "https://github.com/prometheus/prometheus/releases/download/v${PROM_VERSION}/prometheus-${PROM_VERSION}.${PLATFORM}.tar.gz" \
| tar -xz -C /tmp && \
mv "/tmp/prometheus-${PROM_VERSION}.${PLATFORM}/prometheus" /usr/local/bin/ && \
chmod +x /usr/local/bin/prometheus && \
rm -rf "/tmp/prometheus-${PROM_VERSION}.${PLATFORM}"
# Copy CUDA development tools (nvcc, headers, dependencies, etc.) from framework devel image # Copy CUDA development tools (nvcc, headers, dependencies, etc.) from framework devel image
COPY --from=framework /usr/local/cuda/bin/nvcc /usr/local/cuda/bin/nvcc COPY --from=framework /usr/local/cuda/bin/nvcc /usr/local/cuda/bin/nvcc
COPY --from=framework /usr/local/cuda/bin/cudafe++ /usr/local/cuda/bin/cudafe++ COPY --from=framework /usr/local/cuda/bin/cudafe++ /usr/local/cuda/bin/cudafe++
......
...@@ -99,20 +99,6 @@ RUN apt-get update && \ ...@@ -99,20 +99,6 @@ RUN apt-get update && \
jq && \ jq && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# Install prometheus
ARG PROM_VERSION=3.4.1
RUN ARCH=$(dpkg --print-architecture) && \
case "$ARCH" in \
amd64) PLATFORM=linux-amd64 ;; \
arm64) PLATFORM=linux-arm64 ;; \
*) echo "Unsupported architecture: $ARCH" && exit 1 ;; \
esac && \
curl -fsSL --retry 5 --retry-delay 5 "https://github.com/prometheus/prometheus/releases/download/v${PROM_VERSION}/prometheus-${PROM_VERSION}.${PLATFORM}.tar.gz" \
| tar -xz -C /tmp && \
mv "/tmp/prometheus-${PROM_VERSION}.${PLATFORM}/prometheus" /usr/local/bin/ && \
chmod +x /usr/local/bin/prometheus && \
rm -rf "/tmp/prometheus-${PROM_VERSION}.${PLATFORM}"
# Copy CUDA development tools (nvcc, headers, dependencies, etc.) from framework devel image # Copy CUDA development tools (nvcc, headers, dependencies, etc.) from framework devel image
COPY --from=framework /usr/local/cuda/bin/nvcc /usr/local/cuda/bin/nvcc COPY --from=framework /usr/local/cuda/bin/nvcc /usr/local/cuda/bin/nvcc
COPY --from=framework /usr/local/cuda/bin/cudafe++ /usr/local/cuda/bin/cudafe++ COPY --from=framework /usr/local/cuda/bin/cudafe++ /usr/local/cuda/bin/cudafe++
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment