"lib/bindings/vscode:/vscode.git/clone" did not exist on "71be641d16f2e1741df17eabfa6e582eff493d94"
Unverified Commit 0a266653 authored by daiyaanarfeen's avatar daiyaanarfeen Committed by GitHub
Browse files

feat: add GlobalPlanner component for centralized scaling (#5702)


Signed-off-by: default avatarDaiyaan <darfeen@nvidia.com>
Signed-off-by: default avatardaiyaanarfeen <darfeen@nvidia.com>
Co-authored-by: default avatarClaude Sonnet 4.5 <noreply@anthropic.com>
parent 733c3dce
......@@ -109,7 +109,9 @@ deploy:
planner:
- 'components/src/dynamo/planner/**'
- 'components/src/dynamo/global_planner/**'
- 'tests/planner/**'
- 'tests/global_planner/**'
- 'components/src/dynamo/profiler/**'
- 'components/src/dynamo/global_router/**'
......
......@@ -27,6 +27,7 @@ CODEOWNERS @ai-dynamo/Devops
# Planner
/components/src/dynamo/planner/ @ai-dynamo/python-codeowners @ai-dynamo/Devops
/components/src/dynamo/global_router/ @ai-dynamo/python-codeowners @ai-dynamo/Devops
/components/src/dynamo/global_planner/ @ai-dynamo/python-codeowners @ai-dynamo/Devops
/examples/hierarchical_planner/ @ai-dynamo/python-codeowners @ai-dynamo/Devops
/components/src/dynamo/profiler/ @ai-dynamo/python-codeowners @ai-dynamo/Devops
/tests/planner/ @ai-dynamo/python-codeowners @ai-dynamo/Devops
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
GlobalPlanner - Centralized scaling execution service.
The GlobalPlanner is a standalone component that receives scale requests from
Planners and executes them via the Kubernetes API. It provides centralized
scaling management across multiple deployments and namespaces.
Architecture:
- Planners make scaling decisions (observe, predict, decide)
- Planners in delegating mode send requests to GlobalPlanner
- GlobalPlanner executes scaling via Kubernetes API
- GlobalPlanner is stateless and can scale horizontally
Usage:
DYN_NAMESPACE=global-infra python -m dynamo.global_planner \
--managed-namespaces app-ns-1 app-ns-2
"""
__all__ = [
"ScaleRequestHandler",
]
from dynamo.global_planner.scale_handler import ScaleRequestHandler
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
GlobalPlanner - Centralized Scaling Execution Service
Entry point for the GlobalPlanner component.
Usage:
DYN_NAMESPACE=global-infra python -m dynamo.global_planner
With authorization:
DYN_NAMESPACE=global-infra python -m dynamo.global_planner \\
--managed-namespaces app-ns-1 app-ns-2
"""
import asyncio
import logging
import os
from pydantic import BaseModel
from dynamo.global_planner.argparse_config import create_global_planner_parser
from dynamo.global_planner.scale_handler import ScaleRequestHandler
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class HealthCheckRequest(BaseModel):
"""Request type for health check endpoint"""
text: str = "ping"
@dynamo_worker()
async def main(runtime: DistributedRuntime, args):
"""Initialize and run GlobalPlanner.
The GlobalPlanner is a centralized scaling service that:
1. Listens for scale requests from Planners
2. Validates caller authorization (optional)
3. Executes scaling via Kubernetes API
4. Returns success/failure status
Args:
runtime: Dynamo runtime instance
args: Parsed command-line arguments
"""
# Get Dynamo namespace from environment variable
namespace = os.environ.get("DYN_NAMESPACE")
if not namespace:
raise ValueError(
"DYN_NAMESPACE environment variable is required but not set. "
"Please set DYN_NAMESPACE to specify the Dynamo namespace for GlobalPlanner."
)
logger.info("=" * 60)
logger.info("Starting GlobalPlanner")
logger.info("=" * 60)
logger.info(f"Namespace: {namespace}")
logger.info(f"Environment: {args.environment}")
if args.managed_namespaces:
logger.info("Authorization: ENABLED")
logger.info(f"Authorized namespaces: {args.managed_namespaces}")
else:
logger.info("Authorization: DISABLED (accepting all namespaces)")
logger.info("=" * 60)
# Create the GlobalPlanner component
component = runtime.namespace(namespace).component("GlobalPlanner")
# Get K8s namespace (where GlobalPlanner pod is running)
k8s_namespace = os.environ.get("POD_NAMESPACE", "default")
logger.info(f"Running in Kubernetes namespace: {k8s_namespace}")
# Create scale request handler
handler = ScaleRequestHandler(
runtime=runtime,
managed_namespaces=args.managed_namespaces,
k8s_namespace=k8s_namespace,
)
# Serve scale_request endpoint
logger.info("Serving endpoints...")
scale_endpoint = component.endpoint("scale_request")
await scale_endpoint.serve_endpoint(handler.scale_request)
logger.info(" ✓ scale_request - Receives scaling requests from Planners")
# Serve health check endpoint
async def health_check(request: HealthCheckRequest):
"""Health check endpoint for monitoring"""
yield {
"status": "healthy",
"component": "GlobalPlanner",
"namespace": namespace,
"managed_namespaces": args.managed_namespaces or "all",
}
health_endpoint = component.endpoint("health")
await health_endpoint.serve_endpoint(health_check)
logger.info(" ✓ health - Health check endpoint")
logger.info("=" * 60)
logger.info("GlobalPlanner is ready and waiting for scale requests")
logger.info("=" * 60)
# Keep running forever (process scale requests as they come)
await asyncio.Event().wait()
if __name__ == "__main__":
parser = create_global_planner_parser()
args = parser.parse_args()
asyncio.run(main(args))
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Argument parsing for GlobalPlanner."""
import argparse
def create_global_planner_parser() -> argparse.ArgumentParser:
"""Create and configure the argument parser for GlobalPlanner.
Returns:
argparse.ArgumentParser: Configured argument parser for GlobalPlanner
"""
parser = argparse.ArgumentParser(
description="GlobalPlanner - Centralized Scaling Execution Service",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Simple deployment (accept all namespaces)
DYN_NAMESPACE=global-infra python -m dynamo.global_planner
# With authorization
DYN_NAMESPACE=global-infra python -m dynamo.global_planner \\
--managed-namespaces app-ns-1 app-ns-2 app-ns-3
# Custom environment
DYN_NAMESPACE=global-infra python -m dynamo.global_planner \\
--environment=kubernetes
""",
)
parser.add_argument(
"--managed-namespaces",
type=str,
nargs="+",
default=None,
help="Optional: List of namespaces authorized to use this GlobalPlanner (default: accept all)",
)
parser.add_argument(
"--environment",
default="kubernetes",
choices=["kubernetes"],
help="Environment type (currently only kubernetes supported)",
)
return parser
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Handler for scale_request endpoint in GlobalPlanner."""
import logging
from dynamo.planner import KubernetesConnector
from dynamo.planner.scale_protocol import ScaleRequest, ScaleResponse, ScaleStatus
from dynamo.runtime import DistributedRuntime, dynamo_endpoint
logger = logging.getLogger(__name__)
# Model name used for KubernetesConnector in remote execution mode
MANAGED_MODEL_NAME = "managed"
class ScaleRequestHandler:
"""Handles incoming scale requests in GlobalPlanner.
This handler:
1. Receives scale requests from Planners
2. Validates caller authorization (optional)
3. Caches KubernetesConnector per DGD for efficiency
4. Executes scaling via Kubernetes API
5. Returns current replica counts
"""
def __init__(
self, runtime: DistributedRuntime, managed_namespaces: list, k8s_namespace: str
):
"""Initialize the scale request handler.
Args:
runtime: Dynamo runtime instance
managed_namespaces: List of authorized namespaces (None = accept all)
k8s_namespace: Kubernetes namespace where GlobalPlanner is running
"""
self.runtime = runtime
# If managed_namespaces is None, accept all namespaces
self.managed_namespaces = (
set(managed_namespaces) if managed_namespaces else None
)
self.k8s_namespace = k8s_namespace
self.connectors = {} # Cache of KubernetesConnector per DGD
if self.managed_namespaces:
logger.info(
f"ScaleRequestHandler initialized for namespaces: {managed_namespaces}"
)
else:
logger.info("ScaleRequestHandler initialized (accepting all namespaces)")
@dynamo_endpoint(ScaleRequest, ScaleResponse)
async def scale_request(self, request: ScaleRequest):
"""Process scaling request from a Planner.
Args:
request: ScaleRequest with target replicas and DGD info
Yields:
ScaleResponse with status and current replica counts
"""
try:
# Validate caller namespace (if authorization is enabled)
if (
self.managed_namespaces is not None
and request.caller_namespace not in self.managed_namespaces
):
yield {
"status": ScaleStatus.ERROR.value,
"message": f"Namespace {request.caller_namespace} not authorized",
"current_replicas": {},
}
return
logger.info(
f"Processing scale request from {request.caller_namespace} "
f"for DGD {request.graph_deployment_name} "
f"in K8s namespace {request.k8s_namespace}"
)
# Get or create connector for this DGD
connector_key = f"{request.k8s_namespace}/{request.graph_deployment_name}"
if connector_key not in self.connectors:
connector = KubernetesConnector(
dynamo_namespace=request.caller_namespace,
model_name=MANAGED_MODEL_NAME, # Not used for remote execution
k8s_namespace=request.k8s_namespace,
parent_dgd_name=request.graph_deployment_name,
)
await connector._async_init()
self.connectors[connector_key] = connector
logger.debug(f"Created new connector for {connector_key}")
else:
connector = self.connectors[connector_key]
logger.debug(f"Reusing cached connector for {connector_key}")
# Execute scaling (request.target_replicas is already List[TargetReplica])
await connector.set_component_replicas(
request.target_replicas, blocking=request.blocking
)
# Get current replica counts
current_replicas = {}
deployment = connector.kube_api.get_graph_deployment(
connector.parent_dgd_name
)
for service_name, service_spec in deployment["spec"]["services"].items():
sub_type = service_spec.get("subComponentType", "")
if sub_type:
current_replicas[sub_type] = service_spec.get("replicas", 0)
logger.info(
f"Successfully scaled {request.graph_deployment_name}: {current_replicas}"
)
yield {
"status": ScaleStatus.SUCCESS.value,
"message": f"Scaled {request.graph_deployment_name} successfully",
"current_replicas": current_replicas,
}
except Exception as e:
logger.exception(f"Error processing scale request: {e}")
yield {
"status": ScaleStatus.ERROR.value,
"message": str(e),
"current_replicas": {},
}
......@@ -5,12 +5,14 @@ __all__ = [
"PlannerConnector",
"KubernetesConnector",
"VirtualConnector",
"GlobalPlannerConnector",
"SLAPlannerDefaults",
"TargetReplica",
"SubComponentType",
]
# Import the classes
from dynamo.planner.defaults import SLAPlannerDefaults, SubComponentType
from dynamo.planner.global_planner_connector import GlobalPlannerConnector
from dynamo.planner.kubernetes_connector import KubernetesConnector, TargetReplica
from dynamo.planner.planner_connector import PlannerConnector
from dynamo.planner.virtual_connector import VirtualConnector
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Connector for delegating scaling decisions to a centralized GlobalPlanner."""
import logging
import os
import time
from typing import Optional
from dynamo.planner.defaults import SubComponentType
from dynamo.planner.kubernetes_connector import TargetReplica
from dynamo.planner.planner_connector import PlannerConnector
from dynamo.planner.remote_planner_client import RemotePlannerClient
from dynamo.planner.scale_protocol import ScaleRequest, ScaleStatus
from dynamo.planner.utils.exceptions import EmptyTargetReplicasError
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class GlobalPlannerConnector(PlannerConnector):
"""
Connector that delegates scaling decisions to a centralized GlobalPlanner.
This connector wraps RemotePlannerClient and implements the PlannerConnector
interface, allowing planner_core.py to treat global-planner environment mode
consistently with kubernetes and virtual modes.
"""
def __init__(
self,
runtime: DistributedRuntime,
dynamo_namespace: str,
global_planner_namespace: str,
global_planner_component: str = "GlobalPlanner",
model_name: Optional[str] = None,
):
"""
Initialize GlobalPlannerConnector.
Args:
runtime: Distributed runtime for communication
dynamo_namespace: Local dynamo namespace (caller identification)
global_planner_namespace: Namespace where GlobalPlanner is deployed
global_planner_component: Component name of GlobalPlanner (default: "GlobalPlanner")
model_name: Optional model name (will be managed remotely if not provided)
"""
self.runtime = runtime
self.dynamo_namespace = dynamo_namespace
self.global_planner_namespace = global_planner_namespace
self.global_planner_component = global_planner_component
self.model_name = model_name
self.remote_client: Optional[RemotePlannerClient] = None
# Cache for predicted load (will be set by planner before scaling)
self.last_predicted_load: Optional[dict] = None
async def _async_init(self):
"""Async initialization - creates RemotePlannerClient"""
self.remote_client = RemotePlannerClient(
self.runtime,
self.global_planner_namespace,
self.global_planner_component,
)
logger.info(
f"GlobalPlannerConnector initialized: will delegate to {self.global_planner_namespace}.{self.global_planner_component}"
)
def set_predicted_load(
self, num_requests: Optional[float], isl: Optional[float], osl: Optional[float]
):
"""
Set predicted load for inclusion in next scale request.
This is called by planner_core.py before calling set_component_replicas.
"""
self.last_predicted_load = {
"num_requests": num_requests,
"isl": isl,
"osl": osl,
}
async def set_component_replicas(
self, target_replicas: list[TargetReplica], blocking: bool = True
):
"""
Set component replicas by delegating to GlobalPlanner.
Sends a ScaleRequest to the GlobalPlanner with the target replica configurations.
Args:
target_replicas: List of target replica configurations
blocking: Whether to wait for scaling completion (passed to GlobalPlanner)
Raises:
EmptyTargetReplicasError: If target_replicas is empty
RuntimeError: If remote_client is not initialized or response indicates error
"""
if not target_replicas:
raise EmptyTargetReplicasError()
if self.remote_client is None:
raise RuntimeError(
"GlobalPlannerConnector not initialized. Call _async_init() first."
)
# Get DGD info from environment variables
graph_deployment_name = os.environ.get("DYN_PARENT_DGD_K8S_NAME")
if not graph_deployment_name:
raise ValueError(
"DYN_PARENT_DGD_K8S_NAME environment variable is required but not set. "
"Please set DYN_PARENT_DGD_K8S_NAME to specify the parent DGD name."
)
k8s_namespace = os.environ.get("POD_NAMESPACE")
if not k8s_namespace:
raise ValueError(
"POD_NAMESPACE environment variable is required but not set. "
"Please set POD_NAMESPACE to specify the Kubernetes namespace."
)
# Create scale request
request = ScaleRequest(
caller_namespace=self.dynamo_namespace,
graph_deployment_name=graph_deployment_name,
k8s_namespace=k8s_namespace,
target_replicas=target_replicas,
blocking=blocking,
timestamp=time.time(),
predicted_load=self.last_predicted_load,
)
logger.info(
f"Delegating scale request to GlobalPlanner: "
f"DGD={graph_deployment_name}, "
f"prefill={[r.desired_replicas for r in target_replicas if r.sub_component_type == SubComponentType.PREFILL]}, "
f"decode={[r.desired_replicas for r in target_replicas if r.sub_component_type == SubComponentType.DECODE]}"
)
# Send request to GlobalPlanner
response = await self.remote_client.send_scale_request(request)
# Check response status
if response.status == ScaleStatus.SUCCESS:
logger.info(f"GlobalPlanner scaling successful: {response.message}")
elif response.status == ScaleStatus.ERROR:
logger.error(f"GlobalPlanner scaling failed: {response.message}")
raise RuntimeError(f"GlobalPlanner scaling failed: {response.message}")
else:
logger.warning(
f"GlobalPlanner returned status '{response.status.value}': {response.message}"
)
async def add_component(
self, sub_component_type: SubComponentType, blocking: bool = True
):
"""
Add a component (not supported for GlobalPlanner).
GlobalPlanner only supports batch operations via set_component_replicas.
"""
raise NotImplementedError(
"GlobalPlannerConnector only supports batch operations via set_component_replicas(). "
"Use set_component_replicas() to scale components."
)
async def remove_component(
self, sub_component_type: SubComponentType, blocking: bool = True
):
"""
Remove a component (not supported for GlobalPlanner).
GlobalPlanner only supports batch operations via set_component_replicas.
"""
raise NotImplementedError(
"GlobalPlannerConnector only supports batch operations via set_component_replicas(). "
"Use set_component_replicas() to scale components."
)
async def validate_deployment(
self,
prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None,
):
"""
Validate deployment (no-op for GlobalPlanner).
The GlobalPlanner validates the deployment on its side, so local
validation is not needed in delegating mode.
"""
logger.info(
"GlobalPlannerConnector: Skipping local deployment validation "
"(GlobalPlanner will validate on its side)"
)
async def wait_for_deployment_ready(self):
"""
Wait for deployment to be ready (no-op for GlobalPlanner).
The GlobalPlanner manages deployment state, so we don't need to
wait locally in delegating mode.
"""
logger.info(
"GlobalPlannerConnector: Skipping deployment ready check "
"(GlobalPlanner manages deployment state)"
)
def get_model_name(self) -> str:
"""
Get model name.
Returns the model name if provided during initialization, otherwise
returns a placeholder indicating the model is managed remotely.
"""
if self.model_name:
return self.model_name
return "managed-remotely"
......@@ -51,6 +51,7 @@ class KubernetesConnector(PlannerConnector):
dynamo_namespace: str,
model_name: Optional[str] = None,
k8s_namespace: Optional[str] = None,
parent_dgd_name: Optional[str] = None,
):
self.kube_api = KubernetesAPI(k8s_namespace)
......@@ -60,13 +61,19 @@ class KubernetesConnector(PlannerConnector):
model_name.lower()
) # normalize model name to lowercase (MDC)
# Allow overriding parent DGD name for centralized planner
if parent_dgd_name:
self.parent_dgd_name = parent_dgd_name
else:
graph_deployment_name = os.getenv("DYN_PARENT_DGD_K8S_NAME")
if not graph_deployment_name:
raise DeploymentValidationError(
["DYN_PARENT_DGD_K8S_NAME environment variable is not set"]
)
self.parent_dgd_name = graph_deployment_name
self.graph_deployment_name = graph_deployment_name
# For backwards compatibility
self.graph_deployment_name = self.parent_dgd_name
async def add_component(
self, sub_component_type: SubComponentType, blocking: bool = True
......@@ -411,7 +418,9 @@ if __name__ == "__main__":
)
parser.add_argument("--blocking", action="store_true")
args = parser.parse_args()
connector = KubernetesConnector(args.dynamo_namespace, args.k8s_namespace)
connector = KubernetesConnector(
args.dynamo_namespace, k8s_namespace=args.k8s_namespace
)
if args.action == "add":
task = connector.add_component(SubComponentType(args.component), args.blocking)
......
......@@ -77,4 +77,5 @@ async def init_planner(runtime: DistributedRuntime, args):
if __name__ == "__main__":
parser = create_sla_planner_parser()
args = parser.parse_args()
validate_sla_planner_args(args)
asyncio.run(init_planner(args))
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Client for calling remote planner's scale_request endpoint."""
import asyncio
import logging
from dynamo.planner.defaults import SubComponentType
from dynamo.planner.scale_protocol import ScaleRequest, ScaleResponse
from dynamo.runtime import DistributedRuntime
logger = logging.getLogger(__name__)
class RemotePlannerClient:
"""Client for delegating scaling requests to centralized planner"""
def __init__(
self,
runtime: DistributedRuntime,
central_namespace: str,
central_component: str,
connection_timeout: float = 30.0,
max_retries: int = 3,
):
self.runtime = runtime
self.central_namespace = central_namespace
self.central_component = central_component
self.connection_timeout = connection_timeout
self.max_retries = max_retries
self._client = None
async def _ensure_client(self):
"""Lazy initialization of endpoint client with retry mechanism"""
if self._client is None:
endpoint = (
self.runtime.namespace(self.central_namespace)
.component(self.central_component)
.endpoint("scale_request")
)
# Retry logic with exponential backoff
last_error = None
for attempt in range(self.max_retries):
try:
logger.info(
f"Attempting to connect to GlobalPlanner at "
f"{self.central_namespace}.{self.central_component} "
f"(attempt {attempt + 1}/{self.max_retries})"
)
self._client = await endpoint.client()
# Wait for instances with timeout
await asyncio.wait_for(
self._client.wait_for_instances(),
timeout=self.connection_timeout,
)
logger.info(
f"Successfully connected to centralized planner at "
f"{self.central_namespace}.{self.central_component}"
)
return
except asyncio.TimeoutError as e:
last_error = e
logger.warning(
f"Connection attempt {attempt + 1} timed out after "
f"{self.connection_timeout}s"
)
self._client = None
except Exception as e:
last_error = e
logger.warning(f"Connection attempt {attempt + 1} failed: {e}")
self._client = None
# Exponential backoff before retry (except on last attempt)
if attempt < self.max_retries - 1:
backoff = 2**attempt # 1s, 2s, 4s, ...
logger.info(f"Retrying in {backoff}s...")
await asyncio.sleep(backoff)
# All retries exhausted
raise RuntimeError(
f"Failed to connect to GlobalPlanner at "
f"{self.central_namespace}.{self.central_component} after "
f"{self.max_retries} attempts. Last error: {last_error}"
)
async def send_scale_request(self, request: ScaleRequest) -> ScaleResponse:
"""Send scale request to centralized planner"""
await self._ensure_client()
logger.info(
f"Sending scale request to centralized planner: "
f"prefill={[r.desired_replicas for r in request.target_replicas if r.sub_component_type == SubComponentType.PREFILL]}, "
f"decode={[r.desired_replicas for r in request.target_replicas if r.sub_component_type == SubComponentType.DECODE]}"
)
# Send request to single endpoint
request_json = request.model_dump_json()
response_data = await self._client.scale_request(request_json)
if response_data is None:
raise RuntimeError("No response from centralized planner")
# Parse response
response = ScaleResponse(**response_data)
logger.info(f"Scale request response: {response.status} - {response.message}")
return response
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Data structures for scale request/response protocol between delegating and centralized planners."""
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
from dynamo.planner.kubernetes_connector import TargetReplica
class ScaleStatus(str, Enum):
"""Status values for scaling operations"""
SUCCESS = "success"
ERROR = "error"
SCALING = "scaling"
class ScaleRequest(BaseModel):
"""Request to scale a deployment"""
# Caller identification
caller_namespace: str
# Target deployment
graph_deployment_name: str # K8s DynamoGraphDeployment name
k8s_namespace: str # K8s namespace
# Scaling targets
target_replicas: List[TargetReplica]
# Execution options
blocking: bool = False
# Optional context (for debugging/logging)
timestamp: Optional[float] = None
predicted_load: Optional[dict] = None
class ScaleResponse(BaseModel):
"""Response from scaling operation"""
status: ScaleStatus
message: str
current_replicas: dict # {"prefill": 3, "decode": 5}
......@@ -28,8 +28,8 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--environment",
default=SLAPlannerDefaults.environment,
choices=["kubernetes", "virtual"],
help="Environment type",
choices=["kubernetes", "virtual", "global-planner"],
help="Environment type: kubernetes (direct K8s scaling), virtual (dynamo runtime scaling), global-planner (delegate to GlobalPlanner)",
)
parser.add_argument(
"--namespace",
......@@ -177,6 +177,14 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
help="Model name of deployment (only required for virtual environment)",
)
# For global-planner environment mode
parser.add_argument(
"--global-planner-namespace",
type=str,
default=None,
help="Namespace of GlobalPlanner component (required when environment=global-planner)",
)
# Scaling mode flags
parser.add_argument(
"--enable-throughput-scaling",
......@@ -238,6 +246,16 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
return parser
def validate_planner_args(args):
"""Validate planner configuration"""
if args.environment == "global-planner":
if not args.global_planner_namespace:
raise ValueError(
"--global-planner-namespace required when environment=global-planner. "
"Please specify the namespace where GlobalPlanner is running."
)
def validate_sla_planner_args(args: argparse.Namespace) -> None:
"""Validate and normalize SLA planner arguments.
......
......@@ -18,6 +18,7 @@ from dynamo.planner import (
VirtualConnector,
)
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
from dynamo.planner.global_planner_connector import GlobalPlannerConnector
from dynamo.planner.utils.exceptions import DeploymentValidationError
from dynamo.planner.utils.load_predictor import LOAD_PREDICTORS
from dynamo.planner.utils.perf_interpolation import (
......@@ -267,8 +268,16 @@ class BasePlanner:
self.namespace = args.namespace
if not args.no_operation:
if connector is not None:
self.connector = connector
# Initialize connector based on environment
if args.environment == "global-planner":
# Use GlobalPlannerConnector to delegate to GlobalPlanner
self.connector = GlobalPlannerConnector(
runtime,
self.namespace,
args.global_planner_namespace,
"GlobalPlanner",
getattr(args, "model_name", None),
)
elif args.environment == "kubernetes":
self.connector = KubernetesConnector(
self.namespace, self.model_name
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for ScaleRequestHandler."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from dynamo.global_planner.scale_handler import ScaleRequestHandler
from dynamo.planner import SubComponentType, TargetReplica
from dynamo.planner.scale_protocol import ScaleRequest
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.planner,
pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20"),
]
@pytest.fixture
def mock_runtime():
"""Create a mock DistributedRuntime."""
return MagicMock()
@pytest.mark.asyncio
async def test_handler_authorization_success(mock_runtime):
"""Test handler authorizes requests from managed namespaces."""
handler = ScaleRequestHandler(
runtime=mock_runtime, managed_namespaces=["app-ns"], k8s_namespace="default"
)
request = ScaleRequest(
caller_namespace="app-ns",
graph_deployment_name="my-dgd",
k8s_namespace="default",
target_replicas=[
TargetReplica(
sub_component_type=SubComponentType.PREFILL, desired_replicas=3
)
],
)
# Mock KubernetesConnector
with patch(
"dynamo.global_planner.scale_handler.KubernetesConnector"
) as mock_connector_cls:
mock_connector = AsyncMock()
mock_connector_cls.return_value = mock_connector
mock_connector._async_init = AsyncMock()
mock_connector.set_component_replicas = AsyncMock()
mock_connector.kube_api = MagicMock()
mock_connector.kube_api.get_graph_deployment = MagicMock(
return_value={
"spec": {
"services": {
"prefill-svc": {"subComponentType": "prefill", "replicas": 3},
"decode-svc": {"subComponentType": "decode", "replicas": 5},
}
}
}
)
# Process request (pass as dict to match endpoint behavior)
results = []
async for response in handler.scale_request(request.model_dump()):
results.append(response)
assert len(results) == 1
response = results[0]
assert response["status"] == "success"
assert "Scaled" in response["message"]
assert response["current_replicas"]["prefill"] == 3
assert response["current_replicas"]["decode"] == 5
@pytest.mark.asyncio
async def test_handler_authorization_failure(mock_runtime):
"""Test handler rejects requests from unauthorized namespaces."""
handler = ScaleRequestHandler(
runtime=mock_runtime,
managed_namespaces=["authorized-ns"],
k8s_namespace="default",
)
request = ScaleRequest(
caller_namespace="unauthorized-ns",
graph_deployment_name="my-dgd",
k8s_namespace="default",
target_replicas=[
TargetReplica(
sub_component_type=SubComponentType.PREFILL, desired_replicas=3
)
],
)
# Process request
results = []
async for response in handler.scale_request(request.model_dump()):
results.append(response)
assert len(results) == 1
response = results[0]
assert response["status"] == "error"
assert "not authorized" in response["message"]
assert response["current_replicas"] == {}
@pytest.mark.asyncio
async def test_handler_multiple_dgds(mock_runtime):
"""Test handler creates separate connectors for different DGDs (and caches them)."""
handler = ScaleRequestHandler(
runtime=mock_runtime, managed_namespaces=["app-ns"], k8s_namespace="default"
)
request1 = ScaleRequest(
caller_namespace="app-ns",
graph_deployment_name="dgd-1",
k8s_namespace="default",
target_replicas=[
TargetReplica(
sub_component_type=SubComponentType.PREFILL, desired_replicas=2
)
],
)
request2 = ScaleRequest(
caller_namespace="app-ns",
graph_deployment_name="dgd-2", # Different DGD
k8s_namespace="default",
target_replicas=[
TargetReplica(
sub_component_type=SubComponentType.PREFILL, desired_replicas=4
)
],
)
with patch(
"dynamo.global_planner.scale_handler.KubernetesConnector"
) as mock_connector_cls:
mock_connector = AsyncMock()
mock_connector_cls.return_value = mock_connector
mock_connector._async_init = AsyncMock()
mock_connector.set_component_replicas = AsyncMock()
mock_connector.kube_api = MagicMock()
mock_connector.kube_api.get_graph_deployment = MagicMock(
return_value={"spec": {"services": {}}}
)
# Process both requests
async for _ in handler.scale_request(request1.model_dump()):
pass
async for _ in handler.scale_request(request2.model_dump()):
pass
# Verify two connectors were created
assert "default/dgd-1" in handler.connectors
assert "default/dgd-2" in handler.connectors
assert mock_connector_cls.call_count == 2
@pytest.mark.asyncio
async def test_handler_error_handling(mock_runtime):
"""Test handler error handling during scaling."""
handler = ScaleRequestHandler(
runtime=mock_runtime, managed_namespaces=["app-ns"], k8s_namespace="default"
)
request = ScaleRequest(
caller_namespace="app-ns",
graph_deployment_name="my-dgd",
k8s_namespace="default",
target_replicas=[
TargetReplica(
sub_component_type=SubComponentType.PREFILL, desired_replicas=3
)
],
)
with patch(
"dynamo.global_planner.scale_handler.KubernetesConnector"
) as mock_connector_cls:
mock_connector = AsyncMock()
mock_connector_cls.return_value = mock_connector
mock_connector._async_init = AsyncMock()
# Simulate error during scaling
mock_connector.set_component_replicas = AsyncMock(
side_effect=Exception("Scaling failed")
)
# Process request (pass as dict to match endpoint behavior)
results = []
async for response in handler.scale_request(request.model_dump()):
results.append(response)
assert len(results) == 1
response = results[0]
assert response["status"] == "error"
assert "Scaling failed" in response["message"]
@pytest.mark.asyncio
async def test_handler_blocking_mode(mock_runtime):
"""Test handler respects blocking mode."""
handler = ScaleRequestHandler(
runtime=mock_runtime, managed_namespaces=["app-ns"], k8s_namespace="default"
)
request = ScaleRequest(
caller_namespace="app-ns",
graph_deployment_name="my-dgd",
k8s_namespace="default",
target_replicas=[
TargetReplica(
sub_component_type=SubComponentType.PREFILL, desired_replicas=3
)
],
blocking=True, # Request blocking mode
)
with patch(
"dynamo.global_planner.scale_handler.KubernetesConnector"
) as mock_connector_cls:
mock_connector = AsyncMock()
mock_connector_cls.return_value = mock_connector
mock_connector._async_init = AsyncMock()
mock_connector.set_component_replicas = AsyncMock()
mock_connector.kube_api = MagicMock()
mock_connector.kube_api.get_graph_deployment = MagicMock(
return_value={"spec": {"services": {}}}
)
# Process request (pass as dict to match endpoint behavior)
async for _ in handler.scale_request(request.model_dump()):
pass
# Verify blocking=True was passed to connector
mock_connector.set_component_replicas.assert_called_once()
call_args = mock_connector.set_component_replicas.call_args
assert call_args[1]["blocking"] is True
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for planner argument parsing and validation."""
import pytest
from dynamo.planner.utils.planner_argparse import (
create_sla_planner_parser,
validate_planner_args,
)
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.planner,
]
def test_parser_global_planner_mode():
"""Test parser accepts global-planner environment mode arguments."""
parser = create_sla_planner_parser()
args = parser.parse_args(
[
"--namespace",
"test-ns",
"--environment",
"global-planner",
"--global-planner-namespace",
"global-ns",
]
)
assert args.environment == "global-planner"
assert args.global_planner_namespace == "global-ns"
def test_validate_global_planner_mode_without_namespace():
"""Test validation fails for global-planner environment without GlobalPlanner namespace."""
parser = create_sla_planner_parser()
args = parser.parse_args(
["--namespace", "test-ns", "--environment", "global-planner"]
)
with pytest.raises(ValueError, match="global-planner-namespace required"):
validate_planner_args(args)
def test_parser_invalid_environment():
"""Test parser rejects invalid environment."""
parser = create_sla_planner_parser()
with pytest.raises(SystemExit):
parser.parse_args(
["--namespace", "test-ns", "--environment", "invalid-environment"]
)
def test_parser_all_existing_args_still_work():
"""Test that existing planner arguments still work."""
parser = create_sla_planner_parser()
args = parser.parse_args(
[
"--namespace",
"test-ns",
"--backend",
"vllm",
"--environment",
"kubernetes",
"--ttft",
"200",
"--itl",
"50",
"--max-gpu-budget",
"16",
"--adjustment-interval",
"60",
]
)
assert args.namespace == "test-ns"
assert args.backend == "vllm"
assert args.environment == "kubernetes"
assert args.ttft == 200
assert args.itl == 50
assert args.max_gpu_budget == 16
assert args.adjustment_interval == 60
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for remote planner components.
Tests RemotePlannerClient (low-level) and GlobalPlannerConnector (high-level)
for delegating scale requests to GlobalPlanner.
"""
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from dynamo.planner import SubComponentType, TargetReplica
from dynamo.planner.global_planner_connector import GlobalPlannerConnector
from dynamo.planner.remote_planner_client import RemotePlannerClient
from dynamo.planner.scale_protocol import ScaleRequest, ScaleResponse, ScaleStatus
from dynamo.planner.utils.exceptions import EmptyTargetReplicasError
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.planner,
]
@pytest.fixture
def mock_runtime():
"""Create a mock DistributedRuntime."""
runtime = MagicMock()
namespace_mock = MagicMock()
component_mock = MagicMock()
endpoint_mock = MagicMock()
client_mock = AsyncMock()
runtime.namespace.return_value = namespace_mock
namespace_mock.component.return_value = component_mock
component_mock.endpoint.return_value = endpoint_mock
endpoint_mock.client = AsyncMock(return_value=client_mock)
client_mock.wait_for_instances = AsyncMock()
# Mock scale_request to return a response
client_mock.scale_request = AsyncMock(
return_value={
"status": "success",
"message": "Scaled successfully",
"current_replicas": {"prefill": 3, "decode": 5},
}
)
return runtime, client_mock
@pytest.mark.asyncio
async def test_send_scale_request_success(mock_runtime):
"""Test successful scale request (exercises protocol, client, and serialization)."""
runtime, mock_client = mock_runtime
client = RemotePlannerClient(runtime, "central-ns", "Planner")
request = ScaleRequest(
caller_namespace="app-ns",
graph_deployment_name="my-dgd",
k8s_namespace="default",
target_replicas=[
TargetReplica(
sub_component_type=SubComponentType.PREFILL, desired_replicas=3
),
TargetReplica(
sub_component_type=SubComponentType.DECODE, desired_replicas=5
),
],
blocking=False,
)
response = await client.send_scale_request(request)
assert response.status == ScaleStatus.SUCCESS
assert response.message == "Scaled successfully"
assert response.current_replicas["prefill"] == 3
assert response.current_replicas["decode"] == 5
# Verify lazy init happened
assert client._client is not None
runtime.namespace.assert_called_once_with("central-ns")
@pytest.mark.asyncio
async def test_send_scale_request_error():
"""Test scale request error handling."""
runtime = MagicMock()
namespace_mock = MagicMock()
component_mock = MagicMock()
endpoint_mock = MagicMock()
client_mock = AsyncMock()
runtime.namespace.return_value = namespace_mock
namespace_mock.component.return_value = component_mock
component_mock.endpoint.return_value = endpoint_mock
endpoint_mock.client = AsyncMock(return_value=client_mock)
client_mock.wait_for_instances = AsyncMock()
# Mock scale_request to return error response
client_mock.scale_request = AsyncMock(
return_value={
"status": "error",
"message": "Namespace not authorized",
"current_replicas": {},
}
)
client = RemotePlannerClient(runtime, "central-ns", "Planner")
request = ScaleRequest(
caller_namespace="unauthorized-ns",
graph_deployment_name="my-dgd",
k8s_namespace="default",
target_replicas=[
TargetReplica(
sub_component_type=SubComponentType.PREFILL, desired_replicas=1
)
],
)
response = await client.send_scale_request(request)
assert response.status == ScaleStatus.ERROR
assert "not authorized" in response.message
@pytest.mark.asyncio
async def test_send_scale_request_no_response():
"""Test scale request when no response is received."""
runtime = MagicMock()
namespace_mock = MagicMock()
component_mock = MagicMock()
endpoint_mock = MagicMock()
client_mock = AsyncMock()
runtime.namespace.return_value = namespace_mock
namespace_mock.component.return_value = component_mock
component_mock.endpoint.return_value = endpoint_mock
endpoint_mock.client = AsyncMock(return_value=client_mock)
client_mock.wait_for_instances = AsyncMock()
# Mock scale_request to return None
client_mock.scale_request = AsyncMock(return_value=None)
client = RemotePlannerClient(runtime, "central-ns", "Planner")
request = ScaleRequest(
caller_namespace="app-ns",
graph_deployment_name="my-dgd",
k8s_namespace="default",
target_replicas=[
TargetReplica(
sub_component_type=SubComponentType.PREFILL, desired_replicas=1
)
],
)
with pytest.raises(RuntimeError, match="No response from centralized planner"):
await client.send_scale_request(request)
@pytest.mark.asyncio
async def test_multiple_requests_reuse_client(mock_runtime):
"""Test that multiple requests reuse the same client instance."""
runtime, mock_client = mock_runtime
client = RemotePlannerClient(runtime, "central-ns", "Planner")
request1 = ScaleRequest(
caller_namespace="app-ns",
graph_deployment_name="my-dgd",
k8s_namespace="default",
target_replicas=[
TargetReplica(
sub_component_type=SubComponentType.PREFILL, desired_replicas=2
)
],
)
request2 = ScaleRequest(
caller_namespace="app-ns",
graph_deployment_name="my-dgd",
k8s_namespace="default",
target_replicas=[
TargetReplica(
sub_component_type=SubComponentType.PREFILL, desired_replicas=4
)
],
)
# Send first request
await client.send_scale_request(request1)
first_client = client._client
# Send second request
await client.send_scale_request(request2)
second_client = client._client
# Should be the same client instance
assert first_client is second_client
# ============================================================================
# GlobalPlannerConnector Tests
# ============================================================================
@pytest.fixture
def connector_runtime():
"""Mock runtime for GlobalPlannerConnector"""
return MagicMock()
@pytest.fixture
def connector(connector_runtime):
"""Create GlobalPlannerConnector instance"""
return GlobalPlannerConnector(
runtime=connector_runtime,
dynamo_namespace="test-ns",
global_planner_namespace="global-ns",
model_name="test-model",
)
@pytest.mark.asyncio
async def test_connector_initialization(connector, connector_runtime):
"""Test GlobalPlannerConnector initialization and async_init"""
assert connector.dynamo_namespace == "test-ns"
assert connector.global_planner_namespace == "global-ns"
assert connector.remote_client is None
with patch(
"dynamo.planner.global_planner_connector.RemotePlannerClient"
) as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
await connector._async_init()
mock_client_class.assert_called_once_with(
connector_runtime, "global-ns", "GlobalPlanner"
)
assert connector.remote_client == mock_client
@pytest.mark.asyncio
async def test_connector_set_replicas_success(connector):
"""Test GlobalPlannerConnector scaling with enum conversion and predicted load"""
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name="prefill-svc",
desired_replicas=3,
),
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name="decode-svc",
desired_replicas=5,
),
]
with patch.dict(
os.environ, {"DYN_PARENT_DGD_K8S_NAME": "dgd", "POD_NAMESPACE": "ns"}
):
mock_response = ScaleResponse(
status=ScaleStatus.SUCCESS,
message="OK",
current_replicas={"prefill": 3, "decode": 5},
)
mock_client = AsyncMock()
mock_client.send_scale_request = AsyncMock(return_value=mock_response)
connector.remote_client = mock_client
connector.set_predicted_load(100.0, 512.0, 256.0)
await connector.set_component_replicas(target_replicas, blocking=False)
# Verify request structure and enum to string conversion
request = mock_client.send_scale_request.call_args[0][0]
assert request.caller_namespace == "test-ns"
assert request.blocking is False
assert request.predicted_load["num_requests"] == 100.0
assert len(request.target_replicas) == 2
assert request.target_replicas[0].sub_component_type == "prefill"
assert isinstance(request.target_replicas[0].sub_component_type, str)
@pytest.mark.asyncio
async def test_connector_error_handling(connector):
"""Test GlobalPlannerConnector error handling"""
# Empty list
with pytest.raises(EmptyTargetReplicasError):
await connector.set_component_replicas([])
# Uninitialized
target = [
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name="p",
desired_replicas=1,
)
]
with pytest.raises(RuntimeError, match="not initialized"):
await connector.set_component_replicas(target)
# Error response
with patch.dict(os.environ, {"DYN_PARENT_DGD_K8S_NAME": "d", "POD_NAMESPACE": "n"}):
mock_response = ScaleResponse(
status=ScaleStatus.ERROR, message="Failed", current_replicas={}
)
mock_client = AsyncMock()
mock_client.send_scale_request = AsyncMock(return_value=mock_response)
connector.remote_client = mock_client
with pytest.raises(RuntimeError, match="GlobalPlanner scaling failed"):
await connector.set_component_replicas(target)
@pytest.mark.asyncio
async def test_connector_unsupported_and_noop_operations(connector):
"""Test unsupported and no-op operations"""
# Unsupported
with pytest.raises(NotImplementedError, match="batch operations"):
await connector.add_component(SubComponentType.PREFILL)
with pytest.raises(NotImplementedError, match="batch operations"):
await connector.remove_component(SubComponentType.DECODE)
# No-op operations
await connector.validate_deployment(
prefill_component_name="p", decode_component_name="d"
)
await connector.wait_for_deployment_ready()
def test_connector_model_name_and_predicted_load(connector_runtime):
"""Test GlobalPlannerConnector model name and predicted load tracking"""
# With model name
c1 = GlobalPlannerConnector(connector_runtime, "ns", "gns", "GP", model_name="test")
assert c1.get_model_name() == "test"
# Without model name
c2 = GlobalPlannerConnector(connector_runtime, "ns", "gns", "GP", model_name=None)
assert c2.get_model_name() == "managed-remotely"
# Predicted load
c1.set_predicted_load(42.0, 256.0, 128.0)
assert c1.last_predicted_load == {"num_requests": 42.0, "isl": 256.0, "osl": 128.0}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment