Unverified Commit 95c45509 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

refactor: restructure planner package into subpackage hierarchy (#7689)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent 68628976
...@@ -108,7 +108,8 @@ TensorRT-LLM ...@@ -108,7 +108,8 @@ TensorRT-LLM
# The Linux kernel generates core dump files when a process crashes (receives signals # The Linux kernel generates core dump files when a process crashes (receives signals
# like SIGSEGV, SIGABRT, SIGBUS, etc.). The Linux container's /proc/sys/kernel/core_pattern # like SIGSEGV, SIGABRT, SIGBUS, etc.). The Linux container's /proc/sys/kernel/core_pattern
# determines the filename (typically core or core.<pid>) # determines the filename (typically core or core.<pid>)
core /core
core.*
# Ruler Generated Files # Ruler Generated Files
/.cursor/instructions.md /.cursor/instructions.md
......
...@@ -7,8 +7,8 @@ import asyncio ...@@ -7,8 +7,8 @@ import asyncio
import logging import logging
from dynamo.planner import KubernetesConnector from dynamo.planner import KubernetesConnector
from dynamo.planner.kube import KubernetesAPI from dynamo.planner.connectors.kubernetes_api import KubernetesAPI
from dynamo.planner.scale_protocol import ScaleRequest, ScaleResponse, ScaleStatus from dynamo.planner.connectors.protocol import ScaleRequest, ScaleResponse, ScaleStatus
from dynamo.runtime import DistributedRuntime, dynamo_endpoint from dynamo.runtime import DistributedRuntime, dynamo_endpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -33,7 +33,7 @@ from pathlib import Path ...@@ -33,7 +33,7 @@ from pathlib import Path
import numpy as np import numpy as np
from dynamo.planner.utils.perf_interpolation import ( from dynamo.planner.core.throughput.interpolation import (
DecodeInterpolator, DecodeInterpolator,
PrefillInterpolator, PrefillInterpolator,
) )
......
...@@ -11,13 +11,17 @@ __all__ = [ ...@@ -11,13 +11,17 @@ __all__ = [
"SubComponentType", "SubComponentType",
"WorkerInfo", "WorkerInfo",
] ]
# Import the classes
from dynamo.planner.defaults import SLAPlannerDefaults, SubComponentType from dynamo.planner.config.defaults import (
from dynamo.planner.global_planner_connector import GlobalPlannerConnector SLAPlannerDefaults,
from dynamo.planner.kubernetes_connector import KubernetesConnector, TargetReplica SubComponentType,
from dynamo.planner.planner_connector import PlannerConnector TargetReplica,
from dynamo.planner.virtual_connector import VirtualConnector )
from dynamo.planner.worker_info import WorkerInfo from dynamo.planner.connectors.base import PlannerConnector
from dynamo.planner.connectors.global_planner import GlobalPlannerConnector
from dynamo.planner.connectors.kubernetes import KubernetesConnector
from dynamo.planner.connectors.virtual import VirtualConnector
from dynamo.planner.monitoring.worker_info import WorkerInfo
try: try:
from ._version import __version__ from ._version import __version__
......
...@@ -20,11 +20,11 @@ from typing import Union ...@@ -20,11 +20,11 @@ from typing import Union
from pydantic import BaseModel from pydantic import BaseModel
from dynamo.planner.utils.agg_planner import AggPlanner from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.utils.decode_planner import DecodePlanner from dynamo.planner.core.agg import AggPlanner
from dynamo.planner.utils.disagg_planner import DisaggPlanner from dynamo.planner.core.decode import DecodePlanner
from dynamo.planner.utils.planner_config import PlannerConfig from dynamo.planner.core.disagg import DisaggPlanner
from dynamo.planner.utils.prefill_planner import PrefillPlanner from dynamo.planner.core.prefill import PrefillPlanner
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class ComponentName:
"""Base class for backend component name configurations."""
prefill_worker_k8s_name: str = ""
prefill_worker_component_name: str = ""
prefill_worker_endpoint: str = ""
decode_worker_k8s_name: str = ""
decode_worker_component_name: str = ""
decode_worker_endpoint: str = ""
class VllmComponentName(ComponentName):
prefill_worker_k8s_name = "VllmPrefillWorker"
prefill_worker_component_name = "prefill"
prefill_worker_endpoint = "generate"
decode_worker_k8s_name = "VllmDecodeWorker"
decode_worker_component_name = "backend"
decode_worker_endpoint = "generate"
class SGLangComponentName(ComponentName):
prefill_worker_k8s_name = (
"prefill" # use short name to stay within k8s limits with grove
)
prefill_worker_component_name = "prefill"
prefill_worker_endpoint = "generate"
decode_worker_k8s_name = (
"decode" # use short name to stay within k8s limits with grove
)
decode_worker_component_name = "backend"
decode_worker_endpoint = "generate"
class TrtllmComponentName(ComponentName):
# Unified frontend architecture (consistent with vLLM/SGLang):
# - Prefill workers use "prefill" component
# - Decode workers use "tensorrt_llm" component
prefill_worker_k8s_name = "TRTLLMPrefillWorker"
prefill_worker_component_name = "prefill"
prefill_worker_endpoint = "generate"
decode_worker_k8s_name = "TRTLLMDecodeWorker"
decode_worker_component_name = "tensorrt_llm"
decode_worker_endpoint = "generate"
class MockerComponentName(ComponentName):
# Mocker backend for testing/simulation purposes
prefill_worker_k8s_name = "prefill"
prefill_worker_component_name = "prefill"
prefill_worker_endpoint = "generate"
decode_worker_k8s_name = "decode"
decode_worker_component_name = "backend"
decode_worker_endpoint = "generate"
WORKER_COMPONENT_NAMES: dict[str, type[ComponentName]] = {
"vllm": VllmComponentName,
"sglang": SGLangComponentName,
"trtllm": TrtllmComponentName,
"mocker": MockerComponentName,
}
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from enum import Enum
from typing import Literal, Optional
from pydantic import BaseModel
class BasePlannerDefaults:
# Namespace from DYN_NAMESPACE env var (injected by operator as "{k8s_namespace}-{dgd_name}")
namespace = os.environ.get("DYN_NAMESPACE", "dynamo")
environment: Literal["kubernetes", "virtual", "global-planner"] = "kubernetes"
backend: Literal["vllm", "sglang", "trtllm", "mocker"] = "vllm"
no_operation = False
log_dir = None
throughput_adjustment_interval = 180 # in seconds
max_gpu_budget = 8
min_endpoint = 1 # applies to both decode and prefill
decode_engine_num_gpu = 1
prefill_engine_num_gpu = 1
# Port for exposing planner's own metrics (0 means disabled)
metric_reporting_prometheus_port = int(os.environ.get("PLANNER_PROMETHEUS_PORT", 0))
class SLAPlannerDefaults(BasePlannerDefaults):
# Prometheus endpoint URL for pulling/querying metrics
metric_pulling_prometheus_endpoint = os.environ.get(
"PROMETHEUS_ENDPOINT",
"http://prometheus-kube-prometheus-prometheus.monitoring.svc.cluster.local:9090",
)
profile_results_dir = "profiling_results"
isl = 3000 # in number of tokens
osl = 150 # in number of tokens
ttft = 500.0 # in milliseconds
itl = 50.0 # in milliseconds
# for load predictor
load_predictor = "arima" # ["constant", "arima", "kalman", "prophet"]
prophet_window_size = 50
load_predictor_log1p = False
kalman_q_level = 1.0
kalman_q_trend = 0.1
kalman_r = 10.0
kalman_min_points = 5
no_correction = True
mode: Literal["disagg", "prefill", "decode", "agg"] = "disagg"
throughput_metrics_source: Literal["frontend", "router"] = "frontend"
# Scaling mode flags
enable_throughput_scaling = True
enable_load_scaling = False
# Load-based scaling settings
load_adjustment_interval = 5 # in seconds, must be < throughput_adjustment_interval
load_learning_window = 50 # sliding window size for regression
load_scaling_down_sensitivity = 80 # 0-100
load_metric_samples = 10 # number of samples per interval
load_min_observations = 5 # cold start threshold
class SubComponentType(str, Enum):
PREFILL = "prefill"
DECODE = "decode"
class TargetReplica(BaseModel):
sub_component_type: SubComponentType
component_name: Optional[str] = None
desired_replicas: int
...@@ -23,7 +23,7 @@ from typing import Literal, Optional ...@@ -23,7 +23,7 @@ from typing import Literal, Optional
import yaml import yaml
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from dynamo.planner.defaults import SLAPlannerDefaults from dynamo.planner.config.defaults import SLAPlannerDefaults
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dynamo.planner.defaults import SubComponentType from dynamo.planner.config.defaults import SubComponentType
# TODO: add ability to scale component to X replicas # TODO: add ability to scale component to X replicas
......
...@@ -8,12 +8,11 @@ import os ...@@ -8,12 +8,11 @@ import os
import time import time
from typing import Optional from typing import Optional
from dynamo.planner.defaults import SubComponentType from dynamo.planner.config.defaults import SubComponentType, TargetReplica
from dynamo.planner.kubernetes_connector import TargetReplica from dynamo.planner.connectors.base import PlannerConnector
from dynamo.planner.planner_connector import PlannerConnector from dynamo.planner.connectors.protocol import ScaleRequest, ScaleStatus
from dynamo.planner.remote_planner_client import RemotePlannerClient from dynamo.planner.connectors.remote_client import RemotePlannerClient
from dynamo.planner.scale_protocol import ScaleRequest, ScaleStatus from dynamo.planner.errors import EmptyTargetReplicasError
from dynamo.planner.utils.exceptions import EmptyTargetReplicasError
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
......
...@@ -18,15 +18,10 @@ import logging ...@@ -18,15 +18,10 @@ import logging
import os import os
from typing import Optional from typing import Optional
from pydantic import BaseModel from dynamo.planner.config.defaults import SubComponentType, TargetReplica
from dynamo.planner.connectors.base import PlannerConnector
from dynamo.planner.defaults import ( from dynamo.planner.connectors.kubernetes_api import KubernetesAPI
SubComponentType, from dynamo.planner.errors import (
get_service_from_sub_component_type_or_name,
)
from dynamo.planner.kube import KubernetesAPI
from dynamo.planner.planner_connector import PlannerConnector
from dynamo.planner.utils.exceptions import (
DeploymentModelNameMismatchError, DeploymentModelNameMismatchError,
DeploymentValidationError, DeploymentValidationError,
EmptyTargetReplicasError, EmptyTargetReplicasError,
...@@ -34,19 +29,19 @@ from dynamo.planner.utils.exceptions import ( ...@@ -34,19 +29,19 @@ from dynamo.planner.utils.exceptions import (
PlannerError, PlannerError,
UserProvidedModelNameMismatchError, UserProvidedModelNameMismatchError,
) )
from dynamo.planner.worker_info import WorkerInfo, build_worker_info_from_defaults from dynamo.planner.monitoring.dgd_services import (
get_service_from_sub_component_type_or_name,
)
from dynamo.planner.monitoring.worker_info import (
WorkerInfo,
build_worker_info_from_defaults,
)
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging() configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TargetReplica(BaseModel):
sub_component_type: SubComponentType
component_name: Optional[str] = None
desired_replicas: int
class KubernetesConnector(PlannerConnector): class KubernetesConnector(PlannerConnector):
def __init__( def __init__(
self, self,
......
...@@ -20,7 +20,7 @@ from typing import Optional ...@@ -20,7 +20,7 @@ from typing import Optional
from kubernetes import client, config from kubernetes import client, config
from kubernetes.config.config_exception import ConfigException from kubernetes.config.config_exception import ConfigException
from dynamo.planner.utils.exceptions import DynamoGraphDeploymentNotFoundError from dynamo.planner.errors import DynamoGraphDeploymentNotFoundError
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging() configure_dynamo_logging()
......
...@@ -8,7 +8,7 @@ from typing import List, Optional ...@@ -8,7 +8,7 @@ from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from dynamo.planner.kubernetes_connector import TargetReplica from dynamo.planner.config.defaults import TargetReplica
class ScaleStatus(str, Enum): class ScaleStatus(str, Enum):
......
...@@ -7,8 +7,8 @@ import asyncio ...@@ -7,8 +7,8 @@ import asyncio
import logging import logging
from dynamo._core import Client from dynamo._core import Client
from dynamo.planner.defaults import SubComponentType from dynamo.planner.config.defaults import SubComponentType
from dynamo.planner.scale_protocol import ScaleRequest, ScaleResponse from dynamo.planner.connectors.protocol import ScaleRequest, ScaleResponse
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -6,9 +6,9 @@ import os ...@@ -6,9 +6,9 @@ import os
from typing import Optional from typing import Optional
from dynamo._core import VirtualConnectorCoordinator from dynamo._core import VirtualConnectorCoordinator
from dynamo.planner import SubComponentType, TargetReplica from dynamo.planner.config.defaults import SubComponentType, TargetReplica
from dynamo.planner.planner_connector import PlannerConnector from dynamo.planner.connectors.base import PlannerConnector
from dynamo.planner.utils.exceptions import EmptyTargetReplicasError from dynamo.planner.errors import EmptyTargetReplicasError
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
......
...@@ -5,19 +5,20 @@ import asyncio ...@@ -5,19 +5,20 @@ import asyncio
import logging import logging
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from dynamo.planner import SubComponentType, TargetReplica from dynamo.planner.config.backend_components import WORKER_COMPONENT_NAMES
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES from dynamo.planner.config.defaults import SubComponentType, TargetReplica
from dynamo.planner.utils.planner_config import PlannerConfig from dynamo.planner.config.planner_config import PlannerConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.planner.utils.planner_core import (
BasePlanner, from dynamo.planner.core.base import BasePlanner
PlannerPrometheusMetrics, from dynamo.planner.core.budget import (
PlannerSharedState,
_apply_component_gpu_budget, _apply_component_gpu_budget,
_initialize_gpu_counts, _initialize_gpu_counts,
) )
from dynamo.planner.core.state import PlannerSharedState
from dynamo.planner.monitoring.planner_metrics import PlannerPrometheusMetrics
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -67,7 +68,7 @@ class AggPlanner: ...@@ -67,7 +68,7 @@ class AggPlanner:
component_type=SubComponentType.DECODE, component_type=SubComponentType.DECODE,
) )
from dynamo.planner.utils.fpm_regression import AggRegressionModel from dynamo.planner.core.load.fpm_regression import AggRegressionModel
self.regression = AggRegressionModel( self.regression = AggRegressionModel(
window_size=config.load_learning_window, window_size=config.load_learning_window,
......
...@@ -3,249 +3,45 @@ ...@@ -3,249 +3,45 @@
import asyncio import asyncio
import logging import logging
import math
import time import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
from prometheus_client import Gauge, start_http_server from prometheus_client import start_http_server
from dynamo.planner import ( from dynamo.planner.config.backend_components import WORKER_COMPONENT_NAMES
KubernetesConnector, from dynamo.planner.config.defaults import SubComponentType, TargetReplica
SubComponentType, from dynamo.planner.config.planner_config import PlannerConfig
TargetReplica, from dynamo.planner.connectors.global_planner import GlobalPlannerConnector
VirtualConnector, from dynamo.planner.connectors.kubernetes import KubernetesConnector
from dynamo.planner.connectors.virtual import VirtualConnector
from dynamo.planner.core.budget import (
_apply_component_gpu_budget,
_initialize_gpu_counts,
)
from dynamo.planner.core.load.predictors import LOAD_PREDICTORS
from dynamo.planner.core.state import PlannerSharedState
from dynamo.planner.core.throughput.interpolation import (
DecodeInterpolator,
PrefillInterpolator,
) )
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES from dynamo.planner.core.throughput.pre_swept_results import PreSweptResultsHelper
from dynamo.planner.global_planner_connector import GlobalPlannerConnector from dynamo.planner.monitoring.planner_metrics import PlannerPrometheusMetrics
from dynamo.planner.utils.exceptions import DeploymentValidationError from dynamo.planner.monitoring.traffic_metrics import Metrics, PrometheusAPIClient
from dynamo.planner.monitoring.worker_info import WorkerInfo, resolve_worker_info
from dynamo.planner.offline.trace_data import extract_metrics_from_mooncake
if TYPE_CHECKING: if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.llm import FpmEventSubscriber from dynamo.llm import FpmEventSubscriber
from dynamo.planner.utils.load_predictor import LOAD_PREDICTORS
from dynamo.planner.utils.perf_interpolation import (
DecodeInterpolator,
PrefillInterpolator,
)
from dynamo.planner.utils.planner_config import PlannerConfig
from dynamo.planner.utils.pre_swept_results_utils import PreSweptResultsHelper
from dynamo.planner.utils.prometheus import Metrics, PrometheusAPIClient
from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_mooncake
from dynamo.planner.worker_info import WorkerInfo, resolve_worker_info
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
# Union of all connector types used by the planner
ConnectorType = Union[GlobalPlannerConnector, KubernetesConnector, VirtualConnector] ConnectorType = Union[GlobalPlannerConnector, KubernetesConnector, VirtualConnector]
configure_dynamo_logging() configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PlannerPrometheusMetrics:
"""Container for all Planner Prometheus metrics."""
def __init__(self, prefix: str = "planner"):
# Worker counts
self.num_p_workers = Gauge(
f"{prefix}:num_p_workers", "Number of prefill workers"
)
self.num_d_workers = Gauge(
f"{prefix}:num_d_workers", "Number of decode workers"
)
# Observed metrics
self.observed_ttft = Gauge(
f"{prefix}:observed_ttft", "Observed time to first token (ms)"
)
self.observed_itl = Gauge(
f"{prefix}:observed_itl", "Observed inter-token latency (ms)"
)
self.observed_request_rate = Gauge(
f"{prefix}:observed_request_rate", "Observed request rate (req/s)"
)
self.observed_request_duration = Gauge(
f"{prefix}:observed_request_duration", "Observed request duration (s)"
)
self.observed_isl = Gauge(
f"{prefix}:observed_isl", "Observed input sequence length"
)
self.observed_osl = Gauge(
f"{prefix}:observed_osl", "Observed output sequence length"
)
# Correction factors
self.p_correction_factor = Gauge(
f"{prefix}:p_correction_factor", "Prefill correction factor"
)
self.d_correction_factor = Gauge(
f"{prefix}:d_correction_factor", "Decode correction factor"
)
# Predicted metrics
self.predicted_request_rate = Gauge(
f"{prefix}:predicted_request_rate", "Predicted request rate (req/s)"
)
self.predicted_isl = Gauge(
f"{prefix}:predicted_isl", "Predicted input sequence length"
)
self.predicted_osl = Gauge(
f"{prefix}:predicted_osl", "Predicted output sequence length"
)
self.predicted_num_p = Gauge(
f"{prefix}:predicted_num_p", "Predicted number of prefill replicas"
)
self.predicted_num_d = Gauge(
f"{prefix}:predicted_num_d", "Predicted number of decode replicas"
)
# Cumulative GPU usage
self.gpu_hours = Gauge(f"{prefix}:gpu_hours", "Cumulative GPU hours used")
@dataclass
class PlannerSharedState:
last_metrics: Metrics = field(default_factory=Metrics)
num_p_workers: int = 0
num_d_workers: int = 0
cumulative_gpu_hours: float = 0.0
last_adjustment_time: float = 0.0
# Lower bounds from throughput-based scaling (used when both modes enabled)
throughput_lower_bound_p: int = 1
throughput_lower_bound_d: int = 1
# Separate timestamp for load-based adjustment loop
last_load_adjustment_time: float = 0.0
def _apply_global_gpu_budget(
next_num_p: int, next_num_d: int, config: PlannerConfig
) -> tuple[int, int]:
"""Apply GPU budget constraint to both prefill and decode replicas.
When total GPUs required (num_p * prefill_gpus + num_d * decode_gpus) exceeds the
budget, scale down both proportionally using scale = budget / total_required. Prefill
replicas are clamped to [min_endpoint, max_prefill] where max_prefill reserves enough
GPUs for min_endpoint decode replicas. Remaining budget is then allocated to decode.
Returns (0, 0) if budget cannot satisfy min_endpoint for both components.
"""
if config.max_gpu_budget < 0:
return next_num_p, next_num_d
assert config.prefill_engine_num_gpu is not None
assert config.decode_engine_num_gpu is not None
total_gpu_required = (
next_num_p * config.prefill_engine_num_gpu
+ next_num_d * config.decode_engine_num_gpu
)
if total_gpu_required <= config.max_gpu_budget:
return next_num_p, next_num_d
min_required = (
config.min_endpoint * config.prefill_engine_num_gpu
+ config.min_endpoint * config.decode_engine_num_gpu
)
if config.max_gpu_budget < min_required:
logger.warning(
f"max_gpu_budget ({config.max_gpu_budget}) is below the minimum required "
f"for min_endpoint ({min_required}); enforcing zero replicas"
)
return 0, 0
scale = config.max_gpu_budget / total_gpu_required
max_prefill = math.floor(
(config.max_gpu_budget - config.min_endpoint * config.decode_engine_num_gpu)
/ config.prefill_engine_num_gpu
)
next_num_p = max(
config.min_endpoint, min(max_prefill, math.floor(next_num_p * scale))
)
remaining = config.max_gpu_budget - next_num_p * config.prefill_engine_num_gpu
next_num_d = max(
config.min_endpoint, math.floor(remaining / config.decode_engine_num_gpu)
)
logger.warning(
f"Total number of GPUs required ({total_gpu_required}) exceeds the max GPU budget ({config.max_gpu_budget}), "
f"scaling down to {next_num_p} prefill and {next_num_d} decode replicas"
)
return next_num_p, next_num_d
def _apply_component_gpu_budget(
desired_replicas: int, engine_num_gpu: int, config: PlannerConfig
) -> int:
"""Apply GPU budget constraint to a single component (prefill-only or decode-only).
When total GPUs required (replicas * gpus_per_replica) exceeds the budget, scale down
using scale = budget / total_required, floored and clamped to at least min_endpoint.
Returns 0 if budget cannot satisfy min_endpoint replicas.
"""
if config.max_gpu_budget < 0:
return desired_replicas
total_gpu_required = desired_replicas * engine_num_gpu
if total_gpu_required <= config.max_gpu_budget:
return desired_replicas
min_required = config.min_endpoint * engine_num_gpu
if config.max_gpu_budget < min_required:
logger.warning(
f"max_gpu_budget ({config.max_gpu_budget}) is below the minimum required "
f"for min_endpoint ({min_required}); enforcing zero replicas"
)
return 0
scale = config.max_gpu_budget / total_gpu_required
next_num = max(config.min_endpoint, math.floor(desired_replicas * scale))
logger.warning(
f"Total number of GPUs required ({total_gpu_required}) exceeds the max GPU budget ({config.max_gpu_budget}), "
f"scaling down to {next_num} replicas"
)
return next_num
def _initialize_gpu_counts(
config: PlannerConfig,
connector,
require_prefill: bool,
require_decode: bool,
) -> None:
"""Initialize GPU counts from DGD (Kubernetes) or config (virtual).
In Kubernetes mode: reads from DGD, falls back to CLI flags if not found
(useful for mockers that don't specify GPU resources).
In virtual mode: requires CLI flags, errors if not provided.
Raises:
DeploymentValidationError: If GPU counts cannot be determined
"""
# Try to read from DGD in Kubernetes mode
if hasattr(connector, "get_gpu_counts"):
try:
prefill_gpu, decode_gpu = connector.get_gpu_counts(
require_prefill=require_prefill,
require_decode=require_decode,
)
config.prefill_engine_num_gpu = prefill_gpu
config.decode_engine_num_gpu = decode_gpu
logger.info(
f"Detected GPU counts from DGD: prefill={prefill_gpu}, decode={decode_gpu}"
)
return
except Exception as e:
# Fall back to CLI flags (e.g., for mockers without GPU resources in DGD)
logger.warning(
f"Could not read GPU counts from DGD ({e}), falling back to CLI flags"
)
# Use CLI flags (virtual mode, or K8s fallback when DGD lacks GPU resources)
errors = []
if require_prefill and config.prefill_engine_num_gpu is None:
errors.append("Missing prefill_engine_num_gpu in config")
if require_decode and config.decode_engine_num_gpu is None:
errors.append("Missing decode_engine_num_gpu in config")
if errors:
raise DeploymentValidationError(errors)
logger.info(
f"Using GPU counts from CLI: prefill={config.prefill_engine_num_gpu}, "
f"decode={config.decode_engine_num_gpu}"
)
class BasePlanner: class BasePlanner:
component_type: SubComponentType component_type: SubComponentType
...@@ -429,7 +225,7 @@ class BasePlanner: ...@@ -429,7 +225,7 @@ class BasePlanner:
self.no_correction = config.no_correction self.no_correction = config.no_correction
if self.enable_load: if self.enable_load:
from dynamo.planner.utils.fpm_regression import ( from dynamo.planner.core.load.fpm_regression import (
DecodeRegressionModel, DecodeRegressionModel,
PrefillRegressionModel, PrefillRegressionModel,
) )
...@@ -755,7 +551,6 @@ class BasePlanner: ...@@ -755,7 +551,6 @@ class BasePlanner:
def predict_load(self) -> tuple[Optional[float], Optional[float], Optional[float]]: def predict_load(self) -> tuple[Optional[float], Optional[float], Optional[float]]:
try: try:
# predict the next load
next_num_req = self.num_req_predictor.predict_next() next_num_req = self.num_req_predictor.predict_next()
next_isl = self.isl_predictor.predict_next() next_isl = self.isl_predictor.predict_next()
next_osl = self.osl_predictor.predict_next() next_osl = self.osl_predictor.predict_next()
...@@ -775,7 +570,6 @@ class BasePlanner: ...@@ -775,7 +570,6 @@ class BasePlanner:
self.osl_predictor.add_data_point(osl_avg) self.osl_predictor.add_data_point(osl_avg)
def plan_adjustment(self) -> Optional[int]: def plan_adjustment(self) -> Optional[int]:
# Skip adjustment if no traffic
if not self.last_metrics.is_valid(): if not self.last_metrics.is_valid():
logger.info( logger.info(
"Metrics contain None or NaN values (no active requests), skipping adjustment" "Metrics contain None or NaN values (no active requests), skipping adjustment"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment