"examples/multimodal/components/web.py" did not exist on "16310b269f866e6f4b7968ba6780e54a4f7b76f6"
Unverified Commit 1a5016b0 authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

feat: add subComponentType in DGD API and uptake in planner (#3200)


Signed-off-by: default avatartmontfort <tmontfort@nvidia.com>
Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
Co-authored-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent 13156361
......@@ -26,7 +26,7 @@ from benchmarks.profiler.utils.defaults import (
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
)
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES, SubComponentType
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
......@@ -58,6 +58,7 @@ class Service(BaseModel):
replicas: Optional[int] = None
resources: Optional[ServiceResources] = None
extraPodSpec: Optional[PodSpec] = None
subComponentType: Optional[str] = None
model_config = {"extra": "allow"}
......@@ -187,14 +188,86 @@ def set_multinode_config(worker_service, gpu_count: int, num_gpus_per_node: int)
worker_service.multinode.nodeCount = node_count
# TODO: make is work for all frameworks
def get_worker_service_from_config(config: dict):
"""Helper function to get the SGLang decode worker service from config."""
def get_service_name_by_type(
config: dict, backend: str, sub_component_type: SubComponentType
) -> str:
"""Helper function to get service name by subComponentType.
First tries to find service by subComponentType, then falls back to component name.
Args:
config: Configuration dictionary (with spec.services structure)
backend: Backend name (e.g., "sglang", "vllm", "trtllm")
sub_component_type: The type of sub-component to look for (PREFILL or DECODE)
Returns:
The service name
"""
# Check if config has the expected structure
if (
not isinstance(config, dict)
or "spec" not in config
or "services" not in config.get("spec", {})
):
# Fall back to default name if structure is unexpected
if sub_component_type == SubComponentType.DECODE:
return WORKER_COMPONENT_NAMES[backend].decode_worker_k8s_name
else:
return WORKER_COMPONENT_NAMES[backend].prefill_worker_k8s_name
# Look through services to find one with matching subComponentType
services = config["spec"]["services"]
for service_name, service_config in services.items():
if (
isinstance(service_config, dict)
and service_config.get("subComponentType") == sub_component_type.value
):
return service_name
# Fall back to default component names
if sub_component_type == SubComponentType.DECODE:
default_name = WORKER_COMPONENT_NAMES[backend].decode_worker_k8s_name
else:
default_name = WORKER_COMPONENT_NAMES[backend].prefill_worker_k8s_name
# Check if the default name exists in services
if default_name in services:
return default_name
# Last resort: return the default name anyway
return default_name
def get_worker_service_from_config(
config: dict,
backend: str = "sglang",
sub_component_type: SubComponentType = SubComponentType.DECODE,
):
"""Helper function to get a worker service from config.
First tries to find service by subComponentType, then falls back to component name.
Args:
config: Configuration dictionary
backend: Backend name (e.g., "sglang", "vllm", "trtllm"). Defaults to "sglang".
sub_component_type: The type of sub-component to look for (PREFILL or DECODE). Defaults to DECODE.
Returns:
The worker service from the configuration
"""
if backend not in WORKER_COMPONENT_NAMES:
raise ValueError(
f"Unsupported backend: {backend}. Supported backends: {list(WORKER_COMPONENT_NAMES.keys())}"
)
# Get the service name using the type-aware logic
service_name = get_service_name_by_type(config, backend, sub_component_type)
# Get the actual service from the config
cfg = Config.model_validate(config)
return cfg.spec.services[WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name]
return cfg.spec.services[service_name]
# TODO: make is work for all frameworks
def setup_worker_service_resources(
worker_service, gpu_count: int, num_gpus_per_node: Optional[int] = None
):
......@@ -224,12 +297,24 @@ def setup_worker_service_resources(
worker_service.resources.limits["gpu"] = str(gpu_value)
# TODO: make is work for all frameworks
def validate_and_get_worker_args(worker_service):
"""Helper function to validate worker service and get its arguments."""
def validate_and_get_worker_args(worker_service, backend):
"""Helper function to validate worker service and get its arguments.
Args:
worker_service: Worker service object to validate
backend: Backend name (e.g., "sglang", "vllm", "trtllm"). Defaults to "sglang".
Returns:
List of arguments from the worker service
"""
if backend not in WORKER_COMPONENT_NAMES:
raise ValueError(
f"Unsupported backend: {backend}. Supported backends: {list(WORKER_COMPONENT_NAMES.keys())}"
)
if not worker_service.extraPodSpec or not worker_service.extraPodSpec.mainContainer:
raise ValueError(
f"Missing extraPodSpec or mainContainer in SGLang decode worker service '{WORKER_COMPONENT_NAMES['sglang'].decode_worker_k8s_name}'"
f"Missing extraPodSpec or mainContainer in {backend} decode worker service '{WORKER_COMPONENT_NAMES[backend].decode_worker_k8s_name}'"
)
args = worker_service.extraPodSpec.mainContainer.args
......@@ -310,28 +395,29 @@ class VllmV1ConfigModifier:
del cfg.spec.services["Planner"]
if target == "prefill":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
config, "vllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
config, "vllm", SubComponentType.DECODE
)
# convert prefill worker into decode worker
cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
] = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
]
del cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
cfg.spec.services[decode_service_name] = cfg.spec.services[
prefill_service_name
]
del cfg.spec.services[prefill_service_name]
worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
f"Missing extraPodSpec or mainContainer in VLLM decode worker service '{WORKER_COMPONENT_NAMES['vllm'].decode_worker_k8s_name}'"
)
args = worker_service.extraPodSpec.mainContainer.args
# Set subComponentType for aggregated mode (using decode worker for prefill-only)
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg.model_dump(),
backend="vllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="vllm")
args = break_arguments(args)
# remove --is-prefill-worker flag
......@@ -346,23 +432,26 @@ class VllmV1ConfigModifier:
worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
elif target == "decode":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
config, "vllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
config, "vllm", SubComponentType.DECODE
)
# delete prefill worker
del cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
]
del cfg.spec.services[prefill_service_name]
worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
f"Missing extraPodSpec or mainContainer in VLLM decode worker service '{WORKER_COMPONENT_NAMES['vllm'].decode_worker_k8s_name}'"
)
args = worker_service.extraPodSpec.mainContainer.args
# Set subComponentType for aggregated decode-only mode
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg.model_dump(),
backend="vllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="vllm")
args = break_arguments(args)
# enable prefix caching
......@@ -374,9 +463,11 @@ class VllmV1ConfigModifier:
worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
# set num workers to 1
decode_worker_config = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]
# Use the inferred decode service name
final_decode_service_name = get_service_name_by_type(
cfg.model_dump(), "vllm", SubComponentType.DECODE
)
decode_worker_config = cfg.spec.services[final_decode_service_name]
decode_worker_config.replicas = 1
return cfg.model_dump()
......@@ -384,34 +475,13 @@ class VllmV1ConfigModifier:
@classmethod
def set_config_tp_size(cls, config: dict, tp_size: int):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(config, backend="vllm")
worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]
# Ensure resources exists
if worker_service.resources is None:
worker_service.resources = ServiceResources()
# Ensure requests exists
if worker_service.resources.requests is None:
worker_service.resources.requests = {}
worker_service.resources.requests["gpu"] = str(tp_size)
# Update limits if they exist
if worker_service.resources.limits is not None:
worker_service.resources.limits["gpu"] = str(tp_size)
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
f"Missing extraPodSpec or mainContainer in VLLM decode worker service '{WORKER_COMPONENT_NAMES['vllm'].decode_worker_k8s_name}'"
)
args = worker_service.extraPodSpec.mainContainer.args
# Set up resources
setup_worker_service_resources(worker_service, tp_size)
# Get and validate args
args = validate_and_get_worker_args(worker_service, backend="vllm")
args = break_arguments(args)
try:
......@@ -438,18 +508,14 @@ class VllmV1ConfigModifier:
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
worker_name = WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
worker_service = cfg.spec.services[worker_name]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
try:
worker_service = get_worker_service_from_config(config, backend="vllm")
args = validate_and_get_worker_args(worker_service, backend="vllm")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing extraPodSpec or mainContainer, using default model name: {DEFAULT_MODEL_NAME}"
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
for i, arg in enumerate(args):
......@@ -535,28 +601,29 @@ class SGLangConfigModifier:
del cfg.spec.services["Planner"]
if target == "prefill":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
config, "sglang", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
config, "sglang", SubComponentType.DECODE
)
# convert prefill worker into decode worker
cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
] = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
]
del cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
cfg.spec.services[decode_service_name] = cfg.spec.services[
prefill_service_name
]
del cfg.spec.services[prefill_service_name]
worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
f"Missing extraPodSpec or mainContainer in SGLang decode worker service '{WORKER_COMPONENT_NAMES['sglang'].decode_worker_k8s_name}'"
)
args = worker_service.extraPodSpec.mainContainer.args
# Set subComponentType for aggregated mode (using decode worker for prefill-only)
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg.model_dump(),
backend="sglang",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="sglang")
args = break_arguments(args)
# remove disagg flags
......@@ -571,23 +638,26 @@ class SGLangConfigModifier:
worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
elif target == "decode":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
config, "sglang", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
config, "sglang", SubComponentType.DECODE
)
# delete prefill worker
del cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
]
del cfg.spec.services[prefill_service_name]
worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
f"Missing extraPodSpec or mainContainer in SGLang decode worker service '{WORKER_COMPONENT_NAMES['sglang'].decode_worker_k8s_name}'"
)
args = worker_service.extraPodSpec.mainContainer.args
# Set subComponentType for aggregated decode-only mode
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg.model_dump(),
backend="sglang",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="sglang")
args = break_arguments(args)
# remove disagg flags
......@@ -612,23 +682,25 @@ class SGLangConfigModifier:
worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
# set num workers to 1
decode_worker_config = config["spec"]["services"][
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]
decode_worker_config["replicas"] = 1
# Use the inferred decode service name
final_decode_service_name = get_service_name_by_type(
cfg.model_dump(), "sglang", SubComponentType.DECODE
)
decode_worker_config = cfg.spec.services[final_decode_service_name]
decode_worker_config.replicas = 1
return cfg.model_dump()
@classmethod
def set_config_tp_size(cls, config: dict, tp_size: int):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(config)
worker_service = get_worker_service_from_config(config, backend="sglang")
# Set up resources
setup_worker_service_resources(worker_service, tp_size)
# Get and validate args
args = validate_and_get_worker_args(worker_service)
args = validate_and_get_worker_args(worker_service, backend="sglang")
# Set --tp argument
args = set_argument_value(args, "--tp", str(tp_size))
......@@ -639,13 +711,13 @@ class SGLangConfigModifier:
@classmethod
def set_config_tep_size(cls, config: dict, tep_size: int, num_gpus_per_node: int):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(config)
worker_service = get_worker_service_from_config(config, backend="sglang")
# Set up resources with multinode configuration
setup_worker_service_resources(worker_service, tep_size, num_gpus_per_node)
# Get and validate args
args = validate_and_get_worker_args(worker_service)
args = validate_and_get_worker_args(worker_service, backend="sglang")
# 1. Set --tp=tep_size, if not present add it
args = set_argument_value(args, "--tp", str(tep_size))
......@@ -666,13 +738,13 @@ class SGLangConfigModifier:
@classmethod
def set_config_dep_size(cls, config: dict, dep_size: int, num_gpus_per_node: int):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(config)
worker_service = get_worker_service_from_config(config, backend="sglang")
# Set up resources with multinode configuration
setup_worker_service_resources(worker_service, dep_size, num_gpus_per_node)
# Get and validate args
args = validate_and_get_worker_args(worker_service)
args = validate_and_get_worker_args(worker_service, backend="sglang")
# 1. Set --tp=dep_size
args = set_argument_value(args, "--tp", str(dep_size))
......@@ -692,18 +764,14 @@ class SGLangConfigModifier:
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
worker_name = WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
worker_service = cfg.spec.services[worker_name]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
try:
worker_service = get_worker_service_from_config(config, backend="sglang")
args = validate_and_get_worker_args(worker_service, backend="sglang")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing extraPodSpec or mainContainer, using default model name: {DEFAULT_MODEL_NAME}"
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
for i, arg in enumerate(args):
......@@ -786,28 +854,30 @@ class TrtllmConfigModifier:
del cfg.spec.services["Planner"]
if target == "prefill":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
config, "trtllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
config, "trtllm", SubComponentType.DECODE
)
# Convert to prefill-only aggregated setup
# Merge prefill worker config into a single worker
if "TRTLLMPrefillWorker" in cfg.spec.services:
# Rename prefill worker to generic worker
cfg.spec.services["TRTLLMWorker"] = cfg.spec.services[
"TRTLLMPrefillWorker"
]
del cfg.spec.services["TRTLLMPrefillWorker"]
# Remove decode worker
del cfg.spec.services["TRTLLMDecodeWorker"]
worker_service = cfg.spec.services["TRTLLMWorker"]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in TRTLLM worker service 'TRTLLMWorker'"
)
args = worker_service.extraPodSpec.mainContainer.args
# Rename prefill worker to decode worker name
cfg.spec.services[decode_service_name] = cfg.spec.services[
prefill_service_name
]
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated mode (using decode worker for prefill-only)
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg.model_dump(),
backend="trtllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="trtllm")
args = break_arguments(args)
# Remove disaggregation args
......@@ -838,29 +908,28 @@ class TrtllmConfigModifier:
worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
elif target == "decode":
# Convert to decode-only aggregated setup
# Use decode worker as the main worker
if "TRTLLMDecodeWorker" in cfg.spec.services:
# Rename decode worker to generic worker
cfg.spec.services["TRTLLMWorker"] = cfg.spec.services[
"TRTLLMDecodeWorker"
]
del cfg.spec.services["TRTLLMDecodeWorker"]
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
config, "trtllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
config, "trtllm", SubComponentType.DECODE
)
# Convert to decode-only aggregated setup
# Remove prefill worker if exists
if "TRTLLMPrefillWorker" in cfg.spec.services:
del cfg.spec.services["TRTLLMPrefillWorker"]
worker_service = cfg.spec.services["TRTLLMWorker"]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in TRTLLM worker service 'TRTLLMWorker'"
)
args = worker_service.extraPodSpec.mainContainer.args
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated decode-only mode
cfg.spec.services[decode_service_name].subComponentType = "decode"
# Decode worker already has the correct name
worker_service = get_worker_service_from_config(
cfg.model_dump(),
backend="trtllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="trtllm")
args = break_arguments(args)
# Remove disaggregation args
......@@ -887,7 +956,11 @@ class TrtllmConfigModifier:
worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
# Set num workers to 1
worker_config = cfg.spec.services["TRTLLMWorker"]
# Use the inferred decode service name
final_decode_service_name = get_service_name_by_type(
cfg.model_dump(), "trtllm", SubComponentType.DECODE
)
worker_config = cfg.spec.services[final_decode_service_name]
worker_config.replicas = 1
return cfg.model_dump()
......@@ -896,30 +969,15 @@ class TrtllmConfigModifier:
def set_config_tp_size(cls, config: dict, tp_size: int):
cfg = Config.model_validate(config)
worker_service = cfg.spec.services["TRTLLMWorker"]
# Ensure resources exists
if worker_service.resources is None:
worker_service.resources = ServiceResources()
# Ensure requests exists
if worker_service.resources.requests is None:
worker_service.resources.requests = {}
worker_service.resources.requests["gpu"] = str(tp_size)
# Get the worker service using helper function
# This assumes convert_config has been called, so the service is named decode_worker_k8s_name
worker_service = get_worker_service_from_config(config, backend="trtllm")
# Update limits if they exist
if worker_service.resources.limits is not None:
worker_service.resources.limits["gpu"] = str(tp_size)
# Set up resources
setup_worker_service_resources(worker_service, tp_size)
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in TRTLLM worker service 'TRTLLMWorker'"
)
args = worker_service.extraPodSpec.mainContainer.args
# Validate and get args
args = validate_and_get_worker_args(worker_service, backend="trtllm")
# Break arguments to handle both joined strings and lists
args = break_arguments(args)
......@@ -951,33 +1009,14 @@ class TrtllmConfigModifier:
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
worker_name = "TRTLLMWorker"
worker_service = cfg.spec.services.get(worker_name)
# Also check for disagg worker names
if not worker_service:
worker_name = "TRTLLMPrefillWorker"
worker_service = cfg.spec.services.get(worker_name)
if not worker_service:
worker_name = "TRTLLMDecodeWorker"
worker_service = cfg.spec.services.get(worker_name)
if not worker_service:
logger.warning(
f"Worker service not found, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
try:
worker_service = get_worker_service_from_config(config, backend="trtllm")
args = validate_and_get_worker_args(worker_service, backend="trtllm")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing extraPodSpec or mainContainer, using default model name: {DEFAULT_MODEL_NAME}"
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
for i, arg in enumerate(args):
......
......@@ -7,8 +7,6 @@ metadata:
name: sglang-disagg-planner
spec:
envs:
- name: DYNAMO_SERVICE_CONFIG
value: '{"Prometheus":{"global":{"scrape_interval":"5s"},"scrape_configs":[{"job_name":"prometheus","static_configs":[{"targets":["localhost:9090"]}]},{"job_name":"frontend","static_configs":[{"targets":["sglang-disagg-planner-frontend:8000"]}]}]}}'
- name: DYNAMO_NAMESPACE
value: "dynamo"
services:
......@@ -61,45 +59,11 @@ spec:
--backend=sglang
--adjustment-interval=60
--profile-results-dir=/data/profiling_results
Prometheus: # NOTE: this is set on Prometheus to ensure a service is created for the Prometheus component. This is a workaround and should be managed differently.
dynamoNamespace: dynamo
componentType: frontend
replicas: 1
envs:
- name: PYTHONPATH
value: "/workspace/components/planner/src"
livenessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
readinessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
initialDelaySeconds: 30
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
extraPodSpec:
mainContainer:
image: my-registry/sglang-runtime:my-tag
workingDir: /workspace/components/backends/sglang
command:
- /bin/sh
- -c
args:
- "python3 -m dynamo.planner.prometheus"
decode:
dynamoNamespace: dynamo
envFromSecret: hf-token-secret
componentType: worker
subComponentType: decode
replicas: 2
resources:
limits:
......@@ -131,6 +95,7 @@ spec:
dynamoNamespace: dynamo
envFromSecret: hf-token-secret
componentType: worker
subComponentType: prefill
replicas: 2
resources:
limits:
......
......@@ -7,8 +7,6 @@ metadata:
name: trtllm-disagg-planner
spec:
envs:
- name: DYNAMO_SERVICE_CONFIG
value: '{"Prometheus":{"global":{"scrape_interval":"5s"},"scrape_configs":[{"job_name":"prometheus","static_configs":[{"targets":["localhost:8000"]}]},{"job_name":"frontend","static_configs":[{"targets":["trtllm-disagg-planner-frontend:8000"]}]}]}}'
- name: DYNAMO_NAMESPACE
value: "trtllm-disagg-planner"
services:
......@@ -41,9 +39,6 @@ spec:
envFromSecret: hf-token-secret
componentType: planner
replicas: 1
envs:
- name: PROMETHEUS_PORT
value: "8000"
livenessProbe:
exec:
command:
......@@ -84,47 +79,11 @@ spec:
- --adjustment-interval=60
- --profile-results-dir=/data/profiling_results
- --prometheus-port=9085
Prometheus: # NOTE: this is set on Prometheus to ensure a service is created for the Prometheus component. This is a workaround and should be managed differently.
dynamoNamespace: trtllm-disagg-planner
componentType: frontend
replicas: 1
envs:
- name: PYTHONPATH
value: "/workspace/components/planner/src"
- name: PROMETHEUS_PORT
value: "8000"
livenessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
readinessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
initialDelaySeconds: 30
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
extraPodSpec:
mainContainer:
image: my-registry/trtllm-runtime:my-tag
workingDir: /workspace/components/backends/trtllm
command:
- python3
args:
- -m
- dynamo.planner.prometheus
TRTLLMDecodeWorker:
dynamoNamespace: trtllm-disagg-planner
envFromSecret: hf-token-secret
componentType: worker
subComponentType: decode
replicas: 1
livenessProbe:
httpGet:
......@@ -173,6 +132,7 @@ spec:
dynamoNamespace: trtllm-disagg-planner
envFromSecret: hf-token-secret
componentType: worker
subComponentType: prefill
replicas: 1
resources:
limits:
......
......@@ -7,12 +7,8 @@ metadata:
name: vllm-disagg-planner
spec:
envs:
- name: DYNAMO_SERVICE_CONFIG
value: '{"Prometheus":{"global":{"scrape_interval":"5s"},"scrape_configs":[{"job_name":"prometheus","static_configs":[{"targets":["localhost:9090"]}]},{"job_name":"frontend","static_configs":[{"targets":["vllm-disagg-planner-frontend:8000"]}]}]}}'
- name: DYNAMO_NAMESPACE
value: "vllm-disagg-planner"
- name: PROMETHEUS_PORT
value: "8000"
services:
Frontend:
dynamoNamespace: vllm-disagg-planner
......@@ -63,45 +59,11 @@ spec:
--backend=vllm
--adjustment-interval=60
--profile-results-dir=/data/profiling_results
Prometheus: # NOTE: this is set on Prometheus to ensure a service is created for the Prometheus component. This is a workaround and should be managed differently.
dynamoNamespace: vllm-disagg-planner
componentType: frontend
replicas: 1
envs:
- name: PYTHONPATH
value: "/workspace/components/planner/src"
livenessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
readinessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
initialDelaySeconds: 30
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
extraPodSpec:
mainContainer:
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag
workingDir: /workspace/components/backends/vllm
command:
- /bin/sh
- -c
args:
- "python3 -m dynamo.planner.prometheus"
VllmDecodeWorker:
dynamoNamespace: vllm-disagg-planner
envFromSecret: hf-token-secret
componentType: worker
subComponentType: decode
replicas: 2
resources:
limits:
......@@ -127,6 +89,7 @@ spec:
dynamoNamespace: vllm-disagg-planner
envFromSecret: hf-token-secret
componentType: worker
subComponentType: prefill
replicas: 2
resources:
limits:
......
......@@ -8,11 +8,17 @@ __all__ = [
"LoadPlannerDefaults",
"SLAPlannerDefaults",
"ServiceConfig",
"TargetReplica",
"SubComponentType",
]
# Import the classes
from dynamo.planner.config import ServiceConfig
from dynamo.planner.defaults import LoadPlannerDefaults, SLAPlannerDefaults
from dynamo.planner.kubernetes_connector import KubernetesConnector
from dynamo.planner.defaults import (
LoadPlannerDefaults,
SLAPlannerDefaults,
SubComponentType,
)
from dynamo.planner.kubernetes_connector import KubernetesConnector, TargetReplica
from dynamo.planner.planner_connector import PlannerConnector
from dynamo.planner.virtual_connector import VirtualConnector
......
......@@ -15,8 +15,16 @@
import logging
import os
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from dynamo.planner.kube import get_current_k8s_namespace
from dynamo.planner.utils.exceptions import (
DuplicateSubComponentError,
SubComponentNotFoundError,
)
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
......@@ -56,6 +64,10 @@ class LoadPlannerDefaults(BasePlannerDefaults):
def _get_default_prometheus_endpoint(port: str, namespace: str):
"""Compute default prometheus endpoint using environment variables and Kubernetes service discovery"""
prometheus_endpoint = os.environ.get("PROMETHEUS_ENDPOINT", "").strip()
if prometheus_endpoint:
logger.debug("Using PROMETHEUS_ENDPOINT override: %s", prometheus_endpoint)
return prometheus_endpoint
k8s_namespace = get_current_k8s_namespace()
if k8s_namespace and k8s_namespace != "default":
......@@ -124,3 +136,67 @@ WORKER_COMPONENT_NAMES = {
"sglang": SGLangComponentName,
"trtllm": TrtllmComponentName,
}
class SubComponentType(str, Enum):
PREFILL = "prefill"
DECODE = "decode"
class Service(BaseModel):
name: str
service: dict
def number_replicas(self) -> int:
return self.service.get("replicas", 0)
# TODO: still supporting framework component names for backwards compatibility
# Should be deprecated in favor of service subComponentType
def get_service_from_sub_component_type_or_name(
deployment: dict,
sub_component_type: SubComponentType,
component_name: Optional[str] = None,
) -> Service:
"""
Get the current replicas for a component in a graph deployment
Returns: Service object
Raises:
SubComponentNotFoundError: If no service with the specified subComponentType is found
DuplicateSubComponentError: If multiple services with the same subComponentType are found
"""
services = deployment.get("spec", {}).get("services", {})
# Collect all available subComponentTypes for better error messages
available_types = []
matching_services = []
for curr_name, curr_service in services.items():
service_sub_type = curr_service.get("subComponentType", "")
if service_sub_type:
available_types.append(service_sub_type)
if service_sub_type == sub_component_type.value:
matching_services.append((curr_name, curr_service))
# Check for duplicates
if len(matching_services) > 1:
service_names = [name for name, _ in matching_services]
raise DuplicateSubComponentError(sub_component_type.value, service_names)
# If no service found with subCompontType and fallback component_name is not provided or not found,
# or if the fallback component has a non-empty subComponentType, raise error
if not matching_services and (
not component_name
or component_name not in services
or services[component_name].get("subComponentType", "") != ""
):
raise SubComponentNotFoundError(sub_component_type.value)
# If fallback component_name is provided and exists within services, add to matching_services
elif not matching_services and component_name in services:
matching_services.append((component_name, services[component_name]))
name, service = matching_services[0]
return Service(name=name, service=service)
......@@ -14,12 +14,18 @@
# limitations under the License.
import asyncio
import os
import logging
from typing import Optional
from kubernetes import client, config
from kubernetes.config.config_exception import ConfigException
from dynamo.planner.utils.exceptions import DynamoGraphDeploymentNotFoundError
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
def get_current_k8s_namespace() -> str:
"""Get the current namespace if running inside a k8s cluster"""
......@@ -42,9 +48,7 @@ class KubernetesAPI:
self.custom_api = client.CustomObjectsApi()
self.current_namespace = k8s_namespace or get_current_k8s_namespace()
def _get_graph_deployment_from_name(
self, graph_deployment_name: str
) -> Optional[dict]:
def _get_graph_deployment_from_name(self, graph_deployment_name: str) -> dict:
"""Get the graph deployment from the dynamo graph deployment name"""
return self.custom_api.get_namespaced_custom_object(
group="nvidia.com",
......@@ -54,38 +58,27 @@ class KubernetesAPI:
name=graph_deployment_name,
)
async def get_parent_graph_deployment(self) -> Optional[dict]:
def get_graph_deployment(self, graph_deployment_name: str) -> dict:
"""
Get the parent DynamoGraphDeployment using environment variable.
Uses DYN_PARENT_DGD_K8S_NAME environment variable and assumes the DGD
is in the same namespace as this component (self.current_namespace).
Get the parent DynamoGraphDeployment
Returns:
The DynamoGraphDeployment object or None if env var is not set
"""
dgd_name = os.getenv("DYN_PARENT_DGD_K8S_NAME")
if not dgd_name:
return None
The DynamoGraphDeployment object
Raises:
DynamoGraphDeploymentNotFoundError: If the parent graph deployment is not found
"""
try:
return self._get_graph_deployment_from_name(dgd_name)
return self._get_graph_deployment_from_name(graph_deployment_name)
except client.ApiException as e:
if e.status == 404:
return None
raise DynamoGraphDeploymentNotFoundError(
deployment_name=graph_deployment_name,
namespace=self.current_namespace,
)
raise
async def get_graph_deployment(self) -> Optional[dict]:
"""
Get the parent DynamoGraphDeployment using environment variable.
Returns:
The DynamoGraphDeployment object or None if env var is not set
"""
return await self.get_parent_graph_deployment()
async def update_graph_replicas(
def update_graph_replicas(
self, graph_deployment_name: str, component_name: str, replicas: int
) -> None:
"""Update the replicas count for a component in a DynamoGraphDeployment"""
......@@ -99,15 +92,10 @@ class KubernetesAPI:
body=patch,
)
async def is_deployment_ready(self, graph_deployment_name: str) -> bool:
def is_deployment_ready(self, deployment: dict) -> bool:
"""Check if a graph deployment is ready"""
graph_deployment = self._get_graph_deployment_from_name(graph_deployment_name)
if not graph_deployment:
raise ValueError(f"Graph deployment {graph_deployment_name} not found")
conditions = graph_deployment.get("status", {}).get("conditions", [])
conditions = deployment.get("status", {}).get("conditions", [])
ready_condition = next(
(c for c in conditions if c.get("type") == "Ready"), None
)
......@@ -125,12 +113,7 @@ class KubernetesAPI:
for attempt in range(max_attempts):
await asyncio.sleep(delay_seconds)
graph_deployment = self._get_graph_deployment_from_name(
graph_deployment_name
)
if not graph_deployment:
raise ValueError(f"Graph deployment {graph_deployment_name} not found")
graph_deployment = self.get_graph_deployment(graph_deployment_name)
conditions = graph_deployment.get("status", {}).get("conditions", [])
ready_condition = next(
......@@ -140,7 +123,7 @@ class KubernetesAPI:
if ready_condition and ready_condition.get("status") == "True":
return # Deployment is ready
print(
logger.info(
f"[Attempt {attempt + 1}/{max_attempts}] "
f"(status: {ready_condition.get('status') if ready_condition else 'N/A'}, "
f"message: {ready_condition.get('message') if ready_condition else 'no condition found'})"
......
......@@ -14,104 +14,291 @@
# limitations under the License.
import logging
import os
import shlex
from typing import Optional
from pydantic import BaseModel
from dynamo.planner.defaults import (
SubComponentType,
get_service_from_sub_component_type_or_name,
)
from dynamo.planner.kube import KubernetesAPI
from dynamo.planner.planner_connector import PlannerConnector
from dynamo.planner.utils.exceptions import (
DeploymentModelNameMismatchError,
DeploymentValidationError,
EmptyTargetReplicasError,
ModelNameNotFoundError,
PlannerError,
UserProvidedModelNameMismatchError,
)
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class Service(BaseModel):
name: str
service: dict
def number_replicas(self) -> int:
return self.service.get("replicas", 0)
def get_model_name(self) -> Optional[str]:
args = (
self.service.get("extraPodSpec", {})
.get("mainContainer", {})
.get("args", [])
)
args = break_arguments(args)
if (
"--served-model-name" in args
and len(args) > args.index("--served-model-name") + 1
):
return args[args.index("--served-model-name") + 1]
if "--model" in args and len(args) > args.index("--model") + 1:
return args[args.index("--model") + 1]
return None
def break_arguments(args: list[str] | None) -> list[str]:
ans: list[str] = []
if args is None:
return ans
if isinstance(args, str):
# Use shlex.split to properly handle quoted arguments and JSON values
ans = shlex.split(args)
else:
for arg in args:
if arg is not None:
# Use shlex.split to properly handle quoted arguments
ans.extend(shlex.split(arg))
return ans
class TargetReplica(BaseModel):
sub_component_type: SubComponentType
component_name: Optional[str] = None
desired_replicas: int
class KubernetesConnector(PlannerConnector):
def __init__(self, dynamo_namespace: str, k8s_namespace: Optional[str] = None):
def __init__(
self,
dynamo_namespace: str,
model_name: Optional[str] = None,
k8s_namespace: Optional[str] = None,
):
self.kube_api = KubernetesAPI(k8s_namespace)
self.dynamo_namespace = dynamo_namespace
async def add_component(self, component_name: str, blocking: bool = True):
self.user_provided_model_name: Optional[str] = None
if model_name:
self.user_provided_model_name = (
model_name.lower()
) # normalize model name to lowercase (MDC)
graph_deployment_name = os.getenv("DYN_PARENT_DGD_K8S_NAME")
if not graph_deployment_name:
raise DeploymentValidationError(
["DYN_PARENT_DGD_K8S_NAME environment variable is not set"]
)
self.graph_deployment_name = graph_deployment_name
async def add_component(
self, sub_component_type: SubComponentType, blocking: bool = True
):
"""Add a component by increasing its replica count by 1"""
deployment = await self.kube_api.get_graph_deployment()
if deployment is None:
raise ValueError("Parent DynamoGraphDeployment not found")
deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
# get current replicas or 1 if not found
current_replicas = self._get_current_replicas(deployment, component_name)
await self.kube_api.update_graph_replicas(
self._get_graph_deployment_name(deployment),
component_name,
current_replicas + 1,
service = get_service_from_sub_component_type_or_name(
deployment, sub_component_type
)
self.kube_api.update_graph_replicas(
self.graph_deployment_name,
service.name,
service.number_replicas() + 1,
)
if blocking:
await self.kube_api.wait_for_graph_deployment_ready(
self._get_graph_deployment_name(deployment)
self.graph_deployment_name,
)
async def remove_component(self, component_name: str, blocking: bool = True):
async def remove_component(
self, sub_component_type: SubComponentType, blocking: bool = True
):
"""Remove a component by decreasing its replica count by 1"""
deployment = await self.kube_api.get_graph_deployment()
if deployment is None:
raise ValueError("Parent DynamoGraphDeployment not found")
# get current replicas or 1 if not found
current_replicas = self._get_current_replicas(deployment, component_name)
if current_replicas > 0:
await self.kube_api.update_graph_replicas(
self._get_graph_deployment_name(deployment),
component_name,
current_replicas - 1,
deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
service = get_service_from_sub_component_type_or_name(
deployment, sub_component_type
)
if service.number_replicas() > 0:
self.kube_api.update_graph_replicas(
self.graph_deployment_name,
service.name,
service.number_replicas() - 1,
)
if blocking:
await self.kube_api.wait_for_graph_deployment_ready(
self._get_graph_deployment_name(deployment)
self.graph_deployment_name,
)
async def validate_deployment(
self,
prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None,
):
"""
Verify that the deployment contains services with subComponentType prefill and decode and the model name exists.
Will fallback to worker service names for backwards compatibility. (TODO: deprecate)
Raises:
DynamoGraphDeploymentNotFoundError: If the deployment is not found
DeploymentValidationError: If the deployment does not contain services with subComponentType prefill and decode
"""
deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
errors = []
try:
get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.PREFILL,
component_name=prefill_component_name,
)
except PlannerError as e:
errors.append(str(e))
try:
get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.DECODE,
component_name=decode_component_name,
)
except PlannerError as e:
errors.append(str(e))
try:
self.get_model_name(deployment)
except PlannerError as e:
errors.append(str(e))
# Raise combined error if any issues found
if errors:
raise DeploymentValidationError(errors)
def get_model_name(self, deployment: Optional[dict] = None) -> str:
"""Get the model name from the deployment"""
try:
if deployment is None:
deployment = self.kube_api.get_graph_deployment(
self.graph_deployment_name
)
# TODO: benchmarks/profiler/utils/config.py already contains DGD config parsing
# and model name logic, should consolidate
prefill_service = self.get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.PREFILL,
)
decode_service = self.get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.DECODE,
)
prefill_model_name = prefill_service.get_model_name()
decode_model_name = decode_service.get_model_name()
if prefill_model_name is None and decode_model_name is None:
raise ModelNameNotFoundError()
# Check model name between prefill and decode
if prefill_model_name is None:
model_name = decode_model_name
elif decode_model_name is None:
model_name = prefill_model_name
elif prefill_model_name != decode_model_name:
raise DeploymentModelNameMismatchError(
prefill_model_name, decode_model_name
)
else:
model_name = prefill_model_name
except PlannerError as e:
if self.user_provided_model_name:
logger.warning(
f"Failed to get model name from deployment with error: {e}, using provided model name: {self.user_provided_model_name}"
)
model_name = self.user_provided_model_name
else:
raise e
# If user provided a model name and it doesn't match the model name from the deployment, raise an error
if self.user_provided_model_name:
if model_name != self.user_provided_model_name:
raise UserProvidedModelNameMismatchError(
model_name, self.user_provided_model_name
)
if not model_name:
raise ModelNameNotFoundError()
return model_name
async def wait_for_deployment_ready(self):
"""Wait for the deployment to be ready"""
await self.kube_api.wait_for_graph_deployment_ready(
self.graph_deployment_name,
)
async def set_component_replicas(
self, target_replicas: dict[str, int], blocking: bool = True
self, target_replicas: list[TargetReplica], blocking: bool = True
):
"""Set the replicas for multiple components at once"""
if not target_replicas:
raise ValueError("target_replicas cannot be empty")
raise EmptyTargetReplicasError()
deployment = await self.kube_api.get_graph_deployment()
if deployment is None:
raise ValueError("Parent DynamoGraphDeployment not found")
deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
if not await self.kube_api.is_deployment_ready(
self._get_graph_deployment_name(deployment)
):
if not self.kube_api.is_deployment_ready(deployment):
logger.warning(
f"Deployment {self._get_graph_deployment_name(deployment)} is not ready, ignoring this scaling"
f"Deployment {self.graph_deployment_name} is not ready, ignoring this scaling"
)
return
for component_name, replicas in target_replicas.items():
await self.kube_api.update_graph_replicas(
self._get_graph_deployment_name(deployment),
component_name,
replicas,
for target_replica in target_replicas:
service = get_service_from_sub_component_type_or_name(
deployment,
target_replica.sub_component_type,
component_name=target_replica.component_name,
)
current_replicas = service.number_replicas()
if current_replicas != target_replica.desired_replicas:
logger.info(
f"Updating {target_replica.sub_component_type.value} component {service.name} to desired replica count {target_replica.desired_replicas}"
)
self.kube_api.update_graph_replicas(
self.graph_deployment_name,
service.name,
target_replica.desired_replicas,
)
else:
logger.info(
f"{target_replica.sub_component_type.value} component {service.name} already at desired replica count {target_replica.desired_replicas}, skipping"
)
if blocking:
await self.kube_api.wait_for_graph_deployment_ready(
self._get_graph_deployment_name(deployment)
self.graph_deployment_name,
)
def _get_current_replicas(self, deployment: dict, component_name: str) -> int:
"""Get the current replicas for a component in a graph deployment"""
return (
deployment.get("spec", {})
.get("services", {})
.get(component_name, {})
.get("replicas", 1)
)
def _get_graph_deployment_name(self, deployment: dict) -> str:
"""Get the name of the graph deployment"""
return deployment["metadata"]["name"]
if __name__ == "__main__":
import argparse
......@@ -121,13 +308,21 @@ if __name__ == "__main__":
parser.add_argument("--dynamo_namespace", type=str, default="dynamo")
parser.add_argument("--k8s_namespace", type=str, default="default")
parser.add_argument("--action", type=str, choices=["add", "remove"])
parser.add_argument("--component", type=str, default="planner")
parser.add_argument(
"--component",
type=str,
choices=[t.value for t in SubComponentType],
default=SubComponentType.PREFILL.value,
help="Target sub-component to scale",
)
parser.add_argument("--blocking", action="store_true")
args = parser.parse_args()
connector = KubernetesConnector(args.dynamo_namespace, args.k8s_namespace)
if args.action == "add":
task = connector.add_component(args.component, args.blocking)
task = connector.add_component(SubComponentType(args.component), args.blocking)
elif args.action == "remove":
task = connector.remove_component(args.component, args.blocking)
task = connector.remove_component(
SubComponentType(args.component), args.blocking
)
asyncio.run(task)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import subprocess
import tempfile
import yaml
from dynamo.planner.config import ServiceConfig
from dynamo.planner.defaults import SLAPlannerDefaults
from dynamo.runtime import DistributedRuntime, dynamo_worker
logger = logging.getLogger(__name__)
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
"""Initialize and run Prometheus server with Dynamo config."""
config = ServiceConfig.get_parsed_config("Prometheus")
logger.info(f"Prometheus config: {config}")
await start_prometheus_server(config)
async def start_prometheus_server(config):
logger.info("Starting prometheus server...")
temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False)
yaml.dump(config, temp_file)
temp_file.close()
config_path = temp_file.name
prometheus_port = SLAPlannerDefaults.port
cmd = [
"prometheus",
f"--config.file={config_path}",
f"--web.listen-address=0.0.0.0:{prometheus_port}",
]
logger.info(f"Prometheus cmd: {cmd}")
process = subprocess.Popen(
cmd,
stdout=None,
stderr=None,
)
# Keep the worker running
try:
while True:
await asyncio.sleep(1)
if process.poll() is not None:
logger.error("Prometheus process died")
break
except asyncio.CancelledError:
logger.info("Shutting down Prometheus...")
process.terminate()
process.wait()
raise
if __name__ == "__main__":
# The dynamo_worker decorator handles runtime setup
import asyncio
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom exceptions for the dynamo planner module.
This module defines a hierarchy of custom exceptions that provide more specific
error handling than generic ValueError exceptions. Each exception includes
contextual information to help with debugging and error handling.
"""
from typing import List
class PlannerError(Exception):
"""Base exception for all planner-related errors.
This serves as the root exception class for all custom exceptions
in the planner module, allowing for broad exception catching when needed.
"""
pass
class DynamoGraphDeploymentNotFoundError(PlannerError):
"""Raised when Parent DynamoGraphDeployment cannot be found.
This typically occurs when:
- The DYN_PARENT_DGD_K8S_NAME environment variable is not set
- The referenced DynamoGraphDeployment doesn't exist in the namespace
"""
def __init__(self, deployment_name: str, namespace: str):
self.deployment_name = deployment_name
self.namespace = namespace
message = f"Parent DynamoGraphDeployment not found (name: '{deployment_name}' in namespace '{namespace}')"
super().__init__(message)
class ComponentError(PlannerError):
"""Base class for subComponent configuration issues.
This serves as a parent class for all exceptions related to
subComponentType configuration problems in DynamoGraphDeployments.
"""
pass
class ModelNameNotFoundError(PlannerError):
"""Raised when the model name is not found in the deployment"""
def __init__(self):
message = "Model name not found in DynamoGraphDeployment"
super().__init__(message)
class DeploymentModelNameMismatchError(PlannerError):
"""Raised when the model name is not the same in the deployment"""
def __init__(self, prefill_model_name: str, decode_model_name: str):
self.prefill_model_name = prefill_model_name
self.decode_model_name = decode_model_name
message = f"Model name mismatch in DynamoGraphDeployment: prefill model name {prefill_model_name} != decode model name {decode_model_name}"
self.message = message
super().__init__(self.message)
class UserProvidedModelNameMismatchError(PlannerError):
"""Raised when the model name is not the same as the user provided model name"""
def __init__(self, model_name: str, user_provided_model_name: str):
self.model_name = model_name
self.user_provided_model_name = user_provided_model_name
message = f"Model name {model_name} does not match expected model name {user_provided_model_name}"
self.message = message
super().__init__(self.message)
class BackendFrameworkNotFoundError(PlannerError):
"""Raised when the backend framework is not supported.
This occurs when the DynamoGraphDeployment contains an unsupported backend framework.
"""
def __init__(self):
message = "Backend framework not found on DynamoGraphDeployment"
super().__init__(message)
class BackendFrameworkInvalidError(PlannerError):
"""Raised when the backend framework does not exist.
This occurs when the DynamoGraphDeployment contains an unsupported backend framework.
"""
def __init__(self, backend_framework: str):
self.backend_framework = backend_framework
message = f"Backend framework {backend_framework} is invalid"
super().__init__(message)
class SubComponentNotFoundError(ComponentError):
"""Raised when a required subComponentType is not found in the deployment.
This occurs when the DynamoGraphDeployment doesn't contain any service
with the requested subComponentType (e.g., 'prefill', 'decode').
"""
def __init__(self, sub_component_type: str):
self.sub_component_type = sub_component_type
message = f"DynamoGraphDeployment must contain a service with subComponentType '{sub_component_type}'"
super().__init__(message)
class DuplicateSubComponentError(ComponentError):
"""Raised when multiple services have the same subComponentType.
This occurs when the DynamoGraphDeployment contains more than one service
with the same subComponentType, which violates the expected uniqueness constraint.
"""
def __init__(self, sub_component_type: str, service_names: List[str]):
self.sub_component_type = sub_component_type
self.service_names = service_names
message = (
f"DynamoGraphDeployment must contain only one service with "
f"subComponentType '{sub_component_type}', but found multiple: "
f"{', '.join(sorted(service_names))}"
)
super().__init__(message)
class DeploymentValidationError(PlannerError):
"""Raised when deployment validation fails for multiple components.
This is used to aggregate multiple validation errors into a single exception,
providing a comprehensive view of all validation issues.
"""
def __init__(self, errors: List[str]):
self.errors = errors
message = f"Service verification failed: {'; '.join(errors)}"
super().__init__(message)
class EmptyTargetReplicasError(PlannerError):
"""Raised when target_replicas is empty or invalid.
This occurs when attempting to set component replicas with an empty
or invalid target_replicas dictionary.
"""
def __init__(
self,
):
message = "target_replicas cannot be empty"
super().__init__(message)
......@@ -118,4 +118,9 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
default=SLAPlannerDefaults.no_correction,
help="Disable correction factor",
)
parser.add_argument(
"--model-name",
type=str,
help="Model name of deployment (only required for virtual environment)",
)
return parser
......@@ -11,7 +11,12 @@ from typing import Optional
from prometheus_client import Gauge, start_http_server
from dynamo.planner import KubernetesConnector, VirtualConnector
from dynamo.planner import (
KubernetesConnector,
SubComponentType,
TargetReplica,
VirtualConnector,
)
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES, SLAPlannerDefaults
from dynamo.planner.utils.load_predictor import LOAD_PREDICTORS
from dynamo.planner.utils.perf_interpolation import (
......@@ -63,22 +68,30 @@ class Planner:
self.args = args
self.dryrun = dryrun
# Rely on getting model name from connector
self.model_name: Optional[str] = None
if not self.dryrun:
self.runtime = runtime
self.namespace = args.namespace
if not args.no_operation:
if args.environment == "kubernetes":
self.connector = KubernetesConnector(self.namespace)
self.connector = KubernetesConnector(
self.namespace, self.model_name
)
elif args.environment == "virtual":
self.connector = VirtualConnector(
runtime, self.namespace, args.backend
runtime,
self.namespace,
args.model_name,
)
else:
raise ValueError(f"Invalid environment: {args.environment}")
self.prometheus_api_client = PrometheusAPIClient(
SLAPlannerDefaults.prometheus_endpoint
SLAPlannerDefaults.prometheus_endpoint,
args.namespace,
)
self.num_req_predictor = LOAD_PREDICTORS[args.load_predictor](
......@@ -121,6 +134,13 @@ class Planner:
self.prefill_interpolator = PrefillInterpolator(args.profile_results_dir)
self.decode_interpolator = DecodeInterpolator(args.profile_results_dir)
self.prefill_component_name = WORKER_COMPONENT_NAMES[
self.args.backend
].prefill_worker_k8s_name
self.decode_component_name = WORKER_COMPONENT_NAMES[
self.args.backend
].decode_worker_k8s_name
if not self.dryrun:
self.prefill_client = None
self.workers_client = None
......@@ -230,27 +250,33 @@ class Planner:
self.num_d_workers_gauge.set(len(self.d_endpoints))
self.last_metrics.ttft = self.prometheus_api_client.get_avg_time_to_first_token(
f"{self.args.adjustment_interval}s"
f"{self.args.adjustment_interval}s",
self.model_name,
)
self.last_metrics.itl = self.prometheus_api_client.get_avg_inter_token_latency(
f"{self.args.adjustment_interval}s"
f"{self.args.adjustment_interval}s",
self.model_name,
)
self.last_metrics.num_req = self.prometheus_api_client.get_avg_request_count(
f"{self.args.adjustment_interval}s"
f"{self.args.adjustment_interval}s",
self.model_name,
)
self.last_metrics.request_duration = (
self.prometheus_api_client.get_avg_request_duration(
f"{self.args.adjustment_interval}s"
f"{self.args.adjustment_interval}s",
self.model_name,
)
)
self.last_metrics.isl = (
self.prometheus_api_client.get_avg_input_sequence_tokens(
f"{self.args.adjustment_interval}s"
f"{self.args.adjustment_interval}s",
self.model_name,
)
)
self.last_metrics.osl = (
self.prometheus_api_client.get_avg_output_sequence_tokens(
f"{self.args.adjustment_interval}s"
f"{self.args.adjustment_interval}s",
self.model_name,
)
)
......@@ -429,19 +455,43 @@ class Planner:
return
if not self.args.no_operation:
target_replicas = {
WORKER_COMPONENT_NAMES[
self.args.backend
].prefill_worker_k8s_name: next_num_p,
WORKER_COMPONENT_NAMES[
self.args.backend
].decode_worker_k8s_name: next_num_d,
}
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_component_name,
desired_replicas=next_num_p,
),
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.decode_component_name,
desired_replicas=next_num_d,
),
]
await self.connector.set_component_replicas(target_replicas, blocking=False)
async def run(self):
"""Main loop for the planner"""
if not self.args.no_operation:
# Fail fast if the deployment is not valid
logger.info("Validating deployment...")
# TODO: still supporting framework component names for backwards compatibility
# Should be deprecated in favor of service subComponentType
await self.connector.validate_deployment(
prefill_component_name=self.prefill_component_name,
decode_component_name=self.decode_component_name,
)
logger.info("Successfully validated the deployment")
await self.connector.wait_for_deployment_ready()
model_name = self.connector.get_model_name()
logger.info(f"Detected model name from deployment: {model_name}")
self.model_name = (
model_name.lower()
) # normalize model name to lowercase (MDC)
self.last_adjustment_time = time.time()
while True:
......@@ -453,6 +503,7 @@ class Planner:
):
self.last_adjustment_time = time.time()
logger.info("New adjustment interval started!")
await self.observe_metrics()
await self.make_adjustments()
......
......@@ -14,8 +14,10 @@
# limitations under the License.
import logging
import typing
from prometheus_api_client import PrometheusConnect
from pydantic import BaseModel, ValidationError
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -23,12 +25,33 @@ configure_dynamo_logging()
logger = logging.getLogger(__name__)
class FrontendMetric(BaseModel):
container: typing.Optional[str] = None
dynamo_namespace: typing.Optional[str] = None
endpoint: typing.Optional[str] = None
instance: typing.Optional[str] = None
job: typing.Optional[str] = None
model: typing.Optional[str] = None
namespace: typing.Optional[str] = None
pod: typing.Optional[str] = None
class FrontendMetricContainer(BaseModel):
metric: FrontendMetric
value: typing.Tuple[float, float] # [timestamp, value]
class PrometheusAPIClient:
def __init__(self, url: str):
def __init__(self, url: str, dynamo_namespace: str):
self.prom = PrometheusConnect(url=url, disable_ssl=True)
self.dynamo_namespace = dynamo_namespace
def _get_average_metric(
self, metric_name: str, interval: str, operation_name: str
self,
metric_name: str,
interval: str,
operation_name: str,
model_name: str,
) -> float:
"""
Helper method to get average metrics using the pattern:
......@@ -50,57 +73,92 @@ class PrometheusAPIClient:
if not result:
# No data available yet (no requests made) - return 0 silently
return 0
return float(result[0]["value"][1])
metrics_containers = parse_frontend_metric_containers(result)
values = []
for container in metrics_containers:
if (
container.metric.model == model_name
and container.metric.dynamo_namespace == self.dynamo_namespace
):
values.append(container.value[1])
if not values:
return 0
return sum(values) / len(values)
except Exception as e:
logger.error(f"Error getting {operation_name}: {e}")
return 0
def get_avg_inter_token_latency(self, interval: str):
def get_avg_inter_token_latency(self, interval: str, model_name: str):
return self._get_average_metric(
"inter_token_latency_seconds",
interval,
"avg inter token latency",
model_name,
)
def get_avg_time_to_first_token(self, interval: str):
def get_avg_time_to_first_token(self, interval: str, model_name: str):
return self._get_average_metric(
"time_to_first_token_seconds",
interval,
"avg time to first token",
model_name,
)
def get_avg_request_duration(self, interval: str):
def get_avg_request_duration(self, interval: str, model_name: str):
return self._get_average_metric(
"request_duration_seconds",
interval,
"avg request duration",
model_name,
)
def get_avg_request_count(self, interval: str):
def get_avg_request_count(self, interval: str, model_name: str):
# This function follows a different query pattern than the other metrics
try:
raw_res = self.prom.custom_query(
query=f"increase(dynamo_frontend_requests_total[{interval}])"
)
metrics_containers = parse_frontend_metric_containers(raw_res)
total_count = 0.0
for res in raw_res:
# count all success/failed and stream/non-stream requests
total_count += float(res["value"][1])
for container in metrics_containers:
if (
container.metric.model == model_name
and container.metric.dynamo_namespace == self.dynamo_namespace
):
total_count += container.value[1]
return total_count
except Exception as e:
logger.error(f"Error getting avg request count: {e}")
return 0
def get_avg_input_sequence_tokens(self, interval: str):
def get_avg_input_sequence_tokens(self, interval: str, model_name: str):
return self._get_average_metric(
"input_sequence_tokens",
interval,
"avg input sequence tokens",
model_name,
)
def get_avg_output_sequence_tokens(self, interval: str):
def get_avg_output_sequence_tokens(self, interval: str, model_name: str):
return self._get_average_metric(
"output_sequence_tokens",
interval,
"avg output sequence tokens",
model_name,
)
def parse_frontend_metric_containers(
result: list[dict],
) -> list[FrontendMetricContainer]:
metrics_containers: list[FrontendMetricContainer] = []
for res in result:
try:
metrics_containers.append(FrontendMetricContainer.model_validate(res))
except ValidationError as e:
logger.error(f"Error parsing frontend metric container: {e}")
continue
return metrics_containers
......@@ -6,8 +6,9 @@ import os
from typing import Optional
from dynamo._core import VirtualConnectorCoordinator
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
from dynamo.planner import SubComponentType, TargetReplica
from dynamo.planner.planner_connector import PlannerConnector
from dynamo.planner.utils.exceptions import EmptyTargetReplicasError
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -32,7 +33,10 @@ class VirtualConnector(PlannerConnector):
"""
def __init__(
self, runtime: DistributedRuntime, dynamo_namespace: str, backend: str
self,
runtime: DistributedRuntime,
dynamo_namespace: str,
model_name: Optional[str] = None,
):
self.connector = VirtualConnectorCoordinator(
runtime,
......@@ -42,8 +46,12 @@ class VirtualConnector(PlannerConnector):
SCALING_MAX_RETRIES,
)
self.backend = backend
self.worker_component_names = WORKER_COMPONENT_NAMES[backend]
if model_name is None:
raise ValueError("Model name is required for virtual connector")
self.model_name = model_name.lower() # normalize model name to lowercase (MDC)
self.dynamo_namespace = dynamo_namespace
async def _async_init(self):
"""Async initialization that must be called after __init__"""
......@@ -59,47 +67,32 @@ class VirtualConnector(PlannerConnector):
"""Wait for the deployment environment to report that scaling is complete"""
await self.connector.wait_for_scaling_completion()
def _component_to_worker_type(self, component_name: str) -> Optional[str]:
"""Map component name to worker type (prefill or decode)"""
if component_name == self.worker_component_names.prefill_worker_k8s_name:
return "prefill"
elif component_name == self.worker_component_names.decode_worker_k8s_name:
return "decode"
else:
return None
async def add_component(self, component_name: str, blocking: bool = True):
async def add_component(
self, sub_component_type: SubComponentType, blocking: bool = True
):
"""Add a component by increasing its replica count by 1"""
worker_type = self._component_to_worker_type(component_name)
if worker_type is None:
logger.warning(f"Unknown component name: {component_name}, skipping")
return
state = self.connector.read_state()
if worker_type == "prefill":
if sub_component_type == SubComponentType.PREFILL:
await self._update_scaling_decision(
num_prefill=state.num_prefill_workers + 1
)
elif worker_type == "decode":
elif sub_component_type == SubComponentType.DECODE:
await self._update_scaling_decision(num_decode=state.num_decode_workers + 1)
if blocking:
await self._wait_for_scaling_completion()
async def remove_component(self, component_name: str, blocking: bool = True):
async def remove_component(
self, sub_component_type: SubComponentType, blocking: bool = True
):
"""Remove a component by decreasing its replica count by 1"""
worker_type = self._component_to_worker_type(component_name)
if worker_type is None:
logger.warning(f"Unknown component name: {component_name}, skipping")
return
state = self.connector.read_state()
if worker_type == "prefill":
if sub_component_type == SubComponentType.PREFILL:
new_count = max(0, state.num_prefill_workers - 1)
await self._update_scaling_decision(num_prefill=new_count)
elif worker_type == "decode":
elif sub_component_type == SubComponentType.DECODE:
new_count = max(0, state.num_decode_workers - 1)
await self._update_scaling_decision(num_decode=new_count)
......@@ -107,25 +100,20 @@ class VirtualConnector(PlannerConnector):
await self._wait_for_scaling_completion()
async def set_component_replicas(
self, target_replicas: dict[str, int], blocking: bool = True
self, target_replicas: list[TargetReplica], blocking: bool = True
):
"""Set the replicas for multiple components at once"""
if not target_replicas:
raise ValueError("target_replicas cannot be empty")
raise EmptyTargetReplicasError()
num_prefill = None
num_decode = None
for component_name, replicas in target_replicas.items():
worker_type = self._component_to_worker_type(component_name)
if worker_type is None:
logger.warning(f"Unknown component name: {component_name}, skipping")
continue
if worker_type == "prefill":
num_prefill = replicas
elif worker_type == "decode":
num_decode = replicas
for target_replica in target_replicas:
if target_replica.sub_component_type == SubComponentType.PREFILL:
num_prefill = target_replica.desired_replicas
elif target_replica.sub_component_type == SubComponentType.DECODE:
num_decode = target_replica.desired_replicas
if num_prefill is None and num_decode is None:
return
......@@ -137,3 +125,19 @@ class VirtualConnector(PlannerConnector):
if blocking:
await self._wait_for_scaling_completion()
async def validate_deployment(
self,
prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None,
):
"""Validate the deployment"""
pass
async def wait_for_deployment_ready(self):
"""Wait for the deployment to be ready"""
await self._wait_for_scaling_completion()
async def get_model_name(self) -> str:
"""Get the model name from the deployment"""
return self.model_name
......@@ -13,13 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict
from unittest.mock import MagicMock, patch
import pytest
from kubernetes import client
from dynamo.planner.kube import KubernetesAPI
from dynamo.planner.utils.exceptions import DynamoGraphDeploymentNotFoundError
@pytest.fixture
......@@ -75,6 +76,21 @@ def test_get_graph_deployment_from_name(k8s_api, mock_custom_api):
)
def test_update_graph_replicas(k8s_api, mock_custom_api):
mock_custom_api.patch_namespaced_custom_object.return_value = None
k8s_api.update_graph_replicas("test-deployment", "test-component", 1)
mock_custom_api.patch_namespaced_custom_object.assert_called_once_with(
group="nvidia.com",
version="v1alpha1",
namespace=k8s_api.current_namespace,
plural="dynamographdeployments",
name="test-deployment",
body={"spec": {"services": {"test-component": {"replicas": 1}}}},
)
@pytest.mark.asyncio
async def test_is_deployment_ready_true(k8s_api, mock_custom_api):
"""Test is_deployment_ready method when deployment is ready"""
......@@ -87,12 +103,8 @@ async def test_is_deployment_ready_true(k8s_api, mock_custom_api):
}
}
# Mock the method on the instance
with patch.object(
k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
):
result = await k8s_api.is_deployment_ready("test-deployment")
assert result is True
result = k8s_api.is_deployment_ready(mock_deployment)
assert result is True
@pytest.mark.asyncio
......@@ -109,24 +121,8 @@ async def test_is_deployment_ready_false(k8s_api, mock_custom_api):
]
}
}
# Mock the method on the instance
with patch.object(
k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
):
result = await k8s_api.is_deployment_ready("test-deployment")
assert result is False
@pytest.mark.asyncio
async def test_is_deployment_ready_not_found(k8s_api, mock_custom_api):
"""Test is_deployment_ready method when deployment is not found"""
# Mock the method on the instance
with patch.object(k8s_api, "_get_graph_deployment_from_name", return_value=None):
with pytest.raises(ValueError) as exc_info:
await k8s_api.is_deployment_ready("test-deployment")
assert "not found" in str(exc_info.value)
result = k8s_api.is_deployment_ready(mock_deployment)
assert result is False
@pytest.mark.asyncio
......@@ -142,9 +138,7 @@ async def test_wait_for_graph_deployment_ready_success(k8s_api, mock_custom_api)
}
# Mock the method on the instance
with patch.object(
k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
):
with patch.object(k8s_api, "get_graph_deployment", return_value=mock_deployment):
# Test with minimal attempts and delay for faster testing
await k8s_api.wait_for_graph_deployment_ready(
"test-deployment", max_attempts=2, delay_seconds=0.1
......@@ -168,9 +162,7 @@ async def test_wait_for_graph_deployment_ready_timeout(k8s_api, mock_custom_api)
}
# Mock the method on the instance
with patch.object(
k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
):
with patch.object(k8s_api, "get_graph_deployment", return_value=mock_deployment):
# Test with minimal attempts and delay for faster testing
with pytest.raises(TimeoutError) as exc_info:
await k8s_api.wait_for_graph_deployment_ready(
......@@ -183,15 +175,21 @@ async def test_wait_for_graph_deployment_ready_timeout(k8s_api, mock_custom_api)
@pytest.mark.asyncio
async def test_wait_for_graph_deployment_not_found(k8s_api, mock_custom_api):
"""Test wait_for_graph_deployment_ready when deployment is not found"""
# Mock the _get_graph_deployment_from_name response to return None
with patch.object(k8s_api, "_get_graph_deployment_from_name", return_value=None):
# Test with minimal attempts and delay for faster testing
with pytest.raises(ValueError) as exc_info:
await k8s_api.wait_for_graph_deployment_ready(
"test-deployment", max_attempts=2, delay_seconds=0.1
)
assert "not found" in str(exc_info.value)
mock_custom_api.get_namespaced_custom_object.side_effect = client.ApiException(
status=404
)
# Test with minimal attempts and delay for faster testing
with pytest.raises(DynamoGraphDeploymentNotFoundError) as exc_info:
await k8s_api.wait_for_graph_deployment_ready(
"test-deployment", max_attempts=2, delay_seconds=0.1
)
# Validate the exception fields
exception = exc_info.value
assert exception.deployment_name == "test-deployment"
assert exception.namespace == "default"
@pytest.mark.asyncio
......@@ -200,9 +198,7 @@ async def test_wait_for_graph_deployment_no_conditions(k8s_api, mock_custom_api)
# Mock the _get_graph_deployment_from_name response with no conditions
mock_deployment: Dict[str, Any] = {"status": {}}
with patch.object(
k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
):
with patch.object(k8s_api, "get_graph_deployment", return_value=mock_deployment):
# Test with minimal attempts and delay for faster testing
with pytest.raises(TimeoutError) as exc_info:
await k8s_api.wait_for_graph_deployment_ready(
......@@ -249,37 +245,28 @@ async def test_wait_for_graph_deployment_ready_on_second_attempt(
@pytest.mark.asyncio
async def test_get_parent_graph_deployment_with_env_var(k8s_api, mock_custom_api):
"""Test get_parent_graph_deployment with environment variable set"""
async def test_get_graph_deployment(k8s_api, mock_custom_api):
"""Test get_graph_deployment"""
mock_deployment = {"metadata": {"name": "parent-dgd"}}
with patch.dict(os.environ, {"DYN_PARENT_DGD_K8S_NAME": "parent-dgd"}):
with patch.object(
k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
) as mock_get:
result = await k8s_api.get_parent_graph_deployment()
assert result == mock_deployment
mock_get.assert_called_once_with("parent-dgd")
with patch.object(
k8s_api, "_get_graph_deployment_from_name", return_value=mock_deployment
) as mock_get:
result = await k8s_api.get_graph_deployment("parent-dgd")
@pytest.mark.asyncio
async def test_get_parent_graph_deployment_without_env_var(k8s_api, mock_custom_api):
"""Test get_parent_graph_deployment without environment variable"""
with patch.dict(os.environ, {}, clear=True):
result = await k8s_api.get_parent_graph_deployment()
assert result is None
assert result == mock_deployment
mock_get.assert_called_once_with("parent-dgd")
@pytest.mark.asyncio
async def test_get_graph_deployment_delegates_to_parent(k8s_api, mock_custom_api):
"""Test get_graph_deployment delegates to get_parent_graph_deployment"""
mock_deployment = {"metadata": {"name": "parent-dgd"}}
with patch.object(
k8s_api, "get_parent_graph_deployment", return_value=mock_deployment
) as mock_parent:
result = await k8s_api.get_graph_deployment()
async def test_get_graph_deployment_not_found(k8s_api, mock_custom_api):
"""Test get_graph_deployment when deployment is not found"""
k8s_api.custom_api.get_namespaced_custom_object.side_effect = client.ApiException(
status=404
)
with pytest.raises(DynamoGraphDeploymentNotFoundError) as exc_info:
await k8s_api.get_graph_deployment("parent-dgd")
assert result == mock_deployment
mock_parent.assert_called_once()
exception = exc_info.value
assert exception.deployment_name == "parent-dgd"
assert exception.namespace == "default"
......@@ -13,20 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import AsyncMock, Mock
import os
from unittest.mock import AsyncMock, Mock, call, patch
import pytest
from dynamo.planner.kubernetes_connector import KubernetesConnector
from dynamo.planner.defaults import (
SubComponentType,
get_service_from_sub_component_type_or_name,
)
from dynamo.planner.kubernetes_connector import KubernetesConnector, TargetReplica
from dynamo.planner.utils.exceptions import (
DeploymentModelNameMismatchError,
DeploymentValidationError,
DuplicateSubComponentError,
DynamoGraphDeploymentNotFoundError,
EmptyTargetReplicasError,
ModelNameNotFoundError,
SubComponentNotFoundError,
)
@pytest.fixture
def mock_kube_api():
mock_api = Mock()
mock_api.get_graph_deployment = AsyncMock()
mock_api.get_graph_deployment = Mock()
mock_api.update_graph_replicas = AsyncMock()
mock_api.wait_for_graph_deployment_ready = AsyncMock()
mock_api.is_deployment_ready = AsyncMock()
mock_api.is_deployment_ready = Mock()
return mock_api
......@@ -43,24 +57,150 @@ def kubernetes_connector(mock_kube_api_class, monkeypatch):
monkeypatch.setattr(
"dynamo.planner.kubernetes_connector.KubernetesAPI", mock_kube_api_class
)
connector = KubernetesConnector("test-dynamo-namespace", "default")
return connector
with patch.dict(os.environ, {"DYN_PARENT_DGD_K8S_NAME": "test-graph"}):
connector = KubernetesConnector("test-dynamo-namespace")
return connector
def test_kubernetes_connector_no_env_var():
with pytest.raises(DeploymentValidationError) as exc_info:
KubernetesConnector("test-dynamo-namespace")
exception = exc_info.value
assert set(exception.errors) == {
"DYN_PARENT_DGD_K8S_NAME environment variable is not set"
}
def test_get_service_name_from_sub_component_type(kubernetes_connector):
deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"test-component-prefill": {
"replicas": 2,
"subComponentType": "prefill",
},
"test-component-decode": {"replicas": 3, "subComponentType": "decode"},
}
},
}
service = get_service_from_sub_component_type_or_name(
deployment, SubComponentType.PREFILL
)
assert service.name == "test-component-prefill"
assert service.number_replicas() == 2
# should still work if the component_name is provided
service = get_service_from_sub_component_type_or_name(
deployment, SubComponentType.PREFILL, "test-component-prefill"
)
assert service.name == "test-component-prefill"
assert service.number_replicas() == 2
# should respect subComponentType first
service = get_service_from_sub_component_type_or_name(
deployment, SubComponentType.DECODE, "test-component-prefill"
)
assert service.name == "test-component-decode"
assert service.number_replicas() == 3
def test_get_service_name_from_sub_component_type_not_found(kubernetes_connector):
deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"test-component-decode": {"replicas": 3, "subComponentType": "decode"},
}
},
}
with pytest.raises(SubComponentNotFoundError) as exc_info:
get_service_from_sub_component_type_or_name(
deployment, SubComponentType.PREFILL
)
with pytest.raises(SubComponentNotFoundError) as exc_info:
get_service_from_sub_component_type_or_name(
deployment, SubComponentType.PREFILL, "test-component-decode"
)
exception = exc_info.value
assert exception.sub_component_type == SubComponentType.PREFILL.value
def test_get_service_name_from_sub_component_type_duplicate(kubernetes_connector):
deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"test-component-prefill": {
"replicas": 2,
"subComponentType": "prefill",
},
"test-component-prefill-2": {
"replicas": 3,
"subComponentType": "prefill",
},
}
},
}
with pytest.raises(DuplicateSubComponentError) as exc_info:
# even though "test-component-prefill" is provided, subComponentType duplicates should result in an error
get_service_from_sub_component_type_or_name(
deployment, SubComponentType.PREFILL, "test-component-prefill"
)
exception = exc_info.value
assert exception.sub_component_type == SubComponentType.PREFILL.value
assert set(exception.service_names) == {
"test-component-prefill",
"test-component-prefill-2",
}
def test_get_service_name_from_sub_component_type_or_name(kubernetes_connector):
deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"test-component-prefill": {"replicas": 2},
"test-component-decode": {"replicas": 3},
}
},
}
service = get_service_from_sub_component_type_or_name(
deployment, SubComponentType.PREFILL, "test-component-prefill"
)
assert service.name == "test-component-prefill"
assert service.number_replicas() == 2
@pytest.mark.asyncio
async def test_add_component_increases_replicas(kubernetes_connector, mock_kube_api):
# Arrange
sub_component_type = SubComponentType.PREFILL
component_name = "test-component"
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {"services": {"test-component": {"replicas": 1}}},
"spec": {
"services": {
component_name: {
"replicas": 1,
"subComponentType": sub_component_type.value,
}
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
mock_kube_api.update_graph_replicas.return_value = None
mock_kube_api.wait_for_graph_deployment_ready.return_value = None
# Act
await kubernetes_connector.add_component(component_name)
await kubernetes_connector.add_component(sub_component_type)
# Assert
mock_kube_api.get_graph_deployment.assert_called_once()
......@@ -75,19 +215,22 @@ async def test_add_component_with_no_replicas_specified(
kubernetes_connector, mock_kube_api
):
# Arrange
sub_component_type = SubComponentType.PREFILL
component_name = "test-component"
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {"services": {"test-component": {}}},
"spec": {
"services": {component_name: {"subComponentType": sub_component_type.value}}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
await kubernetes_connector.add_component(component_name)
await kubernetes_connector.add_component(sub_component_type)
# Assert
mock_kube_api.update_graph_replicas.assert_called_once_with(
"test-graph", component_name, 2
"test-graph", component_name, 1
)
mock_kube_api.wait_for_graph_deployment_ready.assert_called_once_with("test-graph")
......@@ -96,25 +239,55 @@ async def test_add_component_with_no_replicas_specified(
async def test_add_component_deployment_not_found(kubernetes_connector, mock_kube_api):
# Arrange
component_name = "test-component"
mock_kube_api.get_graph_deployment.return_value = None
mock_kube_api.get_graph_deployment.side_effect = DynamoGraphDeploymentNotFoundError(
"test-graph", "default"
)
# Act & Assert
with pytest.raises(ValueError, match="Parent DynamoGraphDeployment not found"):
with pytest.raises(DynamoGraphDeploymentNotFoundError):
await kubernetes_connector.add_component(component_name)
@pytest.mark.asyncio
async def test_add_component_component_not_found(kubernetes_connector, mock_kube_api):
# Arrange
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {"services": {"test-component": {"subComponentType": "decode"}}},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
with pytest.raises(SubComponentNotFoundError) as exc_info:
await kubernetes_connector.add_component(SubComponentType.PREFILL)
mock_kube_api.update_graph_replicas.assert_not_called()
mock_kube_api.wait_for_graph_deployment_ready.assert_not_called()
exception = exc_info.value
assert exception.sub_component_type == "prefill"
@pytest.mark.asyncio
async def test_remove_component_decreases_replicas(kubernetes_connector, mock_kube_api):
# Arrange
component_name = "test-component"
sub_component_type = SubComponentType.PREFILL
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {"services": {"test-component": {"replicas": 2}}},
"spec": {
"services": {
"test-component": {
"replicas": 2,
"subComponentType": sub_component_type.value,
}
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
await kubernetes_connector.remove_component(component_name)
await kubernetes_connector.remove_component(sub_component_type)
# Assert
mock_kube_api.update_graph_replicas.assert_called_once_with(
......@@ -127,33 +300,82 @@ async def test_remove_component_decreases_replicas(kubernetes_connector, mock_ku
async def test_remove_component_with_zero_replicas(kubernetes_connector, mock_kube_api):
# Arrange
component_name = "test-component"
sub_component_type = SubComponentType.PREFILL
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {"services": {"test-component": {"replicas": 0}}},
"spec": {
"services": {
component_name: {
"replicas": 0,
"subComponentType": sub_component_type.value,
}
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
await kubernetes_connector.remove_component(component_name)
await kubernetes_connector.remove_component(sub_component_type)
# Assert
mock_kube_api.update_graph_replicas.assert_not_called()
mock_kube_api.wait_for_graph_deployment_ready.assert_not_called()
@pytest.mark.asyncio
async def test_remove_component_component_not_found(
kubernetes_connector, mock_kube_api
):
# Arrange
component_name = "test-component"
sub_component_type = SubComponentType.PREFILL
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
component_name: {
"replicas": 0,
"subComponentType": sub_component_type.value,
}
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
with pytest.raises(SubComponentNotFoundError) as exc_info:
await kubernetes_connector.remove_component(SubComponentType.DECODE)
# Assert
mock_kube_api.update_graph_replicas.assert_not_called()
mock_kube_api.wait_for_graph_deployment_ready.assert_not_called()
exception = exc_info.value
assert exception.sub_component_type == "decode"
@pytest.mark.asyncio
async def test_set_component_replicas(kubernetes_connector, mock_kube_api):
# Arrange
target_replicas = {"component1": 3, "component2": 2}
target_replicas = [
TargetReplica(sub_component_type=SubComponentType.PREFILL, desired_replicas=3),
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name="component2",
desired_replicas=2,
),
]
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {"component1": {"replicas": 1}, "component2": {"replicas": 1}}
"services": {
"component1": {"replicas": 1, "subComponentType": "prefill"},
"component2": {"replicas": 1},
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
mock_kube_api.is_deployment_ready.return_value = True
mock_kube_api.update_graph_replicas.return_value = None
mock_kube_api.wait_for_graph_deployment_ready.return_value = None
# Act
......@@ -161,9 +383,81 @@ async def test_set_component_replicas(kubernetes_connector, mock_kube_api):
# Assert
mock_kube_api.get_graph_deployment.assert_called_once()
mock_kube_api.is_deployment_ready.assert_called_once_with("test-graph")
mock_kube_api.is_deployment_ready.assert_called_once_with(mock_deployment)
# Should be called twice, once for each component
assert mock_kube_api.update_graph_replicas.call_count == 2
expected_calls = [
call("test-graph", "component1", 3), # prefill component with 3 replicas
call("test-graph", "component2", 2), # decode component with 2 replicas
]
mock_kube_api.update_graph_replicas.assert_has_calls(expected_calls, any_order=True)
mock_kube_api.wait_for_graph_deployment_ready.assert_called_once_with("test-graph")
@pytest.mark.asyncio
async def test_set_component_replicas_component_not_found(
kubernetes_connector, mock_kube_api
):
# Arrange
target_replicas = [
TargetReplica(sub_component_type=SubComponentType.PREFILL, desired_replicas=3),
TargetReplica(sub_component_type=SubComponentType.DECODE, desired_replicas=2),
]
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"component1": {"replicas": 1, "subComponentType": "prefill"},
"component2": {"replicas": 1},
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
mock_kube_api.is_deployment_ready.return_value = True
mock_kube_api.update_graph_replicas.return_value = None
mock_kube_api.wait_for_graph_deployment_ready.return_value = None
# Act
with pytest.raises(SubComponentNotFoundError) as exc_info:
await kubernetes_connector.set_component_replicas(target_replicas)
exception = exc_info.value
assert exception.sub_component_type == SubComponentType.DECODE.value
@pytest.mark.asyncio
async def test_set_component_replicas_component_already_at_desired_replicas(
kubernetes_connector, mock_kube_api
):
# Arrange
target_replicas = [
TargetReplica(sub_component_type=SubComponentType.PREFILL, desired_replicas=3),
TargetReplica(sub_component_type=SubComponentType.DECODE, desired_replicas=2),
]
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"component1": {"replicas": 1, "subComponentType": "prefill"},
"component2": {"replicas": 2, "subComponentType": "decode"},
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
mock_kube_api.is_deployment_ready.return_value = True
mock_kube_api.update_graph_replicas.return_value = None
mock_kube_api.wait_for_graph_deployment_ready.return_value = None
# Act
await kubernetes_connector.set_component_replicas(target_replicas)
# Assert
mock_kube_api.get_graph_deployment.assert_called_once()
mock_kube_api.is_deployment_ready.assert_called_once_with(mock_deployment)
# Should be called once, for the prefill component (decode component is already at desired replicas)
mock_kube_api.update_graph_replicas.assert_called_once_with(
"test-graph", "component1", 3
)
mock_kube_api.wait_for_graph_deployment_ready.assert_called_once_with("test-graph")
......@@ -172,11 +466,15 @@ async def test_set_component_replicas_deployment_not_found(
kubernetes_connector, mock_kube_api
):
# Arrange
target_replicas = {"component1": 3}
mock_kube_api.get_graph_deployment.return_value = None
target_replicas = [
TargetReplica(sub_component_type=SubComponentType.PREFILL, desired_replicas=3)
]
mock_kube_api.get_graph_deployment.side_effect = DynamoGraphDeploymentNotFoundError(
"test-graph", "default"
)
# Act & Assert
with pytest.raises(ValueError, match="Parent DynamoGraphDeployment not found"):
with pytest.raises(DynamoGraphDeploymentNotFoundError):
await kubernetes_connector.set_component_replicas(target_replicas)
......@@ -185,8 +483,201 @@ async def test_set_component_replicas_empty_target_replicas(
kubernetes_connector, mock_kube_api
):
# Arrange
target_replicas: dict[str, int] = {}
target_replicas: list[TargetReplica] = []
# Act & Assert
with pytest.raises(ValueError, match="target_replicas cannot be empty"):
with pytest.raises(EmptyTargetReplicasError):
await kubernetes_connector.set_component_replicas(target_replicas)
async def test_set_component_replicas_deployment_not_ready(
kubernetes_connector, mock_kube_api
):
# Arrange
target_replicas = [
TargetReplica(sub_component_type=SubComponentType.PREFILL, desired_replicas=3),
TargetReplica(sub_component_type=SubComponentType.DECODE, desired_replicas=2),
]
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"component1": {"replicas": 1, "subComponentType": "prefill"},
"component2": {"replicas": 2, "subComponentType": "decode"},
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
mock_kube_api.is_deployment_ready.return_value = False
# Act & Assert
await kubernetes_connector.set_component_replicas(target_replicas)
mock_kube_api.get_graph_deployment.assert_called_once()
mock_kube_api.is_deployment_ready.assert_called_once_with(mock_deployment)
mock_kube_api.update_graph_replicas.assert_not_called()
mock_kube_api.wait_for_graph_deployment_ready.assert_not_called()
@pytest.mark.asyncio
async def test_validate_deployment_true(kubernetes_connector, mock_kube_api):
# Arrange
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"component1": {
"replicas": 1,
"subComponentType": "prefill",
"extraPodSpec": {
"mainContainer": {
"args": ["--served-model-name", "prefill-model"]
}
},
},
"component2": {"replicas": 2, "subComponentType": "decode"},
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
await kubernetes_connector.validate_deployment(decode_component_name="component2")
@pytest.mark.asyncio
async def test_validate_deployment_fail(kubernetes_connector, mock_kube_api):
# Arrange
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"component1": {"replicas": 1, "subComponentType": "prefill"},
"component2": {"replicas": 2, "subComponentType": "prefill"},
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
with pytest.raises(DeploymentValidationError) as exc_info:
await kubernetes_connector.validate_deployment()
exception = exc_info.value
assert set(exception.errors) == {
str(DuplicateSubComponentError("prefill", ["component1", "component2"])),
str(SubComponentNotFoundError("decode")),
}
def test_get_model_name_both_none_raises_error(kubernetes_connector, mock_kube_api):
# Arrange
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"component1": {"replicas": 1, "subComponentType": "prefill"},
"component2": {"replicas": 2, "subComponentType": "decode"},
}
},
}
with pytest.raises(ModelNameNotFoundError):
kubernetes_connector.get_model_name(mock_deployment)
def test_get_model_name_prefill_none_decode_valid_returns_decode(kubernetes_connector):
# Arrange
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"component1": {"replicas": 1, "subComponentType": "prefill"},
"component2": {
"replicas": 2,
"subComponentType": "decode",
"extraPodSpec": {
"mainContainer": {"args": ["--served-model-name", "test-model"]}
},
},
}
},
}
# Act
result = kubernetes_connector.get_model_name(mock_deployment)
# Assert
assert result == "test-model"
def test_get_model_name_mismatch_raises_error(kubernetes_connector, mock_kube_api):
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"component1": {
"replicas": 1,
"subComponentType": "prefill",
"extraPodSpec": {
"mainContainer": {
"args": ["--served-model-name", "prefill-model"]
}
},
},
"component2": {
"replicas": 2,
"subComponentType": "decode",
"extraPodSpec": {
"mainContainer": {
"args": ["--served-model-name", "decode-model"]
}
},
},
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act & Assert
with pytest.raises(DeploymentModelNameMismatchError) as exc_info:
kubernetes_connector.get_model_name(mock_deployment)
exception = exc_info.value
assert exception.prefill_model_name == "prefill-model"
assert exception.decode_model_name == "decode-model"
def test_get_model_name_agree_returns_model_name(kubernetes_connector, mock_kube_api):
# Arrange
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"component1": {
"replicas": 1,
"subComponentType": "prefill",
"extraPodSpec": {
"mainContainer": {
"args": ["--served-model-name", "agreed-model"]
}
},
},
"component2": {
"replicas": 2,
"subComponentType": "decode",
"extraPodSpec": {
"mainContainer": {
"args": ["--served-model-name", "agreed-model"]
}
},
},
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
# Act
result = kubernetes_connector.get_model_name(mock_deployment)
# Assert
assert result == "agreed-model"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from unittest.mock import patch
import pytest
from dynamo.planner.utils.prometheus import (
FrontendMetric,
FrontendMetricContainer,
PrometheusAPIClient,
)
@pytest.fixture
def mock_prometheus_result():
"""Fixture providing mock prometheus result data for testing"""
return [
{
"metric": {
"container": "main",
"dynamo_namespace": "different_namespace",
"model": "different_model",
"namespace": "dynamo-system",
},
"value": [1758857776.071, 10.5],
},
{
"metric": {
"container": "main",
"dynamo_namespace": "target_namespace",
"model": "target_model",
"namespace": "dynamo-system",
},
"value": [1758857776.071, 42.7],
},
{
"metric": {
"container": "worker",
"dynamo_namespace": "target_namespace",
"model": "target_model",
"namespace": "dynamo-system",
},
"value": [1758857776.071, 35.5],
},
{
"metric": {
"container": "sidecar",
"dynamo_namespace": "target_namespace",
"model": "target_model",
"namespace": "dynamo-system",
},
"value": [30.0, 15.5],
},
]
def test_frontend_metric_container_with_nan_value():
test_data = {
"metric": {
"container": "main",
"dynamo_namespace": "vllm-disagg-planner",
"endpoint": "http",
"instance": "10.244.2.163:8000",
"job": "dynamo-system/dynamo-frontend",
"model": "qwen/qwen3-0.6b",
"namespace": "dynamo-system",
"pod": "vllm-disagg-planner-frontend-865f84c49-6q7s5",
},
"value": [1758857776.071, "NaN"],
}
container = FrontendMetricContainer.model_validate(test_data)
assert container.metric.container == "main"
assert container.metric.dynamo_namespace == "vllm-disagg-planner"
assert container.metric.endpoint == "http"
assert container.metric.instance == "10.244.2.163:8000"
assert container.metric.job == "dynamo-system/dynamo-frontend"
assert container.metric.model == "qwen/qwen3-0.6b"
assert container.metric.namespace == "dynamo-system"
assert container.metric.pod == "vllm-disagg-planner-frontend-865f84c49-6q7s5"
assert container.value[0] == 1758857776.071
assert math.isnan(
container.value[1]
) # becomes special float value that can't be asserted to itself
test_data["value"][1] = 42.5 # type: ignore[index]
container = FrontendMetricContainer.model_validate(test_data)
assert container.value[1] == 42.5
def test_frontend_metric_with_partial_data():
"""Test FrontendMetric with partial data (optional fields)"""
test_data = {
"container": "main",
"model": "qwen/qwen3-0.6b",
"namespace": "dynamo-system",
}
metric = FrontendMetric.model_validate(test_data)
# Assert provided fields
assert metric.container == "main"
assert metric.model == "qwen/qwen3-0.6b"
assert metric.namespace == "dynamo-system"
# Assert optional fields are None
assert metric.dynamo_namespace is None
assert metric.endpoint is None
assert metric.instance is None
assert metric.job is None
assert metric.pod is None
def test_get_average_metric_none_result():
"""Test _get_average_metric when prometheus returns None"""
client = PrometheusAPIClient("http://localhost:9090", "test_namespace")
with patch.object(client.prom, "custom_query") as mock_query:
mock_query.return_value = None
result = client._get_average_metric(
metric_name="test_metric",
interval="60s",
operation_name="test operation",
model_name="test_model",
)
assert result == 0
def test_get_average_metric_empty_result():
"""Test _get_average_metric when prometheus returns empty list"""
client = PrometheusAPIClient("http://localhost:9090", "test_namespace")
with patch.object(client.prom, "custom_query") as mock_query:
mock_query.return_value = []
result = client._get_average_metric(
metric_name="test_metric",
interval="60s",
operation_name="test operation",
model_name="test_model",
)
assert result == 0
def test_get_average_metric_no_matching_containers(mock_prometheus_result):
"""Test _get_average_metric with valid containers but no matches"""
client = PrometheusAPIClient("http://localhost:9090", "test_namespace")
with patch.object(client.prom, "custom_query") as mock_query:
# Use only the first container which doesn't match target criteria
mock_query.return_value = [mock_prometheus_result[0]]
result = client._get_average_metric(
metric_name="test_metric",
interval="60s",
operation_name="test operation",
model_name="target_model",
)
assert result == 0
def test_get_average_metric_one_matching_container(mock_prometheus_result):
"""Test _get_average_metric with one matching container"""
client = PrometheusAPIClient("http://localhost:9090", "target_namespace")
with patch.object(client.prom, "custom_query") as mock_query:
# Use first two containers - one doesn't match, one does
mock_query.return_value = mock_prometheus_result[:2]
result = client._get_average_metric(
metric_name="test_metric",
interval="60s",
operation_name="test operation",
model_name="target_model",
)
assert result == 42.7
def test_get_average_metric_with_validation_error():
"""Test _get_average_metric with one valid container and one that fails validation"""
client = PrometheusAPIClient("http://localhost:9090", "target_namespace")
mock_result = [
{
"metric": {
"container": "main",
"dynamo_namespace": "target_namespace",
"model": "target_model",
"namespace": "dynamo-system",
},
"value": [1758857776.071, 25.5],
},
{
# Invalid structure - missing required fields that will cause validation error
"invalid_structure": "bad_data",
"value": "not_a_tuple",
},
]
with patch.object(client.prom, "custom_query") as mock_query:
mock_query.return_value = mock_result
result = client._get_average_metric(
metric_name="test_metric",
interval="60s",
operation_name="test operation",
model_name="target_model",
)
assert result == 25.5
def test_get_average_metric_multiple_matching_containers(mock_prometheus_result):
"""Test _get_average_metric with multiple matching containers returns average"""
client = PrometheusAPIClient("http://localhost:9090", "target_namespace")
with patch.object(client.prom, "custom_query") as mock_query:
# Use containers 1, 2, 3 which all match target criteria
mock_query.return_value = mock_prometheus_result[1:]
result = client._get_average_metric(
metric_name="test_metric",
interval="60s",
operation_name="test operation",
model_name="target_model",
)
# Average of 42.7, 35.5, and 15.5 (using value[1] from each container)
expected = (42.7 + 35.5 + 15.5) / 3
assert result == expected
......@@ -11,7 +11,7 @@ import logging
import pytest
from dynamo._core import DistributedRuntime, VirtualConnectorClient
from dynamo.planner import VirtualConnector
from dynamo.planner import SubComponentType, TargetReplica, VirtualConnector
pytestmark = pytest.mark.pre_merge
logger = logging.getLogger(__name__)
......@@ -49,7 +49,10 @@ def test_main():
async def next_scaling_decision(c):
"""Move the second decision in to a separate task so we can `.wait` for it."""
replicas = {"prefill": 5, "decode": 8}
replicas = [
TargetReplica(sub_component_type=SubComponentType.PREFILL, desired_replicas=5),
TargetReplica(sub_component_type=SubComponentType.DECODE, desired_replicas=8),
]
await c.set_component_replicas(replicas, blocking=False)
......@@ -57,7 +60,10 @@ async def async_internal(distributed_runtime):
# This is Dynamo Planner
c = VirtualConnector(distributed_runtime, NAMESPACE, "sglang")
await c._async_init()
replicas = {"prefill": 1, "decode": 2}
replicas = [
TargetReplica(sub_component_type=SubComponentType.PREFILL, desired_replicas=1),
TargetReplica(sub_component_type=SubComponentType.DECODE, desired_replicas=2),
]
await c.set_component_replicas(replicas, blocking=False)
# This is the client
......@@ -86,7 +92,10 @@ async def async_internal(distributed_runtime):
await c._wait_for_scaling_completion()
# Now scale to zero
replicas = {"prefill": 0, "decode": 0}
replicas = [
TargetReplica(sub_component_type=SubComponentType.PREFILL, desired_replicas=0),
TargetReplica(sub_component_type=SubComponentType.DECODE, desired_replicas=0),
]
await c.set_component_replicas(replicas, blocking=False)
event = await client.get()
assert event.num_prefill_workers == 0
......
......@@ -156,21 +156,6 @@ RUN apt-get update && \
ca-certificates && \
rm -rf /var/lib/apt/lists/*
# Install prometheus
ARG PROM_VERSION=3.4.1
RUN ARCH=$(dpkg --print-architecture) && \
case "$ARCH" in \
amd64) PLATFORM=linux-amd64 ;; \
arm64) PLATFORM=linux-arm64 ;; \
*) echo "Unsupported architecture: $ARCH" && exit 1 ;; \
esac && \
curl -fsSL --retry 5 --retry-delay 5 "https://github.com/prometheus/prometheus/releases/download/v${PROM_VERSION}/prometheus-${PROM_VERSION}.${PLATFORM}.tar.gz" \
| tar -xz -C /tmp && \
mv "/tmp/prometheus-${PROM_VERSION}.${PLATFORM}/prometheus" /usr/local/bin/ && \
chmod +x /usr/local/bin/prometheus && \
rm -rf "/tmp/prometheus-${PROM_VERSION}.${PLATFORM}"
# Copy CUDA development tools (nvcc, headers, dependencies, etc.) from framework devel image
COPY --from=framework /usr/local/cuda/bin/nvcc /usr/local/cuda/bin/nvcc
COPY --from=framework /usr/local/cuda/bin/cudafe++ /usr/local/cuda/bin/cudafe++
......
......@@ -99,20 +99,6 @@ RUN apt-get update && \
jq && \
rm -rf /var/lib/apt/lists/*
# Install prometheus
ARG PROM_VERSION=3.4.1
RUN ARCH=$(dpkg --print-architecture) && \
case "$ARCH" in \
amd64) PLATFORM=linux-amd64 ;; \
arm64) PLATFORM=linux-arm64 ;; \
*) echo "Unsupported architecture: $ARCH" && exit 1 ;; \
esac && \
curl -fsSL --retry 5 --retry-delay 5 "https://github.com/prometheus/prometheus/releases/download/v${PROM_VERSION}/prometheus-${PROM_VERSION}.${PLATFORM}.tar.gz" \
| tar -xz -C /tmp && \
mv "/tmp/prometheus-${PROM_VERSION}.${PLATFORM}/prometheus" /usr/local/bin/ && \
chmod +x /usr/local/bin/prometheus && \
rm -rf "/tmp/prometheus-${PROM_VERSION}.${PLATFORM}"
# Copy CUDA development tools (nvcc, headers, dependencies, etc.) from framework devel image
COPY --from=framework /usr/local/cuda/bin/nvcc /usr/local/cuda/bin/nvcc
COPY --from=framework /usr/local/cuda/bin/cudafe++ /usr/local/cuda/bin/cudafe++
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment