Unverified Commit f7e0b3fd authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

refactor(planner): extract discrete-event state machine with explicit inputs (#8046)

parent 39a6a240
......@@ -21,10 +21,12 @@ from typing import Union
from pydantic import BaseModel
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.agg import AggPlanner
from dynamo.planner.core.decode import DecodePlanner
from dynamo.planner.core.disagg import DisaggPlanner
from dynamo.planner.core.prefill import PrefillPlanner
from dynamo.planner.core.adapters import (
AggPlanner,
DecodePlanner,
DisaggPlanner,
PrefillPlanner,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
logger = logging.getLogger(__name__)
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.planner.core.state_machine import PlannerStateMachine
from dynamo.planner.core.types import (
EngineCapabilities,
FpmObservations,
PlannerEffects,
ScalingDecision,
ScheduledTick,
TickInput,
TrafficObservation,
WorkerCapabilities,
WorkerCounts,
)
__all__ = [
"EngineCapabilities",
"FpmObservations",
"PlannerEffects",
"PlannerStateMachine",
"ScalingDecision",
"ScheduledTick",
"TickInput",
"TrafficObservation",
"WorkerCapabilities",
"WorkerCounts",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Native planner adapter subclasses (one per mode).
Each subclass sets ``require_prefill`` / ``require_decode`` and overrides
``_bootstrap_regression()`` and ``_apply_effects()``. Everything else
(connector, Prometheus, FPM subscribers, tick loop) is in ``NativePlannerBase``.
"""
import logging
from dynamo.planner.config.defaults import SubComponentType, TargetReplica
from dynamo.planner.core.base import NativePlannerBase
from dynamo.planner.core.types import PlannerEffects
from dynamo.planner.monitoring.perf_metrics import fetch_pre_deployment_metrics
logger = logging.getLogger(__name__)
class PrefillPlanner(NativePlannerBase):
"""Prefill-only mode."""
require_prefill = True
require_decode = False
async def _bootstrap_regression(self) -> None:
try:
fpms = await fetch_pre_deployment_metrics(
runtime=self.runtime,
namespace=self.namespace,
worker_info=self.prefill_worker_info,
profile_results_dir=self.config.profile_results_dir,
component_type=SubComponentType.PREFILL,
)
self.state_machine.load_benchmark_fpms(prefill_fpms=fpms)
except Exception as e:
if self.config.enable_throughput_scaling:
raise
logger.warning(f"No pre-deployment data for prefill: {e}")
async def _apply_effects(self, effects: PlannerEffects) -> None:
if effects.scale_to is None or effects.scale_to.num_prefill is None:
return
desired = effects.scale_to.num_prefill
if self.prometheus_port != 0:
self.prometheus_metrics.predicted_num_p.set(desired)
await self._apply_scaling_targets(
[
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_worker_info.k8s_name,
desired_replicas=desired,
)
]
)
class DecodePlanner(NativePlannerBase):
"""Decode-only mode."""
require_prefill = False
require_decode = True
async def _bootstrap_regression(self) -> None:
try:
fpms = await fetch_pre_deployment_metrics(
runtime=self.runtime,
namespace=self.namespace,
worker_info=self.decode_worker_info,
profile_results_dir=self.config.profile_results_dir,
component_type=SubComponentType.DECODE,
)
self.state_machine.load_benchmark_fpms(decode_fpms=fpms)
except Exception as e:
if self.config.enable_throughput_scaling:
raise
logger.warning(f"No pre-deployment data for decode: {e}")
async def _apply_effects(self, effects: PlannerEffects) -> None:
if effects.scale_to is None or effects.scale_to.num_decode is None:
return
desired = effects.scale_to.num_decode
if self.prometheus_port != 0:
self.prometheus_metrics.predicted_num_d.set(desired)
await self._apply_scaling_targets(
[
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.decode_worker_info.k8s_name,
desired_replicas=desired,
)
]
)
class AggPlanner(NativePlannerBase):
"""Aggregated mode (single engine type handles both prefill and decode)."""
require_prefill = False
require_decode = True
async def _bootstrap_regression(self) -> None:
try:
fpms = await fetch_pre_deployment_metrics(
runtime=self.runtime,
namespace=self.namespace,
worker_info=self.decode_worker_info,
profile_results_dir=self.config.profile_results_dir,
component_type=SubComponentType.DECODE,
)
self.state_machine.load_benchmark_fpms(agg_fpms=fpms)
except Exception as e:
if self.config.enable_throughput_scaling:
raise
logger.warning(f"No pre-deployment data for agg: {e}")
async def _apply_effects(self, effects: PlannerEffects) -> None:
if effects.scale_to is None or effects.scale_to.num_decode is None:
return
desired = effects.scale_to.num_decode
if self.prometheus_port != 0:
self.prometheus_metrics.predicted_num_d.set(desired)
await self._apply_scaling_targets(
[
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.decode_worker_info.k8s_name,
desired_replicas=desired,
)
]
)
class DisaggPlanner(NativePlannerBase):
"""Disaggregated mode (separate prefill and decode engines)."""
require_prefill = True
require_decode = True
async def _bootstrap_regression(self) -> None:
for component, kwarg in [
(SubComponentType.PREFILL, "prefill_fpms"),
(SubComponentType.DECODE, "decode_fpms"),
]:
worker_info = (
self.prefill_worker_info
if component == SubComponentType.PREFILL
else self.decode_worker_info
)
try:
fpms = await fetch_pre_deployment_metrics(
runtime=self.runtime,
namespace=self.namespace,
worker_info=worker_info,
profile_results_dir=self.config.profile_results_dir,
component_type=component,
)
self.state_machine.load_benchmark_fpms(**{kwarg: fpms})
except Exception as e:
if self.config.enable_throughput_scaling:
raise
logger.warning(f"No pre-deployment data for {component.value}: {e}")
async def _apply_effects(self, effects: PlannerEffects) -> None:
if effects.scale_to is None:
return
decision = effects.scale_to
if decision.num_prefill is not None and self.prometheus_port != 0:
self.prometheus_metrics.predicted_num_p.set(decision.num_prefill)
if decision.num_decode is not None and self.prometheus_port != 0:
self.prometheus_metrics.predicted_num_d.set(decision.num_decode)
targets = []
if decision.num_prefill is not None:
targets.append(
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_worker_info.k8s_name,
desired_replicas=decision.num_prefill,
)
)
if decision.num_decode is not None:
targets.append(
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.decode_worker_info.k8s_name,
desired_replicas=decision.num_decode,
)
)
await self._apply_scaling_targets(targets)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import math
import time
from typing import TYPE_CHECKING, Optional
from dynamo.planner.config.backend_components import WORKER_COMPONENT_NAMES
from dynamo.planner.config.defaults import SubComponentType, TargetReplica
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.base import BasePlanner
from dynamo.planner.core.budget import (
_apply_component_gpu_budget,
_initialize_gpu_counts,
)
from dynamo.planner.core.perf_model import AggRegressionModel
from dynamo.planner.core.state import PlannerSharedState
from dynamo.planner.monitoring.perf_metrics import fetch_pre_deployment_metrics
from dynamo.planner.monitoring.planner_metrics import PlannerPrometheusMetrics
from dynamo.runtime import DistributedRuntime
if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class AggPlanner:
"""Aggregated planner: FPM-driven scaling for single engine type.
In aggregated mode, engines handle both prefill and decode (chunked prefill).
A single AggRegressionModel maps (sum_prefill_tokens, sum_decode_kv_tokens)
to wall_time using 2D linear regression.
Supports load-only, throughput-only, or both scaling modes.
Scaling logic (load-based):
- Estimate next TTFT per engine by simulating prefill chunking with
piggybacked decode (steady-state decode load).
- Estimate next ITL per engine by predicting decode iteration time with
average piggybacked prefill load.
- Scale up if (ALL TTFT > SLA) OR (ALL ITL > SLA).
- Scale down if (ALL TTFT < SLA * sensitivity) AND (ALL ITL < SLA * sensitivity).
Scaling logic (throughput-based):
- Use compute_agg_replicas() to find minimum replicas where both SLAs
are met under predicted traffic load.
"""
def __init__(self, runtime: DistributedRuntime, config: PlannerConfig) -> None:
self.config = config
self.runtime = runtime
self.shared_state = PlannerSharedState()
self.enable_throughput = config.enable_throughput_scaling
self.enable_load = config.enable_load_scaling
if not self.enable_throughput and not self.enable_load:
raise ValueError(
"Aggregated planner requires at least one scaling mode enabled."
)
prometheus_metrics = PlannerPrometheusMetrics()
self.planner = BasePlanner(
runtime,
config,
shared_state=self.shared_state,
prometheus_metrics=prometheus_metrics,
start_prometheus_server=True,
component_type=SubComponentType.DECODE,
)
self.regression = AggRegressionModel(
max_num_fpm_samples=config.max_num_fpm_samples,
min_observations=config.load_min_observations,
bucket_count=config.fpm_sample_bucket_size,
)
async def _async_init(self):
defaults = WORKER_COMPONENT_NAMES.get(self.config.backend)
if not self.config.no_operation:
connector = getattr(self.planner, "connector", None)
if connector and hasattr(connector, "_async_init"):
await connector._async_init()
logger.info("Validating deployment...")
await self.planner.connector.validate_deployment(
prefill_component_name=None,
decode_component_name=(
defaults.decode_worker_k8s_name if defaults else None
),
require_prefill=False,
require_decode=True,
)
logger.info("Successfully validated the deployment")
_initialize_gpu_counts(
self.config,
self.planner.connector,
require_prefill=False,
require_decode=True,
)
await self.planner.connector.wait_for_deployment_ready(
include_planner=False
)
await self.planner._init_worker_info(require_prefill=False, require_decode=True)
if self.runtime is not None:
await self.planner._init_fpm_subscriber()
await self._bootstrap_regression()
async def _bootstrap_regression(self) -> None:
"""Bootstrap agg regression from pre-deployment benchmark data."""
worker_info = self.planner.decode_worker_info
try:
fpms = await fetch_pre_deployment_metrics(
runtime=self.runtime,
namespace=self.config.namespace,
worker_info=worker_info,
profile_results_dir=self.config.profile_results_dir,
component_type=SubComponentType.DECODE,
)
self.regression.load_benchmark_fpms(fpms)
logger.info(
f"Bootstrapped agg regression with {len(fpms)} pre-deployment FPMs"
)
except Exception as e:
if self.enable_throughput:
raise
logger.warning(
f"No pre-deployment data for agg regression: {e}. "
"Load-based scaling will learn from live FPM only."
)
async def run(self):
"""Main scaling loop. Call _async_init() before this."""
self.shared_state.last_adjustment_time = time.time()
loops = []
if self.enable_throughput:
loops.append(self._throughput_loop())
loops.append(self._load_and_fpm_update_loop())
await asyncio.gather(*loops)
async def _throughput_loop(self) -> None:
"""Throughput-based scaling loop for agg mode."""
while True:
current_time = time.time()
if (
current_time - self.shared_state.last_adjustment_time
>= self.config.throughput_adjustment_interval
):
self.shared_state.last_adjustment_time = time.time()
logger.info("New agg throughput adjustment interval started!")
await self.planner.observe_traffic_stats(
require_prefill=False, require_decode=True
)
metrics = self.shared_state.last_metrics
if not metrics.is_valid():
logger.info("Metrics invalid, skipping agg throughput adjustment")
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
continue
next_num_req = self.planner.num_req_predictor.predict_next()
next_isl = self.planner.isl_predictor.predict_next()
next_osl = self.planner.osl_predictor.predict_next()
max_num_batched_tokens = getattr(
self.planner.decode_worker_info, "max_num_batched_tokens", None
)
if not max_num_batched_tokens or max_num_batched_tokens <= 0:
logger.warning(
"max_num_batched_tokens not available, skipping agg throughput"
)
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
continue
(
engine_rps,
actual_ttft,
actual_itl,
) = self.regression.find_best_engine_agg_rps(
isl=next_isl,
osl=next_osl,
max_num_batched_tokens=max_num_batched_tokens,
ttft_sla=self.config.ttft,
itl_sla=self.config.itl,
)
if engine_rps <= 0:
logger.warning(
"Agg perf model not ready, skipping throughput scaling"
)
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
continue
if actual_ttft > self.config.ttft or actual_itl > self.config.itl:
logger.warning(
f"Agg SLA not fully met: TTFT={actual_ttft:.1f}ms "
f"(target {self.config.ttft:.1f}ms), "
f"ITL={actual_itl:.1f}ms (target {self.config.itl:.1f}ms), "
"scaling with best achievable rate"
)
demand_rps = next_num_req / self.config.throughput_adjustment_interval
desired = math.ceil(demand_rps / engine_rps)
desired = max(desired, self.config.min_endpoint)
logger.info(
f"Agg: {demand_rps:.2f}(demand rps) / "
f"{engine_rps:.2f}(engine rps) = {desired}(replicas), "
f"est_ttft={actual_ttft:.1f}ms, est_itl={actual_itl:.1f}ms"
)
if self.enable_load:
self.shared_state.throughput_lower_bound_d = desired
logger.info(f"Agg throughput lower bound set to {desired}")
else:
assert self.config.decode_engine_num_gpu is not None
desired = _apply_component_gpu_budget(
desired, self.config.decode_engine_num_gpu, self.config
)
if (
self.planner.prometheus_port != 0
and self.planner.prometheus_metrics is not None
):
self.planner.prometheus_metrics.predicted_num_d.set(desired)
if not self.config.no_operation:
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.planner.decode_worker_info.k8s_name,
desired_replicas=desired,
)
]
await self.planner.connector.set_component_replicas(
target_replicas, blocking=False
)
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
async def _load_and_fpm_update_loop(self) -> None:
"""FPM observation and (optionally) load-based scaling for agg mode.
Always updates regression with live FPM. When load-based scaling
is enabled, makes scaling decisions immediately after.
"""
pending_desired: Optional[int] = None
while True:
await asyncio.sleep(self.config.load_adjustment_interval)
logger.info("New agg load/FPM update interval started!")
_, num_d, _ = await self.planner.get_workers_info(
require_prefill=False, require_decode=True
)
self.shared_state.num_d_workers = num_d
num_workers = num_d
fpm_stats = self.planner._get_fpm_stats()
if not fpm_stats:
continue
for (wid, dp), fpm in fpm_stats.items():
BasePlanner._log_fpm(wid, dp, fpm, "agg")
self.regression.add_observation(fpm)
if not self.enable_load:
continue
if pending_desired is not None:
if num_workers == pending_desired:
logger.info(
f"Scaling to {pending_desired} complete, resuming decisions"
)
pending_desired = None
else:
logger.info(
f"Scaling in progress ({num_workers} -> {pending_desired}), "
"observing only"
)
continue
if not BasePlanner._reconcile_fpm_worker_count(
fpm_stats, num_workers, "agg"
):
continue
if not self.regression.has_sufficient_data():
logger.info(
f"Agg regression: insufficient data "
f"({self.regression.num_observations}/{self.regression.min_observations})"
)
continue
max_num_batched_tokens = getattr(
self.planner.decode_worker_info, "max_num_batched_tokens", None
)
if not max_num_batched_tokens or max_num_batched_tokens <= 0:
logger.warning(
"max_num_batched_tokens not available from WorkerInfo, "
"skipping agg scaling"
)
continue
p_desired = self._prefill_scaling_decision(
fpm_stats, num_workers, max_num_batched_tokens
)
d_desired = self._decode_scaling_decision(fpm_stats, num_workers)
logger.info(
f"Agg scaling decisions: prefill={p_desired}, decode={d_desired} "
f"(current={num_workers})"
)
if p_desired is not None and p_desired > num_workers:
desired = p_desired
elif d_desired is not None and d_desired > num_workers:
desired = d_desired
elif (
p_desired is not None
and p_desired < num_workers
and d_desired is not None
and d_desired < num_workers
):
desired = max(p_desired, d_desired)
else:
logger.info("Agg scaling: no scaling needed")
continue
desired = max(desired, self.config.min_endpoint)
if self.enable_throughput:
desired = max(desired, self.shared_state.throughput_lower_bound_d)
assert self.config.decode_engine_num_gpu is not None
desired = _apply_component_gpu_budget(
desired, self.config.decode_engine_num_gpu, self.config
)
logger.info(f"Agg load-based scaling: {num_workers} -> {desired}")
if (
self.planner.prometheus_port != 0
and self.planner.prometheus_metrics is not None
):
self.planner.prometheus_metrics.predicted_num_d.set(desired)
if not self.config.no_operation:
pending_desired = desired
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.planner.decode_worker_info.k8s_name,
desired_replicas=desired,
)
]
await self.planner.connector.set_component_replicas(
target_replicas, blocking=False
)
def _prefill_scaling_decision(
self,
fpm_stats: "dict[tuple[str, int], ForwardPassMetrics]",
num_workers: int,
max_num_batched_tokens: int,
) -> Optional[int]:
estimated_ttfts: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
est = self.regression.estimate_next_ttft(
queued_prefill_tokens=fpm.queued_requests.sum_prefill_tokens,
max_num_batched_tokens=max_num_batched_tokens,
current_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
)
if est is not None:
estimated_ttfts.append(est * 1000)
return self.planner._load_based_scaling_decision_from_estimates(
estimated_ttfts, self.config.ttft, num_workers, "agg TTFT"
)
def _decode_scaling_decision(
self,
fpm_stats: "dict[tuple[str, int], ForwardPassMetrics]",
num_workers: int,
) -> Optional[int]:
estimated_itls: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
est = self.regression.estimate_next_itl(
scheduled_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
queued_decode_kv=fpm.queued_requests.sum_decode_kv_tokens,
)
if est is not None:
estimated_itls.append(est * 1000)
return self.planner._load_based_scaling_decision_from_estimates(
estimated_itls, self.config.itl, num_workers, "agg ITL"
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Runtime I/O plumbing for the native planner.
This module contains **zero decision logic**. It only gathers data from the
outside world (Prometheus, FPM subscribers, K8s connectors) and applies
scaling decisions back. All scaling logic lives in
:class:`~dynamo.planner.core.state_machine.PlannerStateMachine`.
Subclasses (PrefillPlanner, DecodePlanner, AggPlanner, DisaggPlanner) set
mode-specific flags and override ``_bootstrap_regression`` and
``_apply_effects``.
"""
from __future__ import annotations
import asyncio
import logging
import time
......@@ -9,19 +23,23 @@ from typing import TYPE_CHECKING, Optional, Union
from prometheus_client import start_http_server
from dynamo.planner.config.backend_components import WORKER_COMPONENT_NAMES
from dynamo.planner.config.defaults import SubComponentType, TargetReplica
from dynamo.planner.config.defaults import TargetReplica
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.connectors.global_planner import GlobalPlannerConnector
from dynamo.planner.connectors.kubernetes import KubernetesConnector
from dynamo.planner.connectors.virtual import VirtualConnector
from dynamo.planner.core.budget import (
_apply_component_gpu_budget,
_initialize_gpu_counts,
from dynamo.planner.core.budget import _initialize_gpu_counts
from dynamo.planner.core.state_machine import PlannerStateMachine
from dynamo.planner.core.types import (
EngineCapabilities,
FpmObservations,
PlannerEffects,
ScheduledTick,
TickInput,
TrafficObservation,
WorkerCapabilities,
WorkerCounts,
)
from dynamo.planner.core.load.predictors import LOAD_PREDICTORS
from dynamo.planner.core.perf_model import DecodeRegressionModel, PrefillRegressionModel
from dynamo.planner.core.state import PlannerSharedState
from dynamo.planner.monitoring.perf_metrics import fetch_pre_deployment_metrics
from dynamo.planner.monitoring.planner_metrics import PlannerPrometheusMetrics
from dynamo.planner.monitoring.traffic_metrics import Metrics, PrometheusAPIClient
from dynamo.planner.monitoring.worker_info import WorkerInfo, resolve_worker_info
......@@ -31,7 +49,6 @@ if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.llm import FpmEventSubscriber
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -41,34 +58,63 @@ configure_dynamo_logging()
logger = logging.getLogger(__name__)
class BasePlanner:
component_type: SubComponentType
# ------------------------------------------------------------------
# Helpers for building WorkerCapabilities from resolved WorkerInfo
# ------------------------------------------------------------------
def __init__(
self,
runtime: Optional[DistributedRuntime],
config: PlannerConfig,
shared_state: Optional[PlannerSharedState] = None,
prometheus_metrics: Optional[PlannerPrometheusMetrics] = None,
prometheus_traffic_client: Optional[PrometheusAPIClient] = None,
connector: Optional[ConnectorType] = None,
start_prometheus_server: bool = True,
component_type: Optional[SubComponentType] = None,
):
if component_type is not None:
self.component_type = component_type
self.config = config
self.shared_state = shared_state or PlannerSharedState()
def _engine_caps(
worker_info: Optional[WorkerInfo], num_gpu: Optional[int]
) -> Optional[EngineCapabilities]:
if worker_info is None and num_gpu is None:
return None
return EngineCapabilities(
num_gpu=num_gpu,
max_num_batched_tokens=worker_info.max_num_batched_tokens
if worker_info
else None,
max_num_seqs=worker_info.max_num_seqs if worker_info else None,
context_length=worker_info.context_length if worker_info else None,
)
def build_worker_capabilities(
config: PlannerConfig,
prefill_worker_info: Optional[WorkerInfo] = None,
decode_worker_info: Optional[WorkerInfo] = None,
) -> WorkerCapabilities:
return WorkerCapabilities(
prefill=_engine_caps(prefill_worker_info, config.prefill_engine_num_gpu),
decode=_engine_caps(decode_worker_info, config.decode_engine_num_gpu),
)
# ------------------------------------------------------------------
# Base adapter
# ------------------------------------------------------------------
class NativePlannerBase:
"""Base adapter: runtime I/O plumbing shared by all planner modes.
Subclasses set ``require_prefill`` / ``require_decode`` and override
``_bootstrap_regression()`` and ``_apply_effects()``.
"""
require_prefill: bool = False
require_decode: bool = False
def __init__(
self, runtime: Optional[DistributedRuntime], config: PlannerConfig
) -> None:
self.config = config
self.runtime = runtime
self.namespace = config.namespace
self.model_name: Optional[str] = None
self.connector: ConnectorType
if connector is not None:
self.connector = connector
elif not config.no_operation:
# Connector
self.connector: ConnectorType
if not config.no_operation:
if config.environment == "global-planner":
assert config.global_planner_namespace is not None
assert runtime is not None
......@@ -84,217 +130,161 @@ class BasePlanner:
elif config.environment == "virtual":
assert runtime is not None
self.connector = VirtualConnector(
runtime,
self.namespace,
config.model_name,
runtime, self.namespace, config.model_name
)
else:
raise ValueError(f"Invalid environment: {config.environment}")
self.prometheus_traffic_client = (
prometheus_traffic_client
or PrometheusAPIClient(
config.metric_pulling_prometheus_endpoint,
config.namespace,
metrics_source=config.throughput_metrics_source,
)
# Prometheus
self.prometheus_traffic_client = PrometheusAPIClient(
config.metric_pulling_prometheus_endpoint,
config.namespace,
metrics_source=config.throughput_metrics_source,
)
if config.throughput_metrics_source == "router":
self.prometheus_traffic_client.warn_if_router_not_scraped()
predictor_cls = LOAD_PREDICTORS[config.load_predictor]
self.num_req_predictor = predictor_cls(config)
self.isl_predictor = predictor_cls(config)
self.osl_predictor = predictor_cls(config)
# Optional warmup: preload predictors with historical observations from a
# mooncake-style JSONL trace (request_count/avg_isl/avg_osl per interval).
if config.load_predictor_warmup_trace is not None:
warmup_trace = config.load_predictor_warmup_trace
self.prometheus_port = config.metric_reporting_prometheus_port
self.prometheus_metrics = PlannerPrometheusMetrics()
if self.prometheus_port != 0:
try:
metrics = extract_metrics_from_mooncake(
warmup_trace, config.throughput_adjustment_interval
)
for m in metrics:
self.num_req_predictor.add_data_point(float(m["request_count"]))
self.isl_predictor.add_data_point(float(m["avg_isl"]))
self.osl_predictor.add_data_point(float(m["avg_osl"]))
start_http_server(self.prometheus_port)
logger.info(
f"Warmed load predictors with {len(metrics)} intervals from {warmup_trace}"
f"Started Prometheus metrics server on port {self.prometheus_port}"
)
except Exception as e:
logger.warning(
f"Failed to warm load predictors from {warmup_trace}: {e}"
)
finally:
# Even with warmup data, ignore the initial post-deploy idle
# period (leading zeros) when live metrics start coming in.
for p in (
self.num_req_predictor,
self.isl_predictor,
self.osl_predictor,
):
if hasattr(p, "reset_idle_skip"):
p.reset_idle_skip()
self.enable_load = config.enable_load_scaling
self.enable_throughput = config.enable_throughput_scaling
logger.error(f"Failed to start Prometheus metrics server: {e}")
# Worker info (resolved during _async_init)
self.prefill_worker_info = WorkerInfo()
self.decode_worker_info = WorkerInfo()
self.prefill_client = None
self.workers_client = None
# FPM subscribers (one per component type, populated during _async_init)
self._prefill_fpm_sub: Optional[FpmEventSubscriber] = None
self._decode_fpm_sub: Optional[FpmEventSubscriber] = None
self.prometheus_port = config.metric_reporting_prometheus_port
self.prometheus_metrics: PlannerPrometheusMetrics | None = None
# Runtime client caches
self._prefill_client = None
self._decode_client = None
if prometheus_metrics is None:
self.prometheus_metrics = PlannerPrometheusMetrics()
else:
self.prometheus_metrics = prometheus_metrics
# Shared metrics state
self._last_metrics = Metrics()
self._cumulative_gpu_hours: float = 0.0
if start_prometheus_server and self.prometheus_port != 0:
try:
start_http_server(self.prometheus_port)
logger.info(
f"Started Prometheus metrics server on port {self.prometheus_port}"
)
except Exception as e:
logger.error(f"Failed to start Prometheus metrics server: {e}")
# State machine (created after WorkerInfo is resolved)
self._state_machine: Optional[PlannerStateMachine] = None
self.fpm_subscriber: "Optional[FpmEventSubscriber]" = None
# ------------------------------------------------------------------
# State machine access
# ------------------------------------------------------------------
if self.component_type == SubComponentType.PREFILL:
self.ttft_regression = PrefillRegressionModel(
max_num_fpm_samples=self.config.max_num_fpm_samples,
min_observations=self.config.load_min_observations,
bucket_count=self.config.fpm_sample_bucket_size,
)
elif self.component_type == SubComponentType.DECODE:
self.itl_regression = DecodeRegressionModel(
max_num_fpm_samples=self.config.max_num_fpm_samples,
min_observations=self.config.load_min_observations,
bucket_count=self.config.fpm_sample_bucket_size,
def _ensure_state_machine(self) -> PlannerStateMachine:
if self._state_machine is None:
caps = build_worker_capabilities(
self.config,
self.prefill_worker_info,
self.decode_worker_info,
)
self._state_machine = PlannerStateMachine(self.config, caps)
self._warm_predictors()
return self._state_machine
@property
def last_metrics(self) -> Metrics:
return self.shared_state.last_metrics
def state_machine(self) -> PlannerStateMachine:
return self._ensure_state_machine()
@last_metrics.setter
def last_metrics(self, value: Metrics) -> None:
self.shared_state.last_metrics = value
def _warm_predictors(self) -> None:
if self.config.load_predictor_warmup_trace is None:
return
assert self._state_machine is not None
try:
metrics = extract_metrics_from_mooncake(
self.config.load_predictor_warmup_trace,
self.config.throughput_adjustment_interval,
)
self._state_machine.warm_load_predictors(
[
TrafficObservation(
duration_s=self.config.throughput_adjustment_interval,
num_req=float(m["request_count"]),
isl=float(m["avg_isl"]),
osl=float(m["avg_osl"]),
)
for m in metrics
]
)
except Exception as e:
logger.warning(f"Failed to warm load predictors: {e}")
async def _init_worker_info(
self, require_prefill: bool, require_decode: bool
) -> None:
"""Initialize WorkerInfo and model name in a single step."""
connector = getattr(self, "connector", None)
self.prefill_worker_info, self.decode_worker_info = resolve_worker_info(
backend=self.config.backend,
require_prefill=require_prefill,
require_decode=require_decode,
connector=connector,
config_model_name=getattr(self.config, "model_name", ""),
no_operation=self.config.no_operation,
)
# model_name is resolved and written into both WorkerInfo objects
self.model_name = (
self.decode_worker_info.model_name or self.prefill_worker_info.model_name
)
# ------------------------------------------------------------------
# Async init
# ------------------------------------------------------------------
async def _async_init(self):
"""Async initialization: connector init, deployment validation, WorkerInfo."""
async def _async_init(self) -> None:
if hasattr(self, "connector") and hasattr(self.connector, "_async_init"):
await self.connector._async_init()
require_prefill = self.component_type == SubComponentType.PREFILL
require_decode = self.component_type == SubComponentType.DECODE
if not self.config.no_operation:
defaults = WORKER_COMPONENT_NAMES.get(self.config.backend)
logger.info("Validating deployment...")
await self.connector.validate_deployment(
prefill_component_name=(
defaults.prefill_worker_k8s_name
if require_prefill and defaults
if self.require_prefill and defaults
else None
),
decode_component_name=(
defaults.decode_worker_k8s_name
if require_decode and defaults
if self.require_decode and defaults
else None
),
require_prefill=require_prefill,
require_decode=require_decode,
require_prefill=self.require_prefill,
require_decode=self.require_decode,
)
logger.info("Successfully validated the deployment")
_initialize_gpu_counts(
self.config,
self.connector,
require_prefill=require_prefill,
require_decode=require_decode,
require_prefill=self.require_prefill,
require_decode=self.require_decode,
)
await self.connector.wait_for_deployment_ready(include_planner=False)
await self._init_worker_info(
require_prefill=require_prefill,
require_decode=require_decode,
)
await self._init_worker_info()
if self.runtime is not None:
await self._init_fpm_subscriber()
if self.require_prefill:
await self._init_fpm_subscriber("prefill")
if self.require_decode:
await self._init_fpm_subscriber("decode")
await self._bootstrap_regression()
async def _bootstrap_regression(self) -> None:
"""Fetch pre-deployment FPM data and bootstrap the regression model."""
worker_info = (
self.prefill_worker_info
if self.component_type == SubComponentType.PREFILL
else self.decode_worker_info
async def _init_worker_info(self) -> None:
connector = getattr(self, "connector", None)
self.prefill_worker_info, self.decode_worker_info = resolve_worker_info(
backend=self.config.backend,
require_prefill=self.require_prefill,
require_decode=self.require_decode,
connector=connector,
config_model_name=getattr(self.config, "model_name", ""),
no_operation=self.config.no_operation,
)
self.model_name = (
self.decode_worker_info.model_name or self.prefill_worker_info.model_name
)
try:
fpms = await fetch_pre_deployment_metrics(
runtime=self.runtime,
namespace=self.namespace,
worker_info=worker_info,
profile_results_dir=self.config.profile_results_dir,
component_type=self.component_type,
)
if self.component_type == SubComponentType.PREFILL:
self.ttft_regression.load_benchmark_fpms(fpms)
elif self.component_type == SubComponentType.DECODE:
self.itl_regression.load_benchmark_fpms(fpms)
logger.info(
f"Bootstrapped {self.component_type.value} regression with "
f"{len(fpms)} pre-deployment FPMs"
)
except Exception as e:
if self.enable_throughput:
raise
logger.warning(
f"No pre-deployment data for {self.component_type.value} regression: {e}. "
"Load-based scaling will learn from live FPM only."
)
async def _init_fpm_subscriber(self) -> None:
"""Create and start the FPM subscriber for load-based scaling."""
async def _init_fpm_subscriber(self, component: str) -> None:
from dynamo.llm import FpmEventSubscriber
worker_info = (
self.prefill_worker_info
if self.component_type == SubComponentType.PREFILL
if component == "prefill"
else self.decode_worker_info
)
if not worker_info.component_name or not worker_info.endpoint:
logger.warning(
"WorkerInfo missing component_name or endpoint, "
"cannot create FPM subscriber"
f"WorkerInfo missing for {component}, cannot create FPM subscriber"
)
return
......@@ -302,50 +292,49 @@ class BasePlanner:
endpoint = self.runtime.endpoint(
f"{self.namespace}.{worker_info.component_name}.{worker_info.endpoint}"
)
self.fpm_subscriber = FpmEventSubscriber(endpoint)
self.fpm_subscriber.start_tracking()
sub = FpmEventSubscriber(endpoint)
sub.start_tracking()
logger.info(
f"FPM tracker started for {worker_info.component_name}.{worker_info.endpoint}"
)
def _get_fpm_stats(self) -> "dict[tuple[str, int], ForwardPassMetrics]":
"""Get decoded FPM stats from the subscriber, keyed by (worker_id, dp_rank)."""
if component == "prefill":
self._prefill_fpm_sub = sub
else:
self._decode_fpm_sub = sub
async def _bootstrap_regression(self) -> None:
"""Override in subclasses to bootstrap regression models."""
pass
# ------------------------------------------------------------------
# Data collection (runtime I/O)
# ------------------------------------------------------------------
def _decode_fpm_bytes(
self, subscriber: Optional[FpmEventSubscriber]
) -> dict[tuple[str, int], ForwardPassMetrics]:
from dynamo.common.forward_pass_metrics import decode as decode_fpm
if self.fpm_subscriber is None:
if subscriber is None:
return {}
raw_stats = self.fpm_subscriber.get_recent_stats()
result = {}
for key, raw_bytes in raw_stats.items():
for key, raw_bytes in subscriber.get_recent_stats().items():
fpm = decode_fpm(raw_bytes)
if fpm is not None:
result[key] = fpm
return result
async def _get_or_create_client(self, component_name: str, endpoint_name: str):
"""Create a client for the given component and endpoint, with a brief sleep for state sync."""
assert self.runtime is not None, "Runtime is not initialized"
assert self.runtime is not None
client = await self.runtime.endpoint(
f"{self.namespace}.{component_name}.{endpoint_name}"
).client()
# TODO: remove this sleep after rust client() is blocking until watching state
await asyncio.sleep(0.1)
return client
async def get_workers_info(
self, require_prefill: bool = True, require_decode: bool = True
) -> tuple[int, int, bool]:
"""
Get worker counts for prefill and decode components.
Returns:
tuple[int, int, bool]: (num_p_workers, num_d_workers, is_stable)
- is_stable: False if rollout in progress (scaling should be skipped)
"""
num_p_workers = 0
num_d_workers = 0
# For Kubernetes, use DGD status instead of runtime client
async def _get_worker_counts_raw(self) -> tuple[int, int, bool]:
"""Returns (num_prefill, num_decode, is_stable) from connector or runtime."""
if hasattr(self, "connector") and isinstance(
self.connector, KubernetesConnector
):
......@@ -355,515 +344,230 @@ class BasePlanner:
is_stable,
) = self.connector.get_actual_worker_counts(
prefill_component_name=(
self.prefill_worker_info.k8s_name if require_prefill else None
self.prefill_worker_info.k8s_name if self.require_prefill else None
),
decode_component_name=(
self.decode_worker_info.k8s_name if require_decode else None
self.decode_worker_info.k8s_name if self.require_decode else None
),
)
num_p_workers = prefill_count if require_prefill else 0
num_d_workers = decode_count if require_decode else 0
return num_p_workers, num_d_workers, is_stable
return (
prefill_count if self.require_prefill else 0,
decode_count if self.require_decode else 0,
is_stable,
)
# Fall back to runtime client for non-Kubernetes environments
if self.runtime is None:
raise RuntimeError("Runtime is not initialized")
if require_prefill:
num_p, num_d = 0, 0
if self.require_prefill:
try:
if self.prefill_client is None:
if self._prefill_client is None:
assert self.prefill_worker_info.component_name is not None
assert self.prefill_worker_info.endpoint is not None
self.prefill_client = await self._get_or_create_client(
self._prefill_client = await self._get_or_create_client(
self.prefill_worker_info.component_name,
self.prefill_worker_info.endpoint,
)
num_p_workers = len(self.prefill_client.instance_ids()) # type: ignore
num_p = len(self._prefill_client.instance_ids()) # type: ignore
except Exception:
num_p_workers = 0
logger.warning(
"No prefill workers found, aggregated mode is not supported yet"
)
logger.warning("No prefill workers found")
if require_decode:
if self.require_decode:
try:
if self.workers_client is None:
if self._decode_client is None:
assert self.decode_worker_info.component_name is not None
assert self.decode_worker_info.endpoint is not None
self.workers_client = await self._get_or_create_client(
self._decode_client = await self._get_or_create_client(
self.decode_worker_info.component_name,
self.decode_worker_info.endpoint,
)
num_d_workers = len(self.workers_client.instance_ids()) # type: ignore
num_d = len(self._decode_client.instance_ids()) # type: ignore
except Exception as e:
raise RuntimeError(f"Failed to get decode worker endpoints: {e}")
return num_p_workers, num_d_workers, True # Always stable for non-K8s
async def observe_traffic_stats(
self, require_prefill: bool = True, require_decode: bool = True
) -> None:
"""
Observe metrics from Prometheus and update shared state.
"""
num_p_workers, num_d_workers, _ = await self.get_workers_info(
require_prefill=require_prefill, require_decode=require_decode
)
self.shared_state.num_p_workers = num_p_workers
self.shared_state.num_d_workers = num_d_workers
logger.debug(
f"Number of prefill workers: {num_p_workers}, number of decode workers: {num_d_workers}"
)
return num_p, num_d, True
# Update Prometheus metrics if server is running
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.num_p_workers.set(num_p_workers)
self.prometheus_metrics.num_d_workers.set(num_d_workers)
async def _collect_traffic(self) -> Optional[TrafficObservation]:
"""Pull traffic metrics from Prometheus."""
num_p, num_d, _ = await self._get_worker_counts_raw()
# Calculate and accumulate GPU hours for this interval
# TODO: track startup and shutdown times to get more accurate GPU hours
interval_gpu_hours = (
if self.prometheus_port != 0:
self.prometheus_metrics.num_p_workers.set(num_p)
self.prometheus_metrics.num_d_workers.set(num_d)
gpu_hours = (
(
num_p_workers * (self.config.prefill_engine_num_gpu or 0)
+ num_d_workers * (self.config.decode_engine_num_gpu or 0)
num_p * (self.config.prefill_engine_num_gpu or 0)
+ num_d * (self.config.decode_engine_num_gpu or 0)
)
* self.config.throughput_adjustment_interval
/ 3600
)
self.shared_state.cumulative_gpu_hours += interval_gpu_hours
self.prometheus_metrics.gpu_hours.set(
self.shared_state.cumulative_gpu_hours
)
# Prometheus returns seconds, convert to milliseconds
assert (
self.model_name is not None
), "model_name must be set before observing traffic stats"
self._cumulative_gpu_hours += gpu_hours
self.prometheus_metrics.gpu_hours.set(self._cumulative_gpu_hours)
assert self.model_name is not None
interval_str = f"{self.config.throughput_adjustment_interval}s"
self.last_metrics.ttft = (
m = self._last_metrics
m.ttft = (
self.prometheus_traffic_client.get_avg_time_to_first_token(
interval_str,
self.model_name,
interval_str, self.model_name
)
* 1000
)
self.last_metrics.itl = (
m.itl = (
self.prometheus_traffic_client.get_avg_inter_token_latency(
interval_str,
self.model_name,
interval_str, self.model_name
)
* 1000
)
self.last_metrics.num_req = (
self.prometheus_traffic_client.get_avg_request_count(
interval_str,
self.model_name,
)
m.num_req = self.prometheus_traffic_client.get_avg_request_count(
interval_str, self.model_name
)
self.last_metrics.request_duration = (
self.prometheus_traffic_client.get_avg_request_duration(
interval_str,
self.model_name,
)
m.request_duration = self.prometheus_traffic_client.get_avg_request_duration(
interval_str, self.model_name
)
self.last_metrics.isl = (
self.prometheus_traffic_client.get_avg_input_sequence_tokens(
interval_str,
self.model_name,
)
m.isl = self.prometheus_traffic_client.get_avg_input_sequence_tokens(
interval_str, self.model_name
)
self.last_metrics.osl = (
self.prometheus_traffic_client.get_avg_output_sequence_tokens(
interval_str,
self.model_name,
)
m.osl = self.prometheus_traffic_client.get_avg_output_sequence_tokens(
interval_str, self.model_name
)
logger.info(
f"Observed num_req: {self.last_metrics.num_req:.2f} isl: {self.last_metrics.isl:.2f} osl: {self.last_metrics.osl:.2f}"
)
logger.info(
f"Observed ttft: {self.last_metrics.ttft:.2f}ms itl: {self.last_metrics.itl:.2f}ms"
f"Observed num_req: {m.num_req:.2f} isl: {m.isl:.2f} osl: {m.osl:.2f}"
)
# Update observed metrics in Prometheus
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.observed_ttft.set(self.last_metrics.ttft)
self.prometheus_metrics.observed_itl.set(self.last_metrics.itl)
if self.prometheus_port != 0:
self.prometheus_metrics.observed_ttft.set(m.ttft)
self.prometheus_metrics.observed_itl.set(m.itl)
self.prometheus_metrics.observed_request_rate.set(
self.last_metrics.num_req / self.config.throughput_adjustment_interval
)
self.prometheus_metrics.observed_request_duration.set(
self.last_metrics.request_duration
m.num_req / self.config.throughput_adjustment_interval
)
self.prometheus_metrics.observed_isl.set(self.last_metrics.isl)
self.prometheus_metrics.observed_osl.set(self.last_metrics.osl)
self.prometheus_metrics.observed_request_duration.set(m.request_duration)
self.prometheus_metrics.observed_isl.set(m.isl)
self.prometheus_metrics.observed_osl.set(m.osl)
self.update_predictors_from_metrics(self.last_metrics)
def update_predictors_from_metrics(self, metrics: Metrics) -> None:
if metrics.num_req is not None:
self.num_req_predictor.add_data_point(metrics.num_req)
if metrics.isl is not None:
self.isl_predictor.add_data_point(metrics.isl)
if metrics.osl is not None:
self.osl_predictor.add_data_point(metrics.osl)
def predict_load(self) -> tuple[Optional[float], Optional[float], Optional[float]]:
try:
next_num_req = self.num_req_predictor.predict_next()
next_isl = self.isl_predictor.predict_next()
next_osl = self.osl_predictor.predict_next()
logger.info(
f"Predicted load: num_req={next_num_req:.2f}, isl={next_isl:.2f}, osl={next_osl:.2f}"
)
return next_num_req, next_isl, next_osl
except Exception as e:
logger.error(f"Failed to predict load: {e}")
return None, None, None
def plan_adjustment(self) -> Optional[int]:
if not self.last_metrics.is_valid():
logger.info(
"Metrics contain None or NaN values (no active requests), skipping adjustment"
)
return None
next_num_req, next_isl, next_osl = self.predict_load()
if next_num_req is None or next_isl is None or next_osl is None:
if not m.is_valid():
logger.info("Metrics contain None or NaN values, skipping")
return None
# Update predicted load metrics in Prometheus
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.predicted_request_rate.set(
next_num_req / self.config.throughput_adjustment_interval
)
self.prometheus_metrics.predicted_isl.set(next_isl)
self.prometheus_metrics.predicted_osl.set(next_osl)
try:
return self._compute_replica_requirements(next_num_req, next_isl, next_osl)
except Exception as e:
logger.error(f"Failed to compute number of replicas: {e}")
return None
def update_predicted_replicas_metric(self, desired_replicas: int) -> None:
raise NotImplementedError
def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float
) -> Optional[int]:
raise NotImplementedError
def _component_name(self) -> str:
if self.component_type == SubComponentType.PREFILL:
assert self.prefill_worker_info.k8s_name is not None
return self.prefill_worker_info.k8s_name
assert self.decode_worker_info.k8s_name is not None
return self.decode_worker_info.k8s_name
def _engine_num_gpu(self) -> int:
if self.component_type == SubComponentType.PREFILL:
assert self.config.prefill_engine_num_gpu is not None
return self.config.prefill_engine_num_gpu
assert self.config.decode_engine_num_gpu is not None
return self.config.decode_engine_num_gpu
def apply_component_budget(self, desired_replicas: int) -> int:
return _apply_component_gpu_budget(
max(desired_replicas, self.config.min_endpoint),
self._engine_num_gpu(),
self.config,
)
async def _apply_scaling(self, desired_replicas: int) -> None:
if self.config.no_operation:
return
target_replicas = [
TargetReplica(
sub_component_type=self.component_type,
component_name=self._component_name(),
desired_replicas=desired_replicas,
)
]
await self.connector.set_component_replicas(target_replicas, blocking=False)
_apply_scaling_blocking = _apply_scaling
@staticmethod
def _reconcile_fpm_worker_count(
fpm_stats: "dict[tuple[str, int], ForwardPassMetrics]",
dgd_count: int,
label: str,
) -> bool:
"""Validate that FPM coverage matches DGD worker count, accounting for DP.
With attention DP, each worker emits FPM per dp_rank. We check that
the number of unique worker IDs matches DGD, and that all workers
have the same number of dp_ranks (complete coverage).
Returns True if counts match, False otherwise.
"""
workers_to_dp: dict[str, set[int]] = {}
for wid, dp in fpm_stats:
workers_to_dp.setdefault(wid, set()).add(dp)
fpm_worker_count = len(workers_to_dp)
if fpm_worker_count != dgd_count:
logger.warning(
f"Worker count mismatch: DGD reports {dgd_count}, "
f"FPM reports {fpm_worker_count} workers for {label}. "
"Skipping scaling."
)
return False
dp_sizes = {len(dps) for dps in workers_to_dp.values()}
if len(dp_sizes) > 1:
logger.warning(
f"Inconsistent DP ranks across workers for {label}: "
f"{dict(workers_to_dp)}. Skipping scaling."
)
return False
dp_size = dp_sizes.pop() if dp_sizes else 1
expected_total = dgd_count * dp_size
actual_total = len(fpm_stats)
if actual_total != expected_total:
logger.warning(
f"Incomplete FPM coverage for {label}: expected "
f"{dgd_count} workers × {dp_size} dp_ranks = {expected_total}, "
f"got {actual_total}. Skipping scaling."
)
return False
if dp_size > 1:
logger.info(
f"FPM {label}: {fpm_worker_count} workers × {dp_size} dp_ranks "
f"= {actual_total} engines"
)
return True
@staticmethod
def _log_fpm(wid: str, dp: int, fpm: "ForwardPassMetrics", label: str) -> None:
sched = fpm.scheduled_requests
queued = fpm.queued_requests
logger.info(
f"FPM {label} engine {wid}:dp{dp}: "
f"wall_time={fpm.wall_time:.4f}s, "
f"sched(prefill_tok={sched.sum_prefill_tokens}, "
f"prefill_req={sched.num_prefill_requests}, "
f"decode_kv={sched.sum_decode_kv_tokens}, "
f"decode_req={sched.num_decode_requests}), "
f"queued(prefill_tok={queued.sum_prefill_tokens}, "
f"decode_kv={queued.sum_decode_kv_tokens})"
return TrafficObservation(
duration_s=self.config.throughput_adjustment_interval,
num_req=m.num_req,
isl=m.isl,
osl=m.osl,
)
def observe_fpm_load_stats(
self,
) -> "dict[tuple[str, int], ForwardPassMetrics]":
"""Get latest FPM stats and feed observations into the regression model.
Returns:
The decoded FPM stats dict for use by load_plan_adjustment().
"""
fpm_stats = self._get_fpm_stats()
if not fpm_stats:
logger.warning(
f"No FPM data available for {self.component_type.value} (tracker empty)"
)
return {}
for (wid, dp), fpm in fpm_stats.items():
self._log_fpm(wid, dp, fpm, self.component_type.value)
if self.component_type == SubComponentType.PREFILL:
self.ttft_regression.add_observation(fpm)
elif self.component_type == SubComponentType.DECODE:
self.itl_regression.add_observation(fpm)
logger.info(
f"FPM load stats: {len(fpm_stats)} engines observed for "
f"{self.component_type.value}"
def _collect_fpm(self) -> FpmObservations:
"""Collect FPM from active subscribers."""
prefill_stats = None
decode_stats = None
if self._prefill_fpm_sub is not None:
stats = self._decode_fpm_bytes(self._prefill_fpm_sub)
if stats:
for (wid, dp), fpm in stats.items():
_log_fpm(wid, dp, fpm, "prefill")
prefill_stats = stats
if self._decode_fpm_sub is not None:
stats = self._decode_fpm_bytes(self._decode_fpm_sub)
if stats:
for (wid, dp), fpm in stats.items():
_log_fpm(wid, dp, fpm, "decode")
decode_stats = stats
return FpmObservations(prefill=prefill_stats, decode=decode_stats)
async def _collect_worker_counts(self) -> WorkerCounts:
num_p, num_d, is_stable = await self._get_worker_counts_raw()
return WorkerCounts(
ready_num_prefill=num_p if self.require_prefill else None,
ready_num_decode=num_d if self.require_decode else None,
expected_num_prefill=(num_p if is_stable else None)
if self.require_prefill
else None,
expected_num_decode=(num_d if is_stable else None)
if self.require_decode
else None,
)
return fpm_stats
def _load_based_scaling_decision_from_estimates(
self,
estimates: list[float],
sla: float,
num_workers: int,
label: str,
) -> Optional[int]:
"""Shared scale-up/down logic from per-engine latency estimates (ms).
Args:
estimates: per-engine estimated latencies in ms.
sla: target SLA in ms (e.g. config.ttft or config.itl).
num_workers: current worker count for this component.
label: human-readable label for log messages (e.g. "prefill TTFT").
Returns:
Desired replica count, or None if no scaling action needed.
"""
if not estimates:
return None
sensitivity = self.config.load_scaling_down_sensitivity / 100.0
logger.info(
f"Load-based {label}: workers={num_workers}, sla={sla:.1f}ms, "
f"estimates={[f'{t:.1f}' for t in estimates]}"
# ------------------------------------------------------------------
# Gather tick input
# ------------------------------------------------------------------
async def _gather_tick_input(self, tick: ScheduledTick) -> TickInput:
now = time.time()
traffic = None
worker_counts = None
fpm_obs = None
if tick.need_traffic_metrics:
traffic = await self._collect_traffic()
if tick.need_worker_states:
worker_counts = await self._collect_worker_counts()
if tick.need_worker_fpm:
fpm_obs = self._collect_fpm()
return TickInput(
now_s=now,
traffic=traffic,
worker_counts=worker_counts,
fpm_observations=fpm_obs,
)
if all(t > sla for t in estimates):
logger.info(
f"Load-based {label}: ALL engines above SLA ({sla:.1f}ms), "
f"scaling up to {num_workers + 1}"
)
return num_workers + 1
if num_workers > 1:
threshold = sla * sensitivity
if all(t < threshold for t in estimates):
desired = max(num_workers - 1, self.config.min_endpoint)
if desired == num_workers:
logger.info(
f"Load-based {label}: ALL engines below threshold "
f"({threshold:.1f}ms), but at min_endpoint ({self.config.min_endpoint})"
)
else:
logger.info(
f"Load-based {label}: ALL engines below threshold "
f"({threshold:.1f}ms), scaling down to {desired}"
)
return desired
return None
# ------------------------------------------------------------------
# Apply effects (override in subclasses for mode-specific metrics)
# ------------------------------------------------------------------
def load_plan_adjustment(self) -> Optional[int]:
"""Load-based scaling decision. Override in subclasses."""
raise NotImplementedError
async def _apply_effects(self, effects: PlannerEffects) -> None:
"""Override in subclasses to report metrics and apply scaling."""
pass
async def _throughput_loop(
self, require_prefill: bool, require_decode: bool
async def _apply_scaling_targets(
self, targets: list[TargetReplica], blocking: bool = False
) -> None:
"""Throughput-based scaling loop (existing behavior, extracted from run())."""
while True:
current_time = time.time()
if (
current_time - self.shared_state.last_adjustment_time
>= self.config.throughput_adjustment_interval
):
self.shared_state.last_adjustment_time = time.time()
logger.info("New throughput adjustment interval started!")
await self.observe_traffic_stats(
require_prefill=require_prefill, require_decode=require_decode
)
desired_replicas = self.plan_adjustment()
if desired_replicas is not None:
if self.enable_load:
# When load-based is also enabled: just set lower bound
if self.component_type == SubComponentType.PREFILL:
self.shared_state.throughput_lower_bound_p = (
desired_replicas
)
else:
self.shared_state.throughput_lower_bound_d = (
desired_replicas
)
logger.info(
f"Throughput lower bound set to {desired_replicas} for {self.component_type.value}"
)
else:
# Throughput-only: apply scaling directly
desired_replicas = self.apply_component_budget(desired_replicas)
self.update_predicted_replicas_metric(desired_replicas)
# Throughput planner does not needs blocking scaling because it monitors
# and predicts the load, not relying on the current status of the engine.
await self._apply_scaling(desired_replicas)
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
async def _load_and_fpm_update_loop(
self, require_prefill: bool, require_decode: bool
) -> None:
"""FPM observation and (optionally) load-based scaling loop.
Runs every load_adjustment_interval. Always updates the FPM
regression model with live observations. When load-based scaling
is enabled, also makes scaling decisions immediately after the
FPM update.
"""
pending_desired: Optional[int] = None
while True:
await asyncio.sleep(self.config.load_adjustment_interval)
logger.info("New load/FPM update interval started!")
num_p, num_d, is_stable = await self.get_workers_info(
require_prefill=require_prefill, require_decode=require_decode
)
self.shared_state.num_p_workers = num_p
self.shared_state.num_d_workers = num_d
fpm_stats = self.observe_fpm_load_stats()
if not fpm_stats:
continue
"""Shared helper: send scaling targets to connector."""
if self.config.no_operation or not targets:
return
await self.connector.set_component_replicas(targets, blocking=blocking)
if not self.enable_load:
continue
# ------------------------------------------------------------------
# Main loop
# ------------------------------------------------------------------
if pending_desired is not None:
dgd_count = (
num_p if self.component_type == SubComponentType.PREFILL else num_d
)
if dgd_count == pending_desired:
logger.info(
f"Scaling to {pending_desired} complete, resuming decisions"
)
pending_desired = None
else:
logger.info(
f"Scaling in progress ({dgd_count} -> {pending_desired}), "
"observing only"
)
continue
async def run(self) -> None:
next_tick = self.state_machine.initial_tick(time.time())
poll_interval = self.config.load_adjustment_interval / 10
dgd_count = (
num_p if self.component_type == SubComponentType.PREFILL else num_d
)
if not self._reconcile_fpm_worker_count(
fpm_stats, dgd_count, self.component_type.value
):
while True:
now = time.time()
if now < next_tick.at_s:
await asyncio.sleep(min(next_tick.at_s - now, poll_interval))
continue
desired_replicas = self.load_plan_adjustment()
if desired_replicas is not None:
if self.enable_throughput:
if self.component_type == SubComponentType.PREFILL:
lower_bound = self.shared_state.throughput_lower_bound_p
else:
lower_bound = self.shared_state.throughput_lower_bound_d
desired_replicas = max(desired_replicas, lower_bound)
desired_replicas = self.apply_component_budget(desired_replicas)
self.update_predicted_replicas_metric(desired_replicas)
pending_desired = desired_replicas
await self._apply_scaling_blocking(desired_replicas)
async def run(self):
"""Main scaling loop. Call _async_init() before this."""
require_prefill = self.component_type == SubComponentType.PREFILL
require_decode = self.component_type == SubComponentType.DECODE
self.shared_state.last_adjustment_time = time.time()
self.shared_state.last_load_adjustment_time = time.time()
loops = []
if self.enable_throughput:
loops.append(self._throughput_loop(require_prefill, require_decode))
loops.append(self._load_and_fpm_update_loop(require_prefill, require_decode))
await asyncio.gather(*loops)
tick_input = await self._gather_tick_input(next_tick)
effects = self.state_machine.on_tick(next_tick, tick_input)
await self._apply_effects(effects)
assert effects.next_tick is not None
next_tick = effects.next_tick
# ------------------------------------------------------------------
# Shared utility
# ------------------------------------------------------------------
def _log_fpm(wid: str, dp: int, fpm: ForwardPassMetrics, label: str) -> None:
sched = fpm.scheduled_requests
queued = fpm.queued_requests
logger.info(
f"FPM {label} engine {wid}:dp{dp}: "
f"wall_time={fpm.wall_time:.4f}s, "
f"sched(prefill_tok={sched.sum_prefill_tokens}, "
f"prefill_req={sched.num_prefill_requests}, "
f"decode_kv={sched.sum_decode_kv_tokens}, "
f"decode_req={sched.num_decode_requests}), "
f"queued(prefill_tok={queued.sum_prefill_tokens}, "
f"decode_kv={queued.sum_decode_kv_tokens})"
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import math
from typing import Optional
from dynamo.planner.config.defaults import SubComponentType
from dynamo.planner.core.base import BasePlanner
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class DecodePlanner(BasePlanner):
component_type = SubComponentType.DECODE
def load_plan_adjustment(self) -> Optional[int]:
"""Load-based scaling decision for decode using FPM data.
For each engine, estimates next decode ITL:
- Uses scheduled + queued decode KV tokens + avg decode length
- Predicts wall time via regression
Scale up if ALL engines' estimated ITL > SLA.
Scale down if ALL engines' estimated ITL < SLA * sensitivity.
"""
if not self.itl_regression.has_sufficient_data():
logger.info(
f"ITL regression: insufficient data ({self.itl_regression.num_observations}"
f"/{self.itl_regression.min_observations}), skipping load-based scaling"
)
return None
fpm_stats = self._get_fpm_stats()
if not fpm_stats:
return None
num_workers = self.shared_state.num_d_workers
if num_workers == 0:
return None
estimated_itls: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
scheduled_kv = fpm.scheduled_requests.sum_decode_kv_tokens
queued_kv = fpm.queued_requests.sum_decode_kv_tokens
est = self.itl_regression.estimate_next_itl(
scheduled_decode_kv=scheduled_kv,
queued_decode_kv=queued_kv,
)
if est is None:
continue
est_ms = est * 1000
estimated_itls.append(est_ms)
logger.info(
f"Decode engine {wid}:dp{dp}: estimated ITL {est_ms:.2f}ms "
f"(sched_kv={scheduled_kv}, queued_kv={queued_kv}, "
f"avg_decode_len={self.itl_regression.avg_decode_length:.1f})"
)
return self._load_based_scaling_decision_from_estimates(
estimates=estimated_itls,
sla=self.config.itl,
num_workers=num_workers,
label="decode ITL",
)
def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float
) -> Optional[int]:
demand_rps = next_num_req / self.config.throughput_adjustment_interval
engine_rps, actual_itl_ms = self.itl_regression.find_best_engine_decode_rps(
itl=self.config.itl,
context_length=next_isl + next_osl / 2,
osl=next_osl,
)
if engine_rps <= 0:
logger.warning("Decode perf model not ready, skipping throughput scaling")
return None
if actual_itl_ms > self.config.itl:
logger.warning(
f"Decode ITL SLA not met: {actual_itl_ms:.1f}ms > "
f"{self.config.itl:.1f}ms, scaling with best achievable rate"
)
next_num_d = math.ceil(demand_rps / engine_rps)
next_num_d = max(next_num_d, self.config.min_endpoint)
logger.info(
f"Decode: {demand_rps:.2f}(demand rps) / "
f"{engine_rps:.2f}(engine rps) = {next_num_d}(num_d), "
f"est_itl={actual_itl_ms:.1f}ms"
)
return next_num_d
def update_predicted_replicas_metric(self, desired_replicas: int) -> None:
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.predicted_num_d.set(desired_replicas)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import time
from dynamo.planner.config.backend_components import WORKER_COMPONENT_NAMES
from dynamo.planner.config.defaults import SubComponentType, TargetReplica
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.base import BasePlanner
from dynamo.planner.core.budget import _apply_global_gpu_budget, _initialize_gpu_counts
from dynamo.planner.core.decode import DecodePlanner
from dynamo.planner.core.prefill import PrefillPlanner
from dynamo.planner.core.state import PlannerSharedState
from dynamo.planner.monitoring.planner_metrics import PlannerPrometheusMetrics
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class DisaggPlanner:
def __init__(self, runtime: DistributedRuntime, config: PlannerConfig) -> None:
self.config = config
self.shared_state = PlannerSharedState()
prometheus_metrics = PlannerPrometheusMetrics()
self.enable_throughput = config.enable_throughput_scaling
self.enable_load = config.enable_load_scaling
self.prefill_planner = PrefillPlanner(
runtime,
config,
shared_state=self.shared_state,
prometheus_metrics=prometheus_metrics,
start_prometheus_server=True,
)
self.decode_planner = DecodePlanner(
runtime,
config,
shared_state=self.shared_state,
prometheus_metrics=prometheus_metrics,
prometheus_traffic_client=getattr(
self.prefill_planner, "prometheus_traffic_client", None
),
connector=getattr(self.prefill_planner, "connector", None),
start_prometheus_server=False,
)
async def _async_init(self):
# DisaggPlanner overrides _async_init to handle both prefill+decode
# and share WorkerInfo between the two sub-planners.
defaults = WORKER_COMPONENT_NAMES.get(self.config.backend)
if not self.config.no_operation:
# Connector init (prefill/decode share the same connector)
connector = getattr(self.prefill_planner, "connector", None)
if connector and hasattr(connector, "_async_init"):
await connector._async_init()
logger.info("Validating deployment...")
await self.prefill_planner.connector.validate_deployment(
prefill_component_name=(
defaults.prefill_worker_k8s_name if defaults else None
),
decode_component_name=(
defaults.decode_worker_k8s_name if defaults else None
),
require_prefill=True,
require_decode=True,
)
logger.info("Successfully validated the deployment")
_initialize_gpu_counts(
self.config,
self.prefill_planner.connector,
require_prefill=True,
require_decode=True,
)
await self.prefill_planner.connector.wait_for_deployment_ready(
include_planner=False
)
await self.prefill_planner._init_worker_info(
require_prefill=True, require_decode=True
)
# Share WorkerInfo and model name with decode planner
self.decode_planner.prefill_worker_info = (
self.prefill_planner.prefill_worker_info
)
self.decode_planner.decode_worker_info = self.prefill_planner.decode_worker_info
self.decode_planner.model_name = self.prefill_planner.model_name
if self.prefill_planner.runtime is not None:
await self.prefill_planner._init_fpm_subscriber()
if self.decode_planner.runtime is not None:
await self.decode_planner._init_fpm_subscriber()
await self.prefill_planner._bootstrap_regression()
await self.decode_planner._bootstrap_regression()
async def run(self):
"""Main scaling loop. Call _async_init() before this."""
self.shared_state.last_adjustment_time = time.time()
self.shared_state.last_load_adjustment_time = time.time()
loops = []
if self.enable_throughput:
loops.append(self._throughput_loop())
loops.append(self._load_and_fpm_update_loop())
await asyncio.gather(*loops)
async def _throughput_loop(self) -> None:
"""Throughput-based scaling loop for disagg mode."""
while True:
current_time = time.time()
if (
current_time - self.shared_state.last_adjustment_time
>= self.config.throughput_adjustment_interval
):
self.shared_state.last_adjustment_time = time.time()
logger.info("New throughput adjustment interval started!")
await self.prefill_planner.observe_traffic_stats(
require_prefill=True, require_decode=True
)
self.decode_planner.update_predictors_from_metrics(
self.shared_state.last_metrics
)
next_num_p = self.prefill_planner.plan_adjustment()
next_num_d = self.decode_planner.plan_adjustment()
if next_num_p is None or next_num_d is None:
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
continue
if self.enable_load:
# When load-based is also enabled: just set lower bounds
self.shared_state.throughput_lower_bound_p = next_num_p
self.shared_state.throughput_lower_bound_d = next_num_d
logger.info(
f"Throughput lower bounds set: prefill={next_num_p}, decode={next_num_d}"
)
else:
# Throughput-only: apply scaling directly
next_num_p, next_num_d = _apply_global_gpu_budget(
next_num_p, next_num_d, self.config
)
self.prefill_planner.update_predicted_replicas_metric(next_num_p)
self.decode_planner.update_predicted_replicas_metric(next_num_d)
if not self.config.no_operation:
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_planner.prefill_worker_info.k8s_name,
desired_replicas=next_num_p,
),
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.prefill_planner.decode_worker_info.k8s_name,
desired_replicas=next_num_d,
),
]
await self.prefill_planner.connector.set_component_replicas(
target_replicas, blocking=False
)
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
async def _load_and_fpm_update_loop(self) -> None:
"""FPM observation and (optionally) load-based scaling for disagg mode.
Always updates regression models with live FPM. When load-based
scaling is enabled, makes scaling decisions immediately after.
"""
while True:
await asyncio.sleep(self.config.load_adjustment_interval)
logger.info("New load/FPM update interval started!")
num_p, num_d, _ = await self.prefill_planner.get_workers_info(
require_prefill=True, require_decode=True
)
self.shared_state.num_p_workers = num_p
self.shared_state.num_d_workers = num_d
p_stats = self.prefill_planner.observe_fpm_load_stats()
d_stats = self.decode_planner.observe_fpm_load_stats()
if not self.enable_load:
continue
if not p_stats and not d_stats:
logger.warning("No FPM data for either prefill or decode, skipping")
continue
if p_stats and not BasePlanner._reconcile_fpm_worker_count(
p_stats, num_p, "prefill"
):
continue
if d_stats and not BasePlanner._reconcile_fpm_worker_count(
d_stats, num_d, "decode"
):
continue
p_desired = self.prefill_planner.load_plan_adjustment()
d_desired = self.decode_planner.load_plan_adjustment()
final_p = (
p_desired if p_desired is not None else self.shared_state.num_p_workers
)
final_d = (
d_desired if d_desired is not None else self.shared_state.num_d_workers
)
if (
final_p == self.shared_state.num_p_workers
and final_d == self.shared_state.num_d_workers
):
logger.info("Load-based scaling: no scaling needed")
continue
if self.enable_throughput:
final_p = max(final_p, self.shared_state.throughput_lower_bound_p)
final_d = max(final_d, self.shared_state.throughput_lower_bound_d)
final_p = max(final_p, self.config.min_endpoint)
final_d = max(final_d, self.config.min_endpoint)
final_p, final_d = _apply_global_gpu_budget(final_p, final_d, self.config)
logger.info(
f"Load-based disagg scaling: prefill {self.shared_state.num_p_workers}->{final_p}, "
f"decode {self.shared_state.num_d_workers}->{final_d}"
)
self.prefill_planner.update_predicted_replicas_metric(final_p)
self.decode_planner.update_predicted_replicas_metric(final_d)
if not self.config.no_operation:
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_planner.prefill_worker_info.k8s_name,
desired_replicas=final_p,
),
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.prefill_planner.decode_worker_info.k8s_name,
desired_replicas=final_d,
),
]
await self.prefill_planner.connector.set_component_replicas(
target_replicas, blocking=True
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# mypy: disable-error-code="attr-defined"
"""Load-based scaling logic (FPM-driven, reactive).
Mixin consumed by ``PlannerStateMachine``. All methods access state
via ``self._config``, ``self._capabilities``, and regression models.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Optional
from dynamo.planner.core.types import FpmObservations, ScalingDecision
if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
logger = logging.getLogger(__name__)
class LoadScalingMixin:
"""FPM-driven load-based scaling decisions."""
def _advance_load(self, obs: FpmObservations) -> Optional[ScalingDecision]:
if not self._config.enable_load_scaling:
return None
mode = self._config.mode
if mode == "agg":
return self._advance_load_agg(obs)
if mode == "disagg":
return self._advance_load_disagg(obs)
return self._advance_load_single(obs, mode)
def _advance_load_single(
self, obs: FpmObservations, component: str
) -> Optional[ScalingDecision]:
if self._scaling_in_progress(component):
logger.info(f"Scaling in progress for {component}, observing only")
return None
fpm_stats = obs.prefill if component == "prefill" else obs.decode
num_workers = (
self._num_p_workers if component == "prefill" else self._num_d_workers
)
if not fpm_stats:
return None
if not self._reconcile_fpm_worker_count(fpm_stats, num_workers, component):
return None
desired = (
self._prefill_load_decision(fpm_stats, num_workers)
if component == "prefill"
else self._decode_load_decision(fpm_stats, num_workers)
)
if desired is None:
return None
if self._config.enable_throughput_scaling:
bound = (
self._throughput_lower_bound_p
if component == "prefill"
else self._throughput_lower_bound_d
)
desired = max(desired, bound)
desired = self._apply_single_budget(desired, component)
return (
ScalingDecision(num_prefill=desired)
if component == "prefill"
else ScalingDecision(num_decode=desired)
)
def _advance_load_disagg(self, obs: FpmObservations) -> Optional[ScalingDecision]:
p_stats, d_stats = obs.prefill, obs.decode
if not p_stats and not d_stats:
logger.warning("No FPM data for either prefill or decode, skipping")
return None
if p_stats and not self._reconcile_fpm_worker_count(
p_stats, self._num_p_workers, "prefill"
):
return None
if d_stats and not self._reconcile_fpm_worker_count(
d_stats, self._num_d_workers, "decode"
):
return None
p_desired = (
self._prefill_load_decision(p_stats, self._num_p_workers)
if p_stats
else None
)
d_desired = (
self._decode_load_decision(d_stats, self._num_d_workers)
if d_stats
else None
)
final_p = p_desired if p_desired is not None else self._num_p_workers
final_d = d_desired if d_desired is not None else self._num_d_workers
if final_p == self._num_p_workers and final_d == self._num_d_workers:
logger.info("Load-based scaling: no scaling needed")
return None
if self._config.enable_throughput_scaling:
final_p = max(final_p, self._throughput_lower_bound_p)
final_d = max(final_d, self._throughput_lower_bound_d)
final_p = max(final_p, self._config.min_endpoint)
final_d = max(final_d, self._config.min_endpoint)
final_p, final_d = self._apply_global_budget(final_p, final_d)
logger.info(
f"Load-based disagg scaling: prefill {self._num_p_workers}->{final_p}, "
f"decode {self._num_d_workers}->{final_d}"
)
return ScalingDecision(num_prefill=final_p, num_decode=final_d)
def _advance_load_agg(self, obs: FpmObservations) -> Optional[ScalingDecision]:
fpm_stats = obs.decode
if not fpm_stats:
return None
num_workers = self._num_d_workers
if self._scaling_in_progress("decode"):
logger.info(
f"Scaling in progress ({num_workers} -> {self._expected_num_d}), observing only"
)
return None
if not self._reconcile_fpm_worker_count(fpm_stats, num_workers, "agg"):
return None
if not self._agg_regression.has_sufficient_data():
logger.info(
f"Agg regression: insufficient data "
f"({self._agg_regression.num_observations}/{self._agg_regression.min_observations})"
)
return None
d_caps = self._capabilities.decode
max_tokens = d_caps.max_num_batched_tokens if d_caps else None
if not max_tokens or max_tokens <= 0:
logger.warning("max_num_batched_tokens not available, skipping agg scaling")
return None
p_desired = self._agg_prefill_scaling(fpm_stats, num_workers, max_tokens)
d_desired = self._agg_decode_scaling(fpm_stats, num_workers)
logger.info(
f"Agg scaling decisions: prefill={p_desired}, decode={d_desired} (current={num_workers})"
)
if p_desired is not None and p_desired > num_workers:
desired = p_desired
elif d_desired is not None and d_desired > num_workers:
desired = d_desired
elif (
p_desired is not None
and p_desired < num_workers
and d_desired is not None
and d_desired < num_workers
):
desired = max(p_desired, d_desired)
else:
logger.info("Agg scaling: no scaling needed")
return None
desired = max(desired, self._config.min_endpoint)
if self._config.enable_throughput_scaling:
desired = max(desired, self._throughput_lower_bound_d)
desired = self._apply_single_budget(desired, "decode")
logger.info(f"Agg load-based scaling: {num_workers} -> {desired}")
return ScalingDecision(num_decode=desired)
# ------------------------------------------------------------------
# Per-engine latency estimation
# ------------------------------------------------------------------
def _prefill_load_decision(
self, fpm_stats: dict[tuple[str, int], ForwardPassMetrics], num_workers: int
) -> Optional[int]:
if not self._prefill_regression.has_sufficient_data():
logger.info(
f"TTFT regression: insufficient data "
f"({self._prefill_regression.num_observations}/{self._prefill_regression.min_observations})"
)
return None
if num_workers == 0:
return None
p_caps = self._capabilities.prefill
max_tokens = p_caps.max_num_batched_tokens if p_caps else None
if not max_tokens or max_tokens <= 0:
logger.warning(
"max_num_batched_tokens not available, skipping prefill load scaling"
)
return None
estimates: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
est = self._prefill_regression.estimate_next_ttft(
queued_prefill_tokens=fpm.queued_requests.sum_prefill_tokens,
max_num_batched_tokens=max_tokens,
)
if est is not None:
est_ms = est * 1000
estimates.append(est_ms)
logger.info(
f"Prefill engine {wid}:dp{dp}: estimated TTFT {est_ms:.2f}ms "
f"(queued={fpm.queued_requests.sum_prefill_tokens}, "
f"avg_isl={self._prefill_regression.avg_isl:.1f})"
)
return self._scale_decision(
estimates, self._config.ttft, num_workers, "prefill TTFT"
)
def _decode_load_decision(
self, fpm_stats: dict[tuple[str, int], ForwardPassMetrics], num_workers: int
) -> Optional[int]:
if not self._decode_regression.has_sufficient_data():
logger.info(
f"ITL regression: insufficient data "
f"({self._decode_regression.num_observations}/{self._decode_regression.min_observations})"
)
return None
if num_workers == 0:
return None
estimates: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
est = self._decode_regression.estimate_next_itl(
scheduled_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
queued_decode_kv=fpm.queued_requests.sum_decode_kv_tokens,
)
if est is not None:
est_ms = est * 1000
estimates.append(est_ms)
logger.info(
f"Decode engine {wid}:dp{dp}: estimated ITL {est_ms:.2f}ms "
f"(sched_kv={fpm.scheduled_requests.sum_decode_kv_tokens}, "
f"queued_kv={fpm.queued_requests.sum_decode_kv_tokens})"
)
return self._scale_decision(
estimates, self._config.itl, num_workers, "decode ITL"
)
def _agg_prefill_scaling(
self,
fpm_stats: dict[tuple[str, int], ForwardPassMetrics],
num_workers: int,
max_tokens: int,
) -> Optional[int]:
estimates: list[float] = []
for fpm in fpm_stats.values():
est = self._agg_regression.estimate_next_ttft(
queued_prefill_tokens=fpm.queued_requests.sum_prefill_tokens,
max_num_batched_tokens=max_tokens,
current_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
)
if est is not None:
estimates.append(est * 1000)
return self._scale_decision(
estimates, self._config.ttft, num_workers, "agg TTFT"
)
def _agg_decode_scaling(
self,
fpm_stats: dict[tuple[str, int], ForwardPassMetrics],
num_workers: int,
) -> Optional[int]:
estimates: list[float] = []
for fpm in fpm_stats.values():
est = self._agg_regression.estimate_next_itl(
scheduled_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
queued_decode_kv=fpm.queued_requests.sum_decode_kv_tokens,
)
if est is not None:
estimates.append(est * 1000)
return self._scale_decision(estimates, self._config.itl, num_workers, "agg ITL")
def _scale_decision(
self, estimates: list[float], sla: float, num_workers: int, label: str
) -> Optional[int]:
if not estimates:
return None
sensitivity = self._config.load_scaling_down_sensitivity / 100.0
logger.info(
f"Load-based {label}: workers={num_workers}, sla={sla:.1f}ms, "
f"estimates={[f'{t:.1f}' for t in estimates]}"
)
if all(t > sla for t in estimates):
logger.info(
f"Load-based {label}: ALL above SLA, scaling up to {num_workers + 1}"
)
return num_workers + 1
if num_workers > 1:
threshold = sla * sensitivity
if all(t < threshold for t in estimates):
desired = max(num_workers - 1, self._config.min_endpoint)
logger.info(
f"Load-based {label}: ALL below threshold ({threshold:.1f}ms), -> {desired}"
)
return desired
return None
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import math
from typing import Optional
from dynamo.planner.config.defaults import SubComponentType
from dynamo.planner.core.base import BasePlanner
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class PrefillPlanner(BasePlanner):
component_type = SubComponentType.PREFILL
def load_plan_adjustment(self) -> Optional[int]:
"""Load-based scaling decision for prefill using FPM data.
For each engine, simulates prefill scheduling to estimate next TTFT:
- Uses queued prefill tokens + avg ISL as total tokens to process
- Chunks into max_num_batched_tokens-sized iterations
- Sums regression-predicted wall time per chunk
Scale up if ALL engines' estimated TTFT > SLA.
Scale down if ALL engines' estimated TTFT < SLA * sensitivity.
"""
if not self.ttft_regression.has_sufficient_data():
logger.info(
f"TTFT regression: insufficient data ({self.ttft_regression.num_observations}"
f"/{self.ttft_regression.min_observations}), skipping load-based scaling"
)
return None
fpm_stats = self._get_fpm_stats()
if not fpm_stats:
return None
num_workers = self.shared_state.num_p_workers
if num_workers == 0:
return None
max_num_batched_tokens = getattr(
self.prefill_worker_info, "max_num_batched_tokens", None
)
if not max_num_batched_tokens or max_num_batched_tokens <= 0:
logger.warning(
"max_num_batched_tokens not available from WorkerInfo, "
"skipping prefill load-based scaling"
)
return None
estimated_ttfts: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
queued_prefill = fpm.queued_requests.sum_prefill_tokens
est = self.ttft_regression.estimate_next_ttft(
queued_prefill_tokens=queued_prefill,
max_num_batched_tokens=max_num_batched_tokens,
)
if est is None:
continue
est_ms = est * 1000
estimated_ttfts.append(est_ms)
logger.info(
f"Prefill engine {wid}:dp{dp}: estimated TTFT {est_ms:.2f}ms "
f"(queued_prefill={queued_prefill}, avg_isl={self.ttft_regression.avg_isl:.1f})"
)
return self._load_based_scaling_decision_from_estimates(
estimates=estimated_ttfts,
sla=self.config.ttft,
num_workers=num_workers,
label="prefill TTFT",
)
def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float
) -> Optional[int]:
demand_rps = next_num_req / self.config.throughput_adjustment_interval
engine_rps, actual_ttft_ms = self.ttft_regression.find_best_engine_prefill_rps(
ttft_sla=self.config.ttft, isl=next_isl
)
if engine_rps <= 0:
logger.warning("Prefill perf model not ready, skipping throughput scaling")
return None
if actual_ttft_ms > self.config.ttft:
logger.warning(
f"Prefill TTFT SLA not met: {actual_ttft_ms:.1f}ms > "
f"{self.config.ttft:.1f}ms, scaling with best achievable rate"
)
next_num_p = math.ceil(demand_rps / engine_rps)
next_num_p = max(next_num_p, self.config.min_endpoint)
logger.info(
f"Prefill: {demand_rps:.2f}(demand rps) / "
f"{engine_rps:.2f}(engine rps) = {next_num_p}(num_p), "
f"est_ttft={actual_ttft_ms:.1f}ms"
)
return next_num_p
def update_predicted_replicas_metric(self, desired_replicas: int) -> None:
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.predicted_num_p.set(desired_replicas)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Pure discrete-event state machine for planner scaling decisions.
``PlannerStateMachine`` receives events (``ScheduledTick`` + ``TickInput``),
updates internal state (regression models, load predictors, worker inventory),
and returns effects (``PlannerEffects``: optional scaling decision + next tick).
This module contains **zero I/O** -- no runtime, connector, subscriber, asyncio,
or Prometheus dependencies. All external interaction is done by the adapter
layer (``NativePlannerBase`` and its subclasses) which feeds data in and
applies decisions out.
Load-based scaling logic lives in ``load_scaling.py``.
Throughput-based scaling logic lives in ``throughput_scaling.py``.
"""
from __future__ import annotations
import logging
import math
from typing import TYPE_CHECKING, Optional
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.load.predictors import LOAD_PREDICTORS
from dynamo.planner.core.load_scaling import LoadScalingMixin
from dynamo.planner.core.perf_model import (
AggRegressionModel,
DecodeRegressionModel,
PrefillRegressionModel,
)
from dynamo.planner.core.throughput_scaling import ThroughputScalingMixin
from dynamo.planner.core.types import (
FpmObservations,
PlannerEffects,
ScheduledTick,
TickInput,
TrafficObservation,
WorkerCapabilities,
WorkerCounts,
)
if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
logger = logging.getLogger(__name__)
class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
"""Discrete-event state machine for all planner modes.
Owns regression models, load predictors, throughput lower bounds,
and all scaling decision logic. Receives events, returns effects.
Has no runtime dependencies.
"""
def __init__(
self,
config: PlannerConfig,
capabilities: Optional[WorkerCapabilities] = None,
) -> None:
self._config = config
self._capabilities = capabilities or WorkerCapabilities()
self._is_agg = config.mode == "agg"
self._has_prefill = config.mode in ("disagg", "prefill")
self._has_decode = config.mode in ("disagg", "decode", "agg")
if self._is_agg:
self._agg_regression = AggRegressionModel(
max_num_fpm_samples=config.max_num_fpm_samples,
min_observations=config.load_min_observations,
bucket_count=config.fpm_sample_bucket_size,
)
else:
if self._has_prefill:
self._prefill_regression = PrefillRegressionModel(
max_num_fpm_samples=config.max_num_fpm_samples,
min_observations=config.load_min_observations,
bucket_count=config.fpm_sample_bucket_size,
)
if self._has_decode:
self._decode_regression = DecodeRegressionModel(
max_num_fpm_samples=config.max_num_fpm_samples,
min_observations=config.load_min_observations,
bucket_count=config.fpm_sample_bucket_size,
)
predictor_cls = LOAD_PREDICTORS[config.load_predictor]
self._num_req_predictor = predictor_cls(config)
self._isl_predictor = predictor_cls(config)
self._osl_predictor = predictor_cls(config)
self._num_p_workers: int = 0
self._num_d_workers: int = 0
self._expected_num_p: Optional[int] = None
self._expected_num_d: Optional[int] = None
self._throughput_lower_bound_p: int = 1
self._throughput_lower_bound_d: int = 1
self._next_load_s: float = float("inf")
self._next_throughput_s: float = float("inf")
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def initial_tick(self, start_s: float) -> ScheduledTick:
self._next_load_s = start_s + self._config.load_adjustment_interval
if self._config.enable_throughput_scaling:
self._next_throughput_s = (
start_s + self._config.throughput_adjustment_interval
)
return self._next_scheduled_tick()
def load_benchmark_fpms(
self,
prefill_fpms: Optional[list[ForwardPassMetrics]] = None,
decode_fpms: Optional[list[ForwardPassMetrics]] = None,
agg_fpms: Optional[list[ForwardPassMetrics]] = None,
) -> None:
if agg_fpms and self._is_agg:
self._agg_regression.load_benchmark_fpms(agg_fpms)
logger.info(f"Bootstrapped agg regression with {len(agg_fpms)} FPMs")
if prefill_fpms and self._has_prefill and not self._is_agg:
self._prefill_regression.load_benchmark_fpms(prefill_fpms)
logger.info(
f"Bootstrapped prefill regression with {len(prefill_fpms)} FPMs"
)
if decode_fpms and self._has_decode and not self._is_agg:
self._decode_regression.load_benchmark_fpms(decode_fpms)
logger.info(f"Bootstrapped decode regression with {len(decode_fpms)} FPMs")
def warm_load_predictors(self, observations: list[TrafficObservation]) -> None:
for obs in observations:
self._num_req_predictor.add_data_point(obs.num_req)
self._isl_predictor.add_data_point(obs.isl)
self._osl_predictor.add_data_point(obs.osl)
logger.info(f"Warmed load predictors with {len(observations)} intervals")
for p in (self._num_req_predictor, self._isl_predictor, self._osl_predictor):
if hasattr(p, "reset_idle_skip"):
p.reset_idle_skip()
def on_tick(self, tick: ScheduledTick, tick_input: TickInput) -> PlannerEffects:
effects = PlannerEffects()
if tick_input.worker_counts is not None:
self._update_inventory(tick_input.worker_counts)
if tick.run_load_scaling:
if tick_input.fpm_observations is not None:
self._observe_fpm(tick_input.fpm_observations)
load_decision = self._advance_load(tick_input.fpm_observations)
if load_decision is not None:
effects.scale_to = load_decision
self._next_load_s = tick_input.now_s + self._config.load_adjustment_interval
if tick.run_throughput_scaling:
if tick_input.traffic is not None:
self._observe_traffic(tick_input.traffic)
throughput_decision = self._advance_throughput(tick_input.traffic)
if throughput_decision is not None:
if effects.scale_to is None:
effects.scale_to = throughput_decision
self._next_throughput_s = (
tick_input.now_s + self._config.throughput_adjustment_interval
)
effects.next_tick = self._next_scheduled_tick()
return effects
# ------------------------------------------------------------------
# Tick scheduling
# ------------------------------------------------------------------
_MERGE_TOLERANCE_S = 0.5
def _next_scheduled_tick(self) -> ScheduledTick:
"""Build the single next tick, merging cadences if they coincide."""
at_s = min(self._next_load_s, self._next_throughput_s)
is_load = self._next_load_s <= at_s + self._MERGE_TOLERANCE_S
is_throughput = self._next_throughput_s <= at_s + self._MERGE_TOLERANCE_S
return ScheduledTick(
at_s=at_s,
run_load_scaling=is_load,
run_throughput_scaling=is_throughput,
need_worker_states=True,
need_worker_fpm=is_load,
need_traffic_metrics=is_throughput,
traffic_metrics_duration_s=(
self._config.throughput_adjustment_interval if is_throughput else 0.0
),
)
# ------------------------------------------------------------------
# Inventory
# ------------------------------------------------------------------
def _update_inventory(self, counts: WorkerCounts) -> None:
if counts.ready_num_prefill is not None:
self._num_p_workers = counts.ready_num_prefill
if counts.ready_num_decode is not None:
self._num_d_workers = counts.ready_num_decode
self._expected_num_p = counts.expected_num_prefill
self._expected_num_d = counts.expected_num_decode
def _scaling_in_progress(self, component: str) -> bool:
if component == "prefill":
return (
self._expected_num_p is not None
and self._expected_num_p != self._num_p_workers
)
return (
self._expected_num_d is not None
and self._expected_num_d != self._num_d_workers
)
# ------------------------------------------------------------------
# FPM / traffic observation
# ------------------------------------------------------------------
def _observe_fpm(self, obs: FpmObservations) -> None:
if self._is_agg:
if obs.decode:
for fpm in obs.decode.values():
self._agg_regression.add_observation(fpm)
logger.info(f"FPM load stats: {len(obs.decode)} agg engines observed")
return
if obs.prefill and self._has_prefill:
for fpm in obs.prefill.values():
self._prefill_regression.add_observation(fpm)
logger.info(f"FPM load stats: {len(obs.prefill)} prefill engines observed")
if obs.decode and self._has_decode:
for fpm in obs.decode.values():
self._decode_regression.add_observation(fpm)
logger.info(f"FPM load stats: {len(obs.decode)} decode engines observed")
def _observe_traffic(self, traffic: TrafficObservation) -> None:
self._num_req_predictor.add_data_point(traffic.num_req)
self._isl_predictor.add_data_point(traffic.isl)
self._osl_predictor.add_data_point(traffic.osl)
# ------------------------------------------------------------------
# Budget
# ------------------------------------------------------------------
def _apply_single_budget(self, desired: int, component: str) -> int:
caps = (
self._capabilities.prefill
if component == "prefill"
else self._capabilities.decode
)
gpu = caps.num_gpu if caps else None
if gpu is None:
return desired
return self._budget_clamp(max(desired, self._config.min_endpoint), gpu)
def _apply_global_budget(self, num_p: int, num_d: int) -> tuple[int, int]:
budget = self._config.max_gpu_budget
p_gpu = (
self._capabilities.prefill.num_gpu if self._capabilities.prefill else None
)
d_gpu = self._capabilities.decode.num_gpu if self._capabilities.decode else None
if budget < 0 or p_gpu is None or d_gpu is None:
return num_p, num_d
total = num_p * p_gpu + num_d * d_gpu
if total <= budget:
return num_p, num_d
min_req = self._config.min_endpoint * p_gpu + self._config.min_endpoint * d_gpu
if budget < min_req:
logger.warning(
f"max_gpu_budget ({budget}) below min ({min_req}); zero replicas"
)
return 0, 0
scale = budget / total
max_p = math.floor((budget - self._config.min_endpoint * d_gpu) / p_gpu)
num_p = max(self._config.min_endpoint, min(max_p, math.floor(num_p * scale)))
remaining = budget - num_p * p_gpu
num_d = max(self._config.min_endpoint, math.floor(remaining / d_gpu))
logger.warning(f"GPUs ({total}) > budget ({budget}), -> {num_p}P + {num_d}D")
return num_p, num_d
def _budget_clamp(self, desired: int, engine_gpu: int) -> int:
budget = self._config.max_gpu_budget
if budget < 0:
return desired
total = desired * engine_gpu
if total <= budget:
return desired
min_req = self._config.min_endpoint * engine_gpu
if budget < min_req:
logger.warning(
f"max_gpu_budget ({budget}) below min ({min_req}); zero replicas"
)
return 0
result = max(self._config.min_endpoint, math.floor(budget / engine_gpu))
logger.warning(f"GPUs ({total}) > budget ({budget}), -> {result} replicas")
return result
# ------------------------------------------------------------------
# FPM / worker count reconciliation
# ------------------------------------------------------------------
@staticmethod
def _reconcile_fpm_worker_count(
fpm_stats: dict[tuple[str, int], ForwardPassMetrics], dgd_count: int, label: str
) -> bool:
workers_to_dp: dict[str, set[int]] = {}
for wid, dp in fpm_stats:
workers_to_dp.setdefault(wid, set()).add(dp)
if len(workers_to_dp) != dgd_count:
logger.warning(
f"Worker count mismatch: DGD={dgd_count}, FPM={len(workers_to_dp)} for {label}"
)
return False
dp_sizes = {len(dps) for dps in workers_to_dp.values()}
if len(dp_sizes) > 1:
logger.warning(f"Inconsistent DP ranks for {label}: {dict(workers_to_dp)}")
return False
dp_size = dp_sizes.pop() if dp_sizes else 1
if len(fpm_stats) != dgd_count * dp_size:
logger.warning(
f"Incomplete FPM coverage for {label}: expected {dgd_count}x{dp_size}, got {len(fpm_stats)}"
)
return False
return True
# ------------------------------------------------------------------
# Accessors
# ------------------------------------------------------------------
@property
def prefill_regression(self) -> PrefillRegressionModel:
if not self._has_prefill:
raise AttributeError(f"No prefill regression in mode={self._config.mode}")
return self._prefill_regression
@property
def decode_regression(self) -> DecodeRegressionModel:
if not self._has_decode or self._is_agg:
raise AttributeError(f"No decode regression in mode={self._config.mode}")
return self._decode_regression
@property
def agg_regression(self) -> AggRegressionModel:
if not self._is_agg:
raise AttributeError(f"No agg regression in mode={self._config.mode}")
return self._agg_regression
@property
def regression(self) -> AggRegressionModel:
return self.agg_regression
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# mypy: disable-error-code="attr-defined"
"""Throughput-based scaling logic (Prometheus traffic-driven, predictive).
Mixin consumed by ``PlannerStateMachine``. All methods access state
via ``self._config``, ``self._capabilities``, and regression models.
"""
from __future__ import annotations
import logging
import math
from typing import Optional
from dynamo.planner.core.types import ScalingDecision, TrafficObservation
logger = logging.getLogger(__name__)
class ThroughputScalingMixin:
"""Traffic-driven throughput-based scaling decisions."""
def _advance_throughput(
self, traffic: TrafficObservation
) -> Optional[ScalingDecision]:
if not self._config.enable_throughput_scaling:
return None
next_num_req, next_isl, next_osl = self._predict_load()
if next_num_req is None or next_isl is None or next_osl is None:
return None
if traffic.duration_s <= 0:
logger.warning("Traffic observation has non-positive duration, skipping")
return None
demand_rps = next_num_req / traffic.duration_s
mode = self._config.mode
if mode == "agg":
return self._throughput_agg(demand_rps, next_isl, next_osl)
if mode == "disagg":
return self._throughput_disagg(demand_rps, next_isl, next_osl)
return self._throughput_single(demand_rps, next_isl, next_osl, mode)
def _predict_load(self) -> tuple[Optional[float], Optional[float], Optional[float]]:
try:
nr = self._num_req_predictor.predict_next()
isl = self._isl_predictor.predict_next()
osl = self._osl_predictor.predict_next()
logger.info(
f"Predicted load: num_req={nr:.2f}, isl={isl:.2f}, osl={osl:.2f}"
)
return nr, isl, osl
except Exception as e:
logger.error(f"Failed to predict load: {e}")
return None, None, None
def _throughput_single(
self, demand_rps: float, isl: float, osl: float, component: str
) -> Optional[ScalingDecision]:
desired = (
self._compute_prefill_replicas(demand_rps, isl, osl)
if component == "prefill"
else self._compute_decode_replicas(demand_rps, isl, osl)
)
if desired is None:
return None
if self._config.enable_load_scaling:
if component == "prefill":
self._throughput_lower_bound_p = desired
else:
self._throughput_lower_bound_d = desired
logger.info(f"Throughput lower bound set to {desired} for {component}")
return None
desired = self._apply_single_budget(desired, component)
return (
ScalingDecision(num_prefill=desired)
if component == "prefill"
else ScalingDecision(num_decode=desired)
)
def _throughput_disagg(
self, demand_rps: float, isl: float, osl: float
) -> Optional[ScalingDecision]:
num_p = self._compute_prefill_replicas(demand_rps, isl, osl)
num_d = self._compute_decode_replicas(demand_rps, isl, osl)
if num_p is None or num_d is None:
return None
if self._config.enable_load_scaling:
self._throughput_lower_bound_p = num_p
self._throughput_lower_bound_d = num_d
logger.info(f"Throughput lower bounds set: prefill={num_p}, decode={num_d}")
return None
num_p, num_d = self._apply_global_budget(num_p, num_d)
return ScalingDecision(num_prefill=num_p, num_decode=num_d)
def _throughput_agg(
self, demand_rps: float, isl: float, osl: float
) -> Optional[ScalingDecision]:
d_caps = self._capabilities.decode
max_tokens = d_caps.max_num_batched_tokens if d_caps else None
if not max_tokens or max_tokens <= 0:
logger.warning(
"max_num_batched_tokens not available, skipping agg throughput"
)
return None
(
engine_rps,
actual_ttft,
actual_itl,
) = self._agg_regression.find_best_engine_agg_rps(
isl=isl,
osl=osl,
max_num_batched_tokens=max_tokens,
ttft_sla=self._config.ttft,
itl_sla=self._config.itl,
)
if engine_rps <= 0:
logger.warning("Agg perf model not ready, skipping throughput scaling")
return None
if actual_ttft > self._config.ttft or actual_itl > self._config.itl:
logger.warning(
f"Agg SLA not fully met: TTFT={actual_ttft:.1f}ms, ITL={actual_itl:.1f}ms"
)
desired = max(math.ceil(demand_rps / engine_rps), self._config.min_endpoint)
logger.info(
f"Agg: {demand_rps:.2f} rps / {engine_rps:.2f} engine_rps = {desired} replicas"
)
if self._config.enable_load_scaling:
self._throughput_lower_bound_d = desired
logger.info(f"Agg throughput lower bound set to {desired}")
return None
desired = self._apply_single_budget(desired, "decode")
return ScalingDecision(num_decode=desired)
def _compute_prefill_replicas(
self, demand_rps: float, isl: float, osl: float
) -> Optional[int]:
engine_rps, ttft_ms = self._prefill_regression.find_best_engine_prefill_rps(
ttft_sla=self._config.ttft, isl=isl
)
if engine_rps <= 0:
logger.warning("Prefill perf model not ready, skipping throughput scaling")
return None
if ttft_ms > self._config.ttft:
logger.warning(
f"Prefill TTFT SLA not met: {ttft_ms:.1f}ms > {self._config.ttft:.1f}ms"
)
result = max(math.ceil(demand_rps / engine_rps), self._config.min_endpoint)
logger.info(
f"Prefill: {demand_rps:.2f} rps / {engine_rps:.2f} = {result}, est_ttft={ttft_ms:.1f}ms"
)
return result
def _compute_decode_replicas(
self, demand_rps: float, isl: float, osl: float
) -> Optional[int]:
engine_rps, itl_ms = self._decode_regression.find_best_engine_decode_rps(
itl=self._config.itl,
context_length=isl + osl / 2,
osl=osl,
)
if engine_rps <= 0:
logger.warning("Decode perf model not ready, skipping throughput scaling")
return None
if itl_ms > self._config.itl:
logger.warning(
f"Decode ITL SLA not met: {itl_ms:.1f}ms > {self._config.itl:.1f}ms"
)
result = max(math.ceil(demand_rps / engine_rps), self._config.min_endpoint)
logger.info(
f"Decode: {demand_rps:.2f} rps / {engine_rps:.2f} = {result}, est_itl={itl_ms:.1f}ms"
)
return result
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Explicit-input types for the planner core.
These types form the boundary between the planner core (pure decision logic)
and any adapter (native runtime, replay harness, tests). The core receives
``TickInput`` and returns ``PlannerEffects``; the adapter fills the input
based on the previous tick's ``ScheduledTick`` requirements.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
@dataclass
class ScheduledTick:
"""Declares when the core next needs to be called, what data it needs,
and what decisions to make.
All times are absolute seconds (wall clock for native adapter,
simulated clock for replay).
"""
at_s: float
# What decisions the core will make on this tick
run_load_scaling: bool = False
run_throughput_scaling: bool = False
# What data the adapter should collect before calling on_tick
need_traffic_metrics: bool = False
traffic_metrics_duration_s: float = 0.0
need_worker_states: bool = False
need_worker_fpm: bool = False
@dataclass
class TrafficObservation:
"""Aggregated traffic metrics over an observation window."""
duration_s: float
num_req: float
isl: float
osl: float
@dataclass
class WorkerCounts:
"""Current worker inventory as reported by the adapter."""
ready_num_prefill: Optional[int] = None
ready_num_decode: Optional[int] = None
expected_num_prefill: Optional[int] = None
expected_num_decode: Optional[int] = None
@dataclass
class FpmObservations:
"""Per-engine ForwardPassMetrics keyed by (worker_id, dp_rank)."""
prefill: Optional[dict[tuple[str, int], ForwardPassMetrics]] = None
decode: Optional[dict[tuple[str, int], ForwardPassMetrics]] = None
@dataclass
class TickInput:
"""What the adapter provides to the core on each tick.
Fields are filled according to the previous ``ScheduledTick``'s
declared requirements.
"""
now_s: float
traffic: Optional[TrafficObservation] = None
worker_counts: Optional[WorkerCounts] = None
fpm_observations: Optional[FpmObservations] = None
@dataclass
class ScalingDecision:
"""Desired replica counts. ``None`` means the core has no opinion
on that component (e.g. prefill-only planner leaves decode as None).
"""
num_prefill: Optional[int] = None
num_decode: Optional[int] = None
@dataclass
class PlannerEffects:
"""What the core returns after processing a tick."""
scale_to: Optional[ScalingDecision] = None
next_tick: Optional[ScheduledTick] = None
@dataclass
class EngineCapabilities:
"""Static capabilities for a single engine stage (prefill or decode)."""
num_gpu: Optional[int] = None
max_num_batched_tokens: Optional[int] = None
max_num_seqs: Optional[int] = None
context_length: Optional[int] = None
@dataclass
class WorkerCapabilities:
"""Static per-engine capabilities discovered at startup from MDC.
Provided once when constructing the planner core. In native mode
these come from ``WorkerInfo`` (resolved via MDC / DGD); in replay
they come from the simulated engine args.
For agg mode, only ``decode`` is populated (single engine type).
"""
prefill: Optional[EngineCapabilities] = None
decode: Optional[EngineCapabilities] = None
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import os
from unittest.mock import Mock, patch
"""Regression model unit tests.
These test the perf_model classes directly (PrefillRegressionModel,
DecodeRegressionModel, AggRegressionModel) without any planner adapter.
FPM-driven scaling integration tests live in test_state_machine.py.
"""
import pytest
......@@ -15,18 +20,12 @@ from dynamo.common.forward_pass_metrics import (
ForwardPassMetrics,
QueuedRequestMetrics,
ScheduledRequestMetrics,
encode,
)
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.decode import DecodePlanner
from dynamo.planner.core.perf_model import (
AggRegressionModel,
DecodeRegressionModel,
PrefillRegressionModel,
)
from dynamo.planner.core.prefill import PrefillPlanner
from dynamo.planner.core.state import PlannerSharedState
from dynamo.planner.monitoring.worker_info import WorkerInfo
pytestmark = [
pytest.mark.gpu_0,
......@@ -149,7 +148,6 @@ class TestPrefillRegressionModel:
ttft_sla=2000.0, isl=1000.0
)
assert rps > 0
# wall_time ~1.002s for 1000 tokens -> rps ~ 1/1.002 ~ 0.998
assert 0.5 < rps < 2.0
assert actual_ttft_ms > 0
assert 1000 < actual_ttft_ms < 2000
......@@ -190,7 +188,6 @@ class TestPrefillRegressionModel:
class TestBucketedRetirement:
def test_total_capped_at_max(self):
"""Total observations never exceed max_num_fpm_samples."""
model = PrefillRegressionModel(
max_num_fpm_samples=10, min_observations=3, bucket_count=4
)
......@@ -204,12 +201,9 @@ class TestBucketedRetirement:
assert model.num_observations == 10
def test_most_populated_bucket_loses_oldest(self):
"""When evicting, the oldest entry from the most-populated bucket is removed."""
model = PrefillRegressionModel(
max_num_fpm_samples=6, min_observations=1, bucket_count=4
)
# 3 observations at low tokens (bucket 0 area)
for i in range(3):
fpm = _make_fpm(
sum_prefill_tokens=10 + i,
......@@ -217,8 +211,6 @@ class TestBucketedRetirement:
wall_time=0.001 * (10 + i),
)
model.add_observation(fpm)
# 3 observations at high tokens (different bucket)
for i in range(3):
fpm = _make_fpm(
sum_prefill_tokens=1000 + i * 100,
......@@ -226,51 +218,31 @@ class TestBucketedRetirement:
wall_time=0.001 * (1000 + i * 100),
)
model.add_observation(fpm)
assert model.num_observations == 6
# One more at low tokens; total would exceed 6 so most-populated
# bucket loses its oldest entry.
fpm = _make_fpm(
sum_prefill_tokens=15,
num_prefill_requests=1,
wall_time=0.015,
)
fpm = _make_fpm(sum_prefill_tokens=15, num_prefill_requests=1, wall_time=0.015)
model.add_observation(fpm)
assert model.num_observations == 6
def test_uniform_distribution_preserved(self):
"""Bucketed eviction keeps observations across operating points."""
model = DecodeRegressionModel(
max_num_fpm_samples=10, min_observations=3, bucket_count=16
)
# Many observations at a single operating point
for _ in range(15):
fpm = _make_fpm(
num_decode_requests=32,
sum_decode_kv_tokens=32000,
wall_time=0.01,
num_decode_requests=32, sum_decode_kv_tokens=32000, wall_time=0.01
)
model.add_observation(fpm)
assert model.num_observations == 10
# Add a different operating point; the concentrated bucket loses one
fpm = _make_fpm(
num_decode_requests=4,
sum_decode_kv_tokens=4000,
wall_time=0.005,
num_decode_requests=4, sum_decode_kv_tokens=4000, wall_time=0.005
)
model.add_observation(fpm)
assert model.num_observations == 10
def test_2d_bucketed_retirement(self):
"""2D models retire from the most-populated grid cell."""
model = AggRegressionModel(
max_num_fpm_samples=8, min_observations=1, bucket_count=16
)
# Fill with varied data
for p, d in [(100, 500), (200, 1000), (300, 1500), (400, 2000)]:
fpm = _make_fpm(
sum_prefill_tokens=p,
......@@ -280,8 +252,6 @@ class TestBucketedRetirement:
wall_time=0.001 * p + 0.0001 * d,
)
model.add_observation(fpm)
# Concentrate 4 more in one region
for _ in range(4):
fpm = _make_fpm(
sum_prefill_tokens=100,
......@@ -291,10 +261,7 @@ class TestBucketedRetirement:
wall_time=0.15,
)
model.add_observation(fpm)
assert model.num_observations == 8
# Overflow triggers retirement from the concentrated cell
fpm = _make_fpm(
sum_prefill_tokens=350,
num_prefill_requests=1,
......@@ -310,15 +277,8 @@ class TestBucketedRetirement:
class TestDecodeRegressionModel:
def _train_2d(self, model: DecodeRegressionModel) -> None:
"""Populate with 2D data: wall_time = f(num_decode_requests, sum_decode_kv_tokens)."""
for n_req, kv in [
(5, 1000),
(10, 2000),
(15, 3000),
(20, 4000),
(25, 5000),
]:
def _train_2d(self, model):
for n_req, kv in [(5, 1000), (10, 2000), (15, 3000), (20, 4000), (25, 5000)]:
fpm = _make_fpm(
sum_decode_kv_tokens=kv,
num_decode_requests=n_req,
......@@ -346,11 +306,9 @@ class TestDecodeRegressionModel:
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_2d(model)
assert model.has_sufficient_data()
est = model.estimate_next_itl(scheduled_decode_kv=3000, queued_decode_kv=0)
assert est is not None
assert est > 0
assert est is not None and est > 0
def test_avg_decode_length_tracking(self):
model = DecodeRegressionModel(
......@@ -365,8 +323,7 @@ class TestDecodeRegressionModel:
model.add_observation(fpm)
assert abs(model.avg_decode_length - 200.0) < 1.0
def _train_thpt_model(self, model: DecodeRegressionModel) -> None:
"""Populate with 2D data at decode-realistic wall-time scale."""
def _train_thpt_model(self, model):
for n_req, kv in [
(5, 5000),
(10, 10000),
......@@ -386,13 +343,10 @@ class TestDecodeRegressionModel:
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_thpt_model(model)
rps, actual_itl = model.find_best_engine_decode_rps(
itl=50.0, context_length=1000.0, osl=150.0
)
assert rps > 0
assert actual_itl > 0
assert actual_itl <= 50.0
assert rps > 0 and actual_itl > 0 and actual_itl <= 50.0
def test_find_best_engine_decode_rps_zero_context(self):
model = DecodeRegressionModel(
......@@ -402,8 +356,7 @@ class TestDecodeRegressionModel:
rps, itl_ms = model.find_best_engine_decode_rps(
itl=50.0, context_length=0.0, osl=150.0
)
assert rps == 0.0
assert itl_ms == 0.0
assert rps == 0.0 and itl_ms == 0.0
def test_load_benchmark_fpms(self):
model = DecodeRegressionModel(
......@@ -418,15 +371,14 @@ class TestDecodeRegressionModel:
for n in [5, 10, 15, 20, 25]
]
model.load_benchmark_fpms(fpms)
assert model.num_observations == 5
assert model.has_sufficient_data()
assert model.num_observations == 5 and model.has_sufficient_data()
# ── AggRegressionModel tests ─────────────────────────────────────────
class TestAggRegressionModel:
def _train_agg(self, model: AggRegressionModel) -> None:
def _train_agg(self, model):
for p, d in [(100, 1000), (200, 2000), (300, 3000), (400, 4000), (500, 5000)]:
fpm = _make_fpm(
sum_prefill_tokens=p,
......@@ -458,27 +410,19 @@ class TestAggRegressionModel:
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_agg(model)
assert model.has_sufficient_data()
ttft = model.estimate_next_ttft(
queued_prefill_tokens=0,
max_num_batched_tokens=2048,
current_decode_kv=3000,
queued_prefill_tokens=0, max_num_batched_tokens=2048, current_decode_kv=3000
)
assert ttft is not None
assert ttft > 0
assert ttft is not None and ttft > 0
itl = model.estimate_next_itl(scheduled_decode_kv=3000, queued_decode_kv=0)
assert itl is not None
assert itl > 0
assert itl is not None and itl > 0
def test_find_best_engine_agg_rps(self):
model = AggRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_agg(model)
thpt, actual_ttft, actual_itl = model.find_best_engine_agg_rps(
isl=2048.0,
osl=150.0,
......@@ -486,10 +430,7 @@ class TestAggRegressionModel:
ttft_sla=500.0,
itl_sla=50.0,
)
assert isinstance(thpt, float)
assert thpt > 0
assert actual_ttft >= 0
assert actual_itl >= 0
assert thpt > 0 and actual_ttft >= 0 and actual_itl >= 0
def test_find_best_engine_agg_rps_insufficient_data(self):
model = AggRegressionModel(
......@@ -503,221 +444,3 @@ class TestAggRegressionModel:
itl_sla=50.0,
)
assert thpt == 0.0
# ── Planner integration tests (with mocked FPM subscriber) ──────────
@pytest.fixture(autouse=True)
def mock_prometheus_metrics():
with patch("dynamo.planner.monitoring.planner_metrics.Gauge") as mock_gauge:
mock_gauge.return_value = Mock()
yield
def _build_load_config(**overrides) -> PlannerConfig:
defaults = dict(
throughput_adjustment_interval=60,
prefill_engine_num_gpu=1,
decode_engine_num_gpu=1,
min_endpoint=1,
max_gpu_budget=-1,
ttft=500.0,
itl=50.0,
backend="vllm",
no_operation=True,
metric_pulling_prometheus_endpoint="http://localhost:9090",
metric_reporting_prometheus_port=0,
load_predictor="constant",
profile_results_dir=os.path.join(
os.path.dirname(__file__),
"..",
"data",
"profiling_results",
"H200_TP1P_TP1D",
),
environment="kubernetes",
namespace="test-namespace",
mode="disagg",
enable_load_scaling=True,
enable_throughput_scaling=True,
load_adjustment_interval=5,
max_num_fpm_samples=50,
fpm_sample_bucket_size=16,
load_scaling_down_sensitivity=80,
load_metric_samples=10,
load_min_observations=5,
)
defaults.update(overrides)
return PlannerConfig.model_construct(**defaults)
def _mock_fpm_subscriber(fpm_stats: dict[tuple[str, int], ForwardPassMetrics]):
"""Create a mock FPM subscriber that returns encoded FPM stats."""
mock = Mock()
encoded = {k: encode(v) for k, v in fpm_stats.items()}
mock.get_recent_stats.return_value = encoded
return mock
class TestPrefillFpmScaling:
def test_scale_up_all_engines_above_sla(self):
"""All engines have high queued prefill -> estimated TTFT > SLA -> scale up."""
config = _build_load_config(ttft=5.0) # 5ms SLA (easy to exceed)
shared_state = PlannerSharedState()
shared_state.num_p_workers = 2
planner = PrefillPlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
planner.prefill_worker_info = WorkerInfo(max_num_batched_tokens=2048)
for tokens in range(200, 1200, 100):
fpm = _make_fpm(
sum_prefill_tokens=tokens,
num_prefill_requests=1,
wall_time=0.001 * tokens,
)
planner.ttft_regression.add_observation(fpm)
stats = {
("w1", 0): _make_fpm(
worker_id="w1",
queued_prefill_tokens=10000,
sum_prefill_tokens=500,
num_prefill_requests=1,
wall_time=0.5,
),
("w2", 0): _make_fpm(
worker_id="w2",
queued_prefill_tokens=8000,
sum_prefill_tokens=600,
num_prefill_requests=1,
wall_time=0.6,
),
}
planner.fpm_subscriber = _mock_fpm_subscriber(stats)
result = planner.load_plan_adjustment()
assert result == 3
def test_scale_down_all_engines_below_sla(self):
"""All engines have low queued prefill -> estimated TTFT < SLA * sensitivity."""
config = _build_load_config(ttft=500.0, load_scaling_down_sensitivity=100)
shared_state = PlannerSharedState()
shared_state.num_p_workers = 3
planner = PrefillPlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
planner.prefill_worker_info = WorkerInfo(max_num_batched_tokens=2048)
for tokens in range(100, 600, 50):
fpm = _make_fpm(
sum_prefill_tokens=tokens,
num_prefill_requests=1,
wall_time=0.001 * tokens,
)
planner.ttft_regression.add_observation(fpm)
stats = {
(f"w{i}", 0): _make_fpm(
worker_id=f"w{i}",
queued_prefill_tokens=0,
sum_prefill_tokens=100,
num_prefill_requests=1,
wall_time=0.1,
)
for i in range(3)
}
planner.fpm_subscriber = _mock_fpm_subscriber(stats)
result = planner.load_plan_adjustment()
assert result == 2
def test_cold_start_returns_none(self):
config = _build_load_config()
shared_state = PlannerSharedState()
shared_state.num_p_workers = 2
planner = PrefillPlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
planner.prefill_worker_info = WorkerInfo(max_num_batched_tokens=2048)
for tokens in [100, 200]:
fpm = _make_fpm(sum_prefill_tokens=tokens, wall_time=0.01)
planner.ttft_regression.add_observation(fpm)
stats = {("w1", 0): _make_fpm(queued_prefill_tokens=5000, wall_time=0.5)}
planner.fpm_subscriber = _mock_fpm_subscriber(stats)
result = planner.load_plan_adjustment()
assert result is None
class TestDecodeFpmScaling:
def test_scale_up_all_engines_above_sla(self):
"""All engines have high decode load -> estimated ITL > SLA -> scale up."""
config = _build_load_config(itl=5.0) # 5ms SLA
shared_state = PlannerSharedState()
shared_state.num_d_workers = 2
planner = DecodePlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
# 2D regression: vary both num_decode_requests and sum_decode_kv_tokens
for n_req, kv in [
(5, 1000),
(10, 2000),
(15, 3000),
(20, 4000),
(25, 5000),
]:
fpm = _make_fpm(
sum_decode_kv_tokens=kv,
num_decode_requests=n_req,
wall_time=0.0001 * kv + 0.0005 * n_req + 0.001,
)
planner.itl_regression.add_observation(fpm)
stats = {
("w1", 0): _make_fpm(
worker_id="w1",
sum_decode_kv_tokens=5000,
queued_decode_kv_tokens=3000,
num_decode_requests=20,
wall_time=0.6,
),
("w2", 0): _make_fpm(
worker_id="w2",
sum_decode_kv_tokens=4500,
queued_decode_kv_tokens=2500,
num_decode_requests=18,
wall_time=0.55,
),
}
planner.fpm_subscriber = _mock_fpm_subscriber(stats)
result = planner.load_plan_adjustment()
assert result == 3
def test_cold_start_returns_none(self):
config = _build_load_config()
shared_state = PlannerSharedState()
shared_state.num_d_workers = 2
planner = DecodePlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
fpm = _make_fpm(
sum_decode_kv_tokens=1000, num_decode_requests=5, wall_time=0.01
)
planner.itl_regression.add_observation(fpm)
stats = {
("w1", 0): _make_fpm(
sum_decode_kv_tokens=5000, num_decode_requests=10, wall_time=0.5
)
}
planner.fpm_subscriber = _mock_fpm_subscriber(stats)
result = planner.load_plan_adjustment()
assert result is None
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Unit tests for SLA planner replica calculation logic.
These tests focus specifically on the replica calculation formulas without
testing load prediction or regression internals.
"""
import asyncio
import math
import os
from unittest.mock import Mock, patch
import pytest
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.budget import _apply_global_gpu_budget
from dynamo.planner.core.decode import DecodePlanner
from dynamo.planner.core.prefill import PrefillPlanner
from dynamo.planner.core.state import PlannerSharedState
from dynamo.planner.monitoring.traffic_metrics import Metrics
from dynamo.planner.monitoring.worker_info import WorkerInfo
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.gpu_0,
pytest.mark.unit,
pytest.mark.planner,
]
class PlannerHarness:
def __init__(self, prefill_planner, decode_planner, shared_state):
self.prefill_planner = prefill_planner
self.decode_planner = decode_planner
self.shared_state = shared_state
self.last_target_replicas = []
async def make_adjustments(self):
if not self.shared_state.last_metrics.is_valid():
return
num_p, num_d, is_stable = await self.prefill_planner.get_workers_info()
self.shared_state.num_p_workers = num_p
self.shared_state.num_d_workers = num_d
next_num_p = self.prefill_planner.plan_adjustment()
next_num_d = self.decode_planner.plan_adjustment()
if next_num_p is None or next_num_d is None:
return
next_num_p, next_num_d = _apply_global_gpu_budget(
next_num_p, next_num_d, self.prefill_planner.config
)
self.prefill_planner.update_predicted_replicas_metric(next_num_p)
self.decode_planner.update_predicted_replicas_metric(next_num_d)
target_replicas = [
{
"sub_component_type": "prefill",
"component_name": self.prefill_planner.prefill_worker_info.k8s_name,
"desired_replicas": next_num_p,
},
{
"sub_component_type": "decode",
"component_name": self.prefill_planner.decode_worker_info.k8s_name,
"desired_replicas": next_num_d,
},
]
self.last_target_replicas = target_replicas
if not self.prefill_planner.config.no_operation:
await self.prefill_planner.connector.set_component_replicas(
target_replicas, blocking=False
)
def __getattr__(self, name):
shared_attrs = {
"num_req_predictor",
"isl_predictor",
"osl_predictor",
"connector",
"prometheus_traffic_client",
"config",
}
prefill_attrs = {
"ttft_regression",
"prefill_worker_info",
}
decode_attrs = {
"itl_regression",
"decode_worker_info",
}
if name == "last_metrics":
return self.shared_state.last_metrics
if name == "get_workers_info":
return self.prefill_planner.get_workers_info
if name in shared_attrs:
return getattr(self.prefill_planner, name)
if name in prefill_attrs:
return getattr(self.prefill_planner, name)
if name in decode_attrs:
return getattr(self.decode_planner, name)
raise AttributeError(name)
def __setattr__(self, name, value):
if name in {"prefill_planner", "decode_planner", "shared_state"}:
return super().__setattr__(name, value)
shared_attrs = {
"num_req_predictor",
"isl_predictor",
"osl_predictor",
"connector",
"prometheus_traffic_client",
"config",
"get_workers_info",
}
prefill_attrs = {"ttft_regression"}
decode_attrs = {"itl_regression"}
if name == "last_metrics":
self.shared_state.last_metrics = value
return None
if name in shared_attrs:
# Store locally to support patch.object lifecycle (set/del).
object.__setattr__(self, name, value)
setattr(self.prefill_planner, name, value)
setattr(self.decode_planner, name, value)
return None
if name in prefill_attrs:
setattr(self.prefill_planner, name, value)
return None
if name in decode_attrs:
setattr(self.decode_planner, name, value)
return None
return super().__setattr__(name, value)
def _replica_count(target_replicas, component_name, default=1):
for replica in target_replicas:
if replica.get("component_name") == component_name:
return replica.get("desired_replicas", default)
return default
@pytest.fixture
def planner():
"""Set up test environment with mocked dependencies."""
config = PlannerConfig.model_construct(
throughput_adjustment_interval=60,
prefill_engine_num_gpu=1,
decode_engine_num_gpu=1,
min_endpoint=1,
max_gpu_budget=10,
ttft=80.0,
itl=10.0,
backend="vllm",
no_operation=True,
metric_pulling_prometheus_endpoint="http://localhost:9090",
metric_reporting_prometheus_port=0,
load_predictor="constant",
profile_results_dir=os.path.join(
os.path.dirname(__file__),
"..",
"data",
"profiling_results",
"H200_TP1P_TP1D",
),
environment="kubernetes",
namespace="test-namespace",
enable_throughput_scaling=True,
enable_load_scaling=False,
load_predictor_warmup_trace=None,
load_predictor_log1p=False,
max_num_fpm_samples=50,
fpm_sample_bucket_size=16,
load_min_observations=5,
)
mock_runtime = Mock()
with patch("dynamo.planner.monitoring.planner_metrics.Gauge") as mock_gauge:
mock_gauge.return_value = Mock()
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(
mock_runtime, config, shared_state=shared_state
)
decode_planner = DecodePlanner(mock_runtime, config, shared_state=shared_state)
planner = PlannerHarness(prefill_planner, decode_planner, shared_state)
# Set up WorkerInfo for both planners
prefill_planner.prefill_worker_info = WorkerInfo(
k8s_name="VllmPrefillWorker",
component_name="prefill",
endpoint="generate",
)
prefill_planner.decode_worker_info = WorkerInfo(
k8s_name="VllmDecodeWorker",
component_name="backend",
endpoint="generate",
)
decode_planner.prefill_worker_info = prefill_planner.prefill_worker_info
decode_planner.decode_worker_info = prefill_planner.decode_worker_info
planner.ttft_regression = Mock()
# Default: 40000 tokens/s at isl=3000 → 40000/3000 rps
planner.ttft_regression.find_best_engine_prefill_rps.return_value = (
40000.0 / 3000.0,
75.0,
)
planner.ttft_regression.has_sufficient_data.return_value = True
planner.itl_regression = Mock()
# Default: 10000 tokens/s at osl=150 → 10000/150 rps
planner.itl_regression.find_best_engine_decode_rps.return_value = (
10000.0 / 150.0,
9.5,
)
planner.itl_regression.has_sufficient_data.return_value = True
# Mock the predictors to return fixed values
planner.num_req_predictor = Mock()
planner.isl_predictor = Mock()
planner.osl_predictor = Mock()
# Mock the connector since we're not testing actual scaling
planner.connector = Mock()
# Mock prometheus client
planner.prometheus_traffic_client = Mock()
planner.config = config
yield planner
class TestReplicaCalculation:
"""Test replica calculation formulas in isolation."""
@pytest.mark.nightly
@pytest.mark.gpu_2
@pytest.mark.performance
def test_prefill_replica_calculation_basic(self, planner):
"""Test basic prefill replica calculation."""
next_num_req = 10
next_isl = 3000
engine_rps = 40000.0 / next_isl
planner.num_req_predictor.predict_next.return_value = next_num_req
planner.isl_predictor.predict_next.return_value = next_isl
planner.osl_predictor.predict_next.return_value = 150
planner.ttft_regression.find_best_engine_prefill_rps.return_value = (
engine_rps,
75.0,
)
planner.itl_regression.find_best_engine_decode_rps.return_value = (
10000.0 / 150.0,
9.5,
)
# Formula: ceil(num_req / interval / engine_rps)
pred_prefill_demand = (
next_num_req / planner.config.throughput_adjustment_interval
)
expected_prefill_replicas = math.ceil(pred_prefill_demand / engine_rps)
planner.last_metrics = Metrics(
num_req=10, isl=3000, osl=150, ttft=80.0, itl=10.0, request_duration=100.0
)
async def mock_get_workers_info(*args, **kwargs):
return (1, 1, True)
planner.get_workers_info = mock_get_workers_info
asyncio.run(planner.make_adjustments())
prefill_component = "VllmPrefillWorker"
calculated_prefill_replicas = _replica_count(
planner.last_target_replicas, prefill_component
)
print(f"Expected prefill replicas: {expected_prefill_replicas}")
print(f"Calculated prefill replicas: {calculated_prefill_replicas}")
assert (
max(expected_prefill_replicas, planner.config.min_endpoint)
== calculated_prefill_replicas
)
@pytest.mark.nightly
@pytest.mark.gpu_2
@pytest.mark.performance
def test_decode_replica_calculation_basic(self, planner):
"""Test basic decode replica calculation."""
next_num_req = 10
next_osl = 150
engine_rps = 10000.0 / next_osl
planner.num_req_predictor.predict_next.return_value = next_num_req
planner.isl_predictor.predict_next.return_value = 3000
planner.osl_predictor.predict_next.return_value = next_osl
planner.ttft_regression.find_best_engine_prefill_rps.return_value = (
40000.0 / 3000.0,
75.0,
)
planner.itl_regression.find_best_engine_decode_rps.return_value = (
engine_rps,
9.5,
)
# Formula: ceil(num_req / interval / engine_rps)
expected_decode_replicas = math.ceil(
next_num_req / planner.config.throughput_adjustment_interval / engine_rps
)
planner.last_metrics = Metrics(
num_req=10, isl=3000, osl=150, ttft=80.0, itl=10.0, request_duration=100.0
)
async def mock_get_workers_info(*args, **kwargs):
return (1, 1, True)
planner.get_workers_info = mock_get_workers_info
asyncio.run(planner.make_adjustments())
decode_component = "VllmDecodeWorker"
calculated_decode_replicas = _replica_count(
planner.last_target_replicas, decode_component
)
print(f"Expected decode replicas: {expected_decode_replicas}")
print(f"Calculated decode replicas: {calculated_decode_replicas}")
assert (
max(expected_decode_replicas, planner.config.min_endpoint)
== calculated_decode_replicas
)
@pytest.mark.parametrize(
"num_req,decode_rps,expected_p,expected_d",
[
(10, 10000.0 / 150.0, 1, 1), # low_load_10_req_per_second
(
500,
1000.0 / 150.0,
1,
2,
), # high_load_500_req_per_second (lower decode rps)
],
)
@pytest.mark.nightly
@pytest.mark.gpu_2
@pytest.mark.performance
def test_scaling_scenario_low_to_high_load(
self, planner, num_req, decode_rps, expected_p, expected_d
):
"""Test scaling from low to high load scenarios."""
planner.num_req_predictor.predict_next.return_value = num_req
planner.isl_predictor.predict_next.return_value = 3000
planner.osl_predictor.predict_next.return_value = 150
planner.ttft_regression.find_best_engine_prefill_rps.return_value = (
40000.0 / 3000.0,
75.0,
)
planner.itl_regression.find_best_engine_decode_rps.return_value = (
decode_rps,
9.5,
)
planner.last_metrics = Metrics(
num_req=num_req,
isl=3000,
osl=150,
ttft=80.0,
itl=10.0,
request_duration=100.0,
)
async def mock_get_workers_info(*args, **kwargs):
return (1, 1, True)
planner.get_workers_info = mock_get_workers_info
planner.connector.reset_mock()
asyncio.run(planner.make_adjustments())
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
print(f"Load {num_req} req/s: P={prefill_replicas}, D={decode_replicas}")
assert (
prefill_replicas == expected_p
), f"Prefill replicas mismatch: expected {expected_p}, got {prefill_replicas}"
assert (
decode_replicas == expected_d
), f"Decode replicas mismatch: expected {expected_d}, got {decode_replicas}"
@pytest.mark.nightly
@pytest.mark.gpu_2
@pytest.mark.performance
def test_gpu_budget_constraint(self, planner):
"""Test that GPU budget constraints are properly applied."""
planner.config.max_gpu_budget = 3
planner.num_req_predictor.predict_next.return_value = 50
planner.isl_predictor.predict_next.return_value = 3000
planner.osl_predictor.predict_next.return_value = 150
planner.ttft_regression.find_best_engine_prefill_rps.return_value = (
40000.0 / 3000.0,
75.0,
)
planner.itl_regression.find_best_engine_decode_rps.return_value = (
10000.0 / 150.0,
9.5,
)
planner.last_metrics = Metrics(
num_req=50, isl=3000, osl=150, ttft=80.0, itl=10.0, request_duration=100.0
)
async def mock_get_workers_info(*args, **kwargs):
return (1, 1, True)
planner.get_workers_info = mock_get_workers_info
asyncio.run(planner.make_adjustments())
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
total_gpus = (
prefill_replicas * planner.config.prefill_engine_num_gpu
+ decode_replicas * planner.config.decode_engine_num_gpu
)
print(
f"GPU budget test: P={prefill_replicas}, D={decode_replicas}, Total GPUs={total_gpus}"
)
assert (
total_gpus <= planner.config.max_gpu_budget
), "Total GPU usage exceeds budget"
@pytest.mark.nightly
@pytest.mark.gpu_2
@pytest.mark.performance
def test_min_endpoint_constraint(self, planner):
"""Test that minimum endpoint constraints are respected."""
planner.config.min_endpoint = 2
planner.num_req_predictor.predict_next.return_value = 1
planner.isl_predictor.predict_next.return_value = 100
planner.osl_predictor.predict_next.return_value = 10
planner.ttft_regression.find_best_engine_prefill_rps.return_value = (
40000.0 / 100.0,
75.0,
)
planner.itl_regression.find_best_engine_decode_rps.return_value = (
10000.0 / 10.0,
9.5,
)
planner.last_metrics = Metrics(
num_req=1, isl=100, osl=10, ttft=80.0, itl=10.0, request_duration=100.0
)
async def mock_get_workers_info(*args, **kwargs):
return (1, 1, True)
planner.get_workers_info = mock_get_workers_info
asyncio.run(planner.make_adjustments())
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
print(f"Min endpoint test: P={prefill_replicas}, D={decode_replicas}")
assert (
prefill_replicas >= planner.config.min_endpoint
), "Prefill replicas below minimum"
assert (
decode_replicas >= planner.config.min_endpoint
), "Decode replicas below minimum"
@pytest.mark.nightly
@pytest.mark.gpu_2
@pytest.mark.performance
def test_multi_gpu_engines(self, planner):
"""Test replica calculation with multi-GPU engines."""
planner.config.prefill_engine_num_gpu = 2
planner.config.decode_engine_num_gpu = 4
planner.num_req_predictor.predict_next.return_value = 20
planner.isl_predictor.predict_next.return_value = 3000
planner.osl_predictor.predict_next.return_value = 150
# Engine-level request rate (already accounts for multi-GPU)
prefill_engine_rps = 40000.0 / 3000.0
decode_engine_rps = 5000.0 / 150.0
planner.ttft_regression.find_best_engine_prefill_rps.return_value = (
prefill_engine_rps,
75.0,
)
planner.itl_regression.find_best_engine_decode_rps.return_value = (
decode_engine_rps,
9.5,
)
planner.last_metrics = Metrics(
num_req=20, isl=3000, osl=150, ttft=80.0, itl=10.0, request_duration=100.0
)
async def mock_get_workers_info(*args, **kwargs):
return (1, 1, True)
planner.get_workers_info = mock_get_workers_info
# No engine_num_gpu division — regression returns engine-level rps
expected_prefill_replicas = math.ceil(
20 / planner.config.throughput_adjustment_interval / prefill_engine_rps
)
expected_decode_replicas = math.ceil(
20 / planner.config.throughput_adjustment_interval / decode_engine_rps
)
asyncio.run(planner.make_adjustments())
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
print(
f"Multi-GPU test: P={prefill_replicas} (expected ~{expected_prefill_replicas}), "
f"D={decode_replicas} (expected ~{expected_decode_replicas})"
)
assert prefill_replicas == max(
expected_prefill_replicas, planner.config.min_endpoint
)
assert decode_replicas == max(
expected_decode_replicas, planner.config.min_endpoint
)
@pytest.mark.weekly
@pytest.mark.gpu_2
@pytest.mark.performance
def test_complex_gpu_budget_scaling(self, planner):
"""Test complex GPU budget scaling with proportional reduction."""
planner.config.max_gpu_budget = 5
planner.config.prefill_engine_num_gpu = 2
planner.config.decode_engine_num_gpu = 2
planner.config.min_endpoint = 1
planner.num_req_predictor.predict_next.return_value = 100
planner.isl_predictor.predict_next.return_value = 3000
planner.osl_predictor.predict_next.return_value = 150
planner.ttft_regression.find_best_engine_prefill_rps.return_value = (
10000.0 / 3000.0,
300.0,
)
planner.itl_regression.find_best_engine_decode_rps.return_value = (
1000.0 / 150.0,
9.5,
)
planner.last_metrics = Metrics(
num_req=100,
isl=3000,
osl=150,
ttft=80.0,
itl=10.0,
request_duration=100.0,
)
async def mock_get_workers_info(*args, **kwargs):
return (1, 1, True)
planner.get_workers_info = mock_get_workers_info
asyncio.run(planner.make_adjustments())
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
total_gpus = (
prefill_replicas * planner.config.prefill_engine_num_gpu
+ decode_replicas * planner.config.decode_engine_num_gpu
)
print(
f"Complex GPU budget test: P={prefill_replicas}, D={decode_replicas}, "
f"Total GPUs={total_gpus}"
)
assert (
total_gpus <= planner.config.max_gpu_budget
), "Total GPU usage should not exceed budget"
assert (
prefill_replicas >= planner.config.min_endpoint
), "Should respect min_endpoint for prefill"
assert (
decode_replicas >= planner.config.min_endpoint
), "Should respect min_endpoint for decode"
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import math
import os
from unittest.mock import MagicMock, Mock, patch
import pytest
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.budget import _initialize_gpu_counts
from dynamo.planner.core.decode import DecodePlanner
from dynamo.planner.core.prefill import PrefillPlanner
from dynamo.planner.core.state import PlannerSharedState
from dynamo.planner.errors import DeploymentValidationError
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.planner,
]
PREFILL_ENGINE_RPS = 10.0
DECODE_ENGINE_RPS = 5.0
DECODE_ACTUAL_ITL_MS = 40.0
@pytest.fixture(autouse=True)
def mock_prometheus_metrics():
with patch("dynamo.planner.monitoring.planner_metrics.Gauge") as mock_gauge:
mock_gauge.return_value = Mock()
yield
def _build_config():
return PlannerConfig.model_construct(
throughput_adjustment_interval=60,
prefill_engine_num_gpu=1,
decode_engine_num_gpu=1,
min_endpoint=1,
max_gpu_budget=-1,
ttft=500.0,
itl=50.0,
backend="vllm",
no_operation=True,
metric_pulling_prometheus_endpoint="http://localhost:9090",
metric_reporting_prometheus_port=0,
load_predictor="constant",
load_predictor_warmup_trace=None,
load_predictor_log1p=False,
profile_results_dir=os.path.join(
os.path.dirname(__file__),
"..",
"data",
"profiling_results",
"H200_TP1P_TP1D",
),
environment="kubernetes",
namespace="test-namespace",
mode="disagg",
enable_throughput_scaling=True,
enable_load_scaling=False,
)
def _build_prometheus_client(samples):
client = Mock()
client.get_avg_time_to_first_token.side_effect = [
s["ttft_ms"] / 1000 for s in samples
]
client.get_avg_inter_token_latency.side_effect = [
s["itl_ms"] / 1000 for s in samples
]
client.get_avg_request_count.side_effect = [s["num_req"] for s in samples]
client.get_avg_request_duration.side_effect = [
s["request_duration"] for s in samples
]
client.get_avg_input_sequence_tokens.side_effect = [s["isl"] for s in samples]
client.get_avg_output_sequence_tokens.side_effect = [s["osl"] for s in samples]
return client
def _build_planners(config, prometheus_client):
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(None, config, shared_state=shared_state)
decode_planner = DecodePlanner(None, config, shared_state=shared_state)
prefill_planner.prometheus_traffic_client = prometheus_client
decode_planner.prometheus_traffic_client = prometheus_client
prefill_planner.model_name = "test-model"
decode_planner.model_name = "test-model"
prefill_planner.ttft_regression = MagicMock()
prefill_planner.ttft_regression.find_best_engine_prefill_rps.return_value = (
PREFILL_ENGINE_RPS,
75.0,
)
prefill_planner.ttft_regression.has_sufficient_data.return_value = True
decode_planner.itl_regression = MagicMock()
decode_planner.itl_regression.find_best_engine_decode_rps.return_value = (
DECODE_ENGINE_RPS,
DECODE_ACTUAL_ITL_MS,
)
decode_planner.itl_regression.has_sufficient_data.return_value = True
async def mock_get_workers_info(require_prefill=True, require_decode=True):
return (
1 if require_prefill else 0,
1 if require_decode else 0,
True, # is_stable
)
prefill_planner.get_workers_info = mock_get_workers_info
decode_planner.get_workers_info = mock_get_workers_info
return prefill_planner, decode_planner, shared_state
def _expected_prefill(config, prefill_planner, sample):
demand_rps = sample["num_req"] / config.throughput_adjustment_interval
engine_rps, _ = prefill_planner.ttft_regression.find_best_engine_prefill_rps(
ttft_sla=config.ttft, isl=sample["isl"]
)
expected = math.ceil(demand_rps / engine_rps)
return max(expected, config.min_endpoint)
def _expected_decode(config, decode_planner, sample):
demand_rps = sample["num_req"] / config.throughput_adjustment_interval
engine_rps, _ = decode_planner.itl_regression.find_best_engine_decode_rps(
itl=config.itl, context_length=sample["isl"] + sample["osl"] / 2
)
expected = math.ceil(demand_rps / engine_rps)
return max(expected, config.min_endpoint)
def _run_interval(prefill_planner, decode_planner, shared_state):
asyncio.run(
prefill_planner.observe_traffic_stats(require_prefill=True, require_decode=True)
)
decode_planner.update_predictors_from_metrics(shared_state.last_metrics)
next_num_p = prefill_planner.plan_adjustment()
next_num_d = decode_planner.plan_adjustment()
return next_num_p, next_num_d
def test_disagg_scale_up():
config = _build_config()
samples = [
{
"num_req": 10,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
{
"num_req": 5000,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
]
client = _build_prometheus_client(samples)
prefill_planner, decode_planner, shared_state = _build_planners(config, client)
low_p, low_d = _run_interval(prefill_planner, decode_planner, shared_state)
high_p, high_d = _run_interval(prefill_planner, decode_planner, shared_state)
assert low_p == _expected_prefill(config, prefill_planner, samples[0])
assert low_d == _expected_decode(config, decode_planner, samples[0])
assert high_p == _expected_prefill(config, prefill_planner, samples[1])
assert high_d == _expected_decode(config, decode_planner, samples[1])
assert high_p > low_p
assert high_d > low_d
def test_disagg_scale_down():
config = _build_config()
samples = [
{
"num_req": 5000,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
{
"num_req": 10,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
]
client = _build_prometheus_client(samples)
prefill_planner, decode_planner, shared_state = _build_planners(config, client)
high_p, high_d = _run_interval(prefill_planner, decode_planner, shared_state)
low_p, low_d = _run_interval(prefill_planner, decode_planner, shared_state)
assert high_p == _expected_prefill(config, prefill_planner, samples[0])
assert high_d == _expected_decode(config, decode_planner, samples[0])
assert low_p == _expected_prefill(config, prefill_planner, samples[1])
assert low_d == _expected_decode(config, decode_planner, samples[1])
assert low_p < high_p
assert low_d < high_d
class TestInitializeGpuCounts:
@staticmethod
def _make_config(**overrides):
defaults = dict(prefill_engine_num_gpu=None, decode_engine_num_gpu=None)
defaults.update(overrides)
return PlannerConfig.model_construct(**defaults)
def test_kubernetes_mode_reads_from_dgd(self):
"""Test that GPU counts are read from DGD in Kubernetes mode"""
config = self._make_config()
connector = Mock()
connector.get_gpu_counts = Mock(return_value=(2, 4))
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert config.prefill_engine_num_gpu == 2
assert config.decode_engine_num_gpu == 4
connector.get_gpu_counts.assert_called_once_with(
require_prefill=True, require_decode=True
)
def test_kubernetes_mode_prefill_only(self):
"""Test GPU count initialization for prefill-only mode"""
config = self._make_config()
connector = Mock()
connector.get_gpu_counts = Mock(return_value=(2, 0))
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=False
)
assert config.prefill_engine_num_gpu == 2
assert config.decode_engine_num_gpu == 0
connector.get_gpu_counts.assert_called_once_with(
require_prefill=True, require_decode=False
)
def test_virtual_mode_uses_cli_args(self):
"""Test that GPU counts come from config in virtual mode"""
config = self._make_config(prefill_engine_num_gpu=2, decode_engine_num_gpu=4)
connector = Mock(spec=[])
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert config.prefill_engine_num_gpu == 2
assert config.decode_engine_num_gpu == 4
def test_virtual_mode_missing_prefill_raises_error(self):
"""Test that missing prefill GPU config raises error in virtual mode"""
config = self._make_config(decode_engine_num_gpu=4)
connector = Mock(spec=[])
with pytest.raises(DeploymentValidationError) as exc_info:
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert "prefill_engine_num_gpu" in str(exc_info.value)
def test_virtual_mode_missing_decode_raises_error(self):
"""Test that missing decode GPU config raises error in virtual mode"""
config = self._make_config(prefill_engine_num_gpu=2)
connector = Mock(spec=[])
with pytest.raises(DeploymentValidationError) as exc_info:
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert "decode_engine_num_gpu" in str(exc_info.value)
def test_virtual_mode_missing_both_raises_error_with_both_messages(self):
"""Test that missing both GPU configs shows both error messages"""
config = self._make_config()
connector = Mock(spec=[])
with pytest.raises(DeploymentValidationError) as exc_info:
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert len(exc_info.value.errors) == 2
def test_virtual_mode_decode_only_no_prefill_error(self):
"""Test decode-only mode doesn't require prefill GPU config"""
config = self._make_config(decode_engine_num_gpu=4)
connector = Mock(spec=[])
_initialize_gpu_counts(
config, connector, require_prefill=False, require_decode=True
)
assert config.decode_engine_num_gpu == 4
def test_kubernetes_mode_fallback_to_cli_on_dgd_error(self):
"""Test that K8s mode falls back to config when DGD parsing fails"""
config = self._make_config(prefill_engine_num_gpu=2, decode_engine_num_gpu=4)
connector = Mock()
connector.get_gpu_counts = Mock(
side_effect=ValueError("No GPU count specified")
)
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert config.prefill_engine_num_gpu == 2
assert config.decode_engine_num_gpu == 4
def test_kubernetes_mode_fallback_missing_cli_flags_raises_error(self):
"""Test that K8s fallback raises error when config also missing"""
config = self._make_config()
connector = Mock()
connector.get_gpu_counts = Mock(
side_effect=ValueError("No GPU count specified")
)
with pytest.raises(DeploymentValidationError) as exc_info:
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert len(exc_info.value.errors) == 2
def test_kubernetes_mode_fallback_partial_cli_flags(self):
"""Test K8s fallback with only one config value provided"""
config = self._make_config(prefill_engine_num_gpu=2)
connector = Mock()
connector.get_gpu_counts = Mock(
side_effect=ValueError("No GPU count specified")
)
with pytest.raises(DeploymentValidationError) as exc_info:
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert "decode_engine_num_gpu" in str(exc_info.value)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Core-only planner tests: TickInput -> PlannerEffects, no mocks."""
import pytest
try:
import msgspec # noqa: F401
except ImportError:
pytest.skip("msgspec required for FPM tests", allow_module_level=True)
from dynamo.common.forward_pass_metrics import (
ForwardPassMetrics,
QueuedRequestMetrics,
ScheduledRequestMetrics,
)
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.state_machine import PlannerStateMachine
from dynamo.planner.core.types import (
EngineCapabilities,
FpmObservations,
ScheduledTick,
TickInput,
TrafficObservation,
WorkerCapabilities,
WorkerCounts,
)
def _tick_for(tick_input: TickInput) -> ScheduledTick:
"""Build a ScheduledTick matching the data present in a TickInput."""
has_fpm = tick_input.fpm_observations is not None
has_traffic = tick_input.traffic is not None
return ScheduledTick(
at_s=tick_input.now_s,
run_load_scaling=has_fpm,
run_throughput_scaling=has_traffic,
need_worker_states=True,
need_worker_fpm=has_fpm,
need_traffic_metrics=has_traffic,
traffic_metrics_duration_s=tick_input.traffic.duration_s
if has_traffic
else 0.0,
)
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.planner,
]
def _make_fpm(
*,
sum_prefill_tokens: int = 0,
num_prefill_requests: int = 0,
sum_decode_kv_tokens: int = 0,
num_decode_requests: int = 0,
queued_prefill_tokens: int = 0,
queued_decode_kv_tokens: int = 0,
wall_time: float = 0.01,
worker_id: str = "w1",
dp_rank: int = 0,
) -> ForwardPassMetrics:
return ForwardPassMetrics(
worker_id=worker_id,
dp_rank=dp_rank,
wall_time=wall_time,
scheduled_requests=ScheduledRequestMetrics(
sum_prefill_tokens=sum_prefill_tokens,
num_prefill_requests=num_prefill_requests,
sum_decode_kv_tokens=sum_decode_kv_tokens,
num_decode_requests=num_decode_requests,
),
queued_requests=QueuedRequestMetrics(
sum_prefill_tokens=queued_prefill_tokens,
sum_decode_kv_tokens=queued_decode_kv_tokens,
),
)
def _make_config(**overrides) -> PlannerConfig:
defaults = dict(
mode="disagg",
ttft=500.0,
itl=50.0,
min_endpoint=1,
max_gpu_budget=-1,
throughput_adjustment_interval=60,
load_adjustment_interval=5,
load_scaling_down_sensitivity=80,
max_num_fpm_samples=50,
fpm_sample_bucket_size=16,
load_min_observations=5,
enable_load_scaling=True,
enable_throughput_scaling=True,
load_predictor="constant",
no_operation=True,
backend="vllm",
metric_pulling_prometheus_endpoint="http://localhost:9090",
metric_reporting_prometheus_port=0,
)
defaults.update(overrides)
return PlannerConfig.model_construct(**defaults)
def _default_caps() -> WorkerCapabilities:
return WorkerCapabilities(
prefill=EngineCapabilities(num_gpu=1, max_num_batched_tokens=2048),
decode=EngineCapabilities(num_gpu=1, max_num_batched_tokens=2048),
)
def _agg_caps() -> WorkerCapabilities:
return WorkerCapabilities(
decode=EngineCapabilities(num_gpu=1, max_num_batched_tokens=2048),
)
def _agg_config(**overrides) -> PlannerConfig:
return _make_config(mode="agg", **overrides)
def _make_core(config=None, caps=None, **config_overrides) -> PlannerStateMachine:
cfg = config or _make_config(**config_overrides)
return PlannerStateMachine(cfg, caps or _default_caps())
def _make_agg_core(config=None, caps=None, **config_overrides) -> PlannerStateMachine:
cfg = config or _agg_config(**config_overrides)
return PlannerStateMachine(cfg, caps or _agg_caps())
def _train_prefill_regression(core: PlannerStateMachine) -> None:
fpms = [
_make_fpm(
sum_prefill_tokens=t, num_prefill_requests=1, wall_time=0.001 * t + 0.002
)
for t in [500, 1000, 1500, 2000, 2500]
]
core.load_benchmark_fpms(prefill_fpms=fpms)
def _train_decode_regression(core: PlannerStateMachine) -> None:
fpms = [
_make_fpm(
sum_decode_kv_tokens=kv,
num_decode_requests=n,
wall_time=0.00001 * kv + 0.001,
)
for n, kv in [(5, 5000), (10, 10000), (20, 20000), (30, 30000), (40, 40000)]
]
core.load_benchmark_fpms(decode_fpms=fpms)
# ── Initial ticks ─────────────────────────────────────────────────────
class TestInitialTick:
def test_both_enabled_returns_earliest(self):
core = _make_core()
tick = core.initial_tick(start_s=100.0)
# Load interval (5s) < throughput interval (60s), so load tick first
assert tick.at_s == 105.0
assert tick.need_worker_fpm
assert not tick.need_traffic_metrics
def test_load_only(self):
core = _make_core(enable_throughput_scaling=False)
tick = core.initial_tick(start_s=0.0)
assert tick.at_s == 5.0
assert tick.need_worker_fpm
assert not tick.need_traffic_metrics
def test_throughput_only(self):
core = _make_core(enable_load_scaling=False)
tick = core.initial_tick(start_s=0.0)
# Load tick is still scheduled (feeds regression) at 5s < 60s
assert tick.at_s == 5.0
assert tick.need_worker_fpm
# ── Load benchmark bootstrapping ──────────────────────────────────────
class TestBenchmarkBootstrap:
def test_prefill_regression_bootstrapped(self):
core = _make_core(mode="prefill")
_train_prefill_regression(core)
assert core.prefill_regression.has_sufficient_data()
def test_decode_regression_bootstrapped(self):
core = _make_core(mode="decode")
_train_decode_regression(core)
assert core.decode_regression.has_sufficient_data()
# ── FPM observation via on_tick ───────────────────────────────────────
class TestFpmObservation:
def test_fpm_feeds_regression(self):
core = _make_core(mode="prefill")
assert core.prefill_regression.num_observations == 0
fpm = _make_fpm(sum_prefill_tokens=500, num_prefill_requests=1, wall_time=0.5)
tick = TickInput(
now_s=5.0,
fpm_observations=FpmObservations(prefill={("w1", 0): fpm}),
worker_counts=WorkerCounts(ready_num_prefill=1),
)
core.on_tick(_tick_for(tick), tick)
assert core.prefill_regression.num_observations == 1
def test_next_tick_scheduled_after_fpm(self):
core = _make_core(mode="prefill")
tick = TickInput(
now_s=10.0,
fpm_observations=FpmObservations(
prefill={
("w1", 0): _make_fpm(
sum_prefill_tokens=500,
num_prefill_requests=1,
wall_time=0.5,
)
}
),
worker_counts=WorkerCounts(ready_num_prefill=1),
)
effects = core.on_tick(_tick_for(tick), tick)
assert effects.next_tick is not None
assert effects.next_tick.at_s == 15.0
assert effects.next_tick.need_worker_fpm
# ── Load-based scaling (prefill) ──────────────────────────────────────
class TestPrefillLoadScaling:
def test_scale_up_when_all_above_sla(self):
core = _make_core(mode="prefill", ttft=5.0)
_train_prefill_regression(core)
fpm = _make_fpm(
worker_id="w1",
queued_prefill_tokens=10000,
sum_prefill_tokens=500,
num_prefill_requests=1,
wall_time=0.5,
)
tick = TickInput(
now_s=5.0,
fpm_observations=FpmObservations(prefill={("w1", 0): fpm}),
worker_counts=WorkerCounts(ready_num_prefill=1),
)
effects = core.on_tick(_tick_for(tick), tick)
assert effects.scale_to is not None
assert effects.scale_to.num_prefill is not None
assert effects.scale_to.num_prefill > 1
def test_no_scaling_when_insufficient_data(self):
core = _make_core(mode="prefill")
fpm = _make_fpm(
queued_prefill_tokens=5000, sum_prefill_tokens=100, wall_time=0.1
)
tick = TickInput(
now_s=5.0,
fpm_observations=FpmObservations(prefill={("w1", 0): fpm}),
worker_counts=WorkerCounts(ready_num_prefill=1),
)
effects = core.on_tick(_tick_for(tick), tick)
assert effects.scale_to is None
def test_no_scaling_when_load_disabled(self):
core = _make_core(mode="prefill", enable_load_scaling=False)
_train_prefill_regression(core)
fpm = _make_fpm(
queued_prefill_tokens=10000,
sum_prefill_tokens=500,
num_prefill_requests=1,
wall_time=0.5,
)
tick = TickInput(
now_s=5.0,
fpm_observations=FpmObservations(prefill={("w1", 0): fpm}),
worker_counts=WorkerCounts(ready_num_prefill=1),
)
effects = core.on_tick(_tick_for(tick), tick)
assert effects.scale_to is None
# ── Load-based scaling (decode) ───────────────────────────────────────
class TestDecodeLoadScaling:
def test_scale_up_when_all_above_sla(self):
core = _make_core(mode="decode", itl=5.0)
_train_decode_regression(core)
fpm = _make_fpm(
worker_id="w1",
sum_decode_kv_tokens=30000,
queued_decode_kv_tokens=20000,
num_decode_requests=30,
wall_time=0.3,
)
tick = TickInput(
now_s=5.0,
fpm_observations=FpmObservations(decode={("w1", 0): fpm}),
worker_counts=WorkerCounts(ready_num_decode=1),
)
effects = core.on_tick(_tick_for(tick), tick)
assert effects.scale_to is not None
assert effects.scale_to.num_decode is not None
assert effects.scale_to.num_decode > 1
# ── Disagg load scaling ───────────────────────────────────────────────
class TestDisaggLoadScaling:
def test_disagg_scale_up(self):
core = _make_core(ttft=5.0, itl=5.0)
_train_prefill_regression(core)
_train_decode_regression(core)
p_fpm = _make_fpm(
worker_id="w1",
queued_prefill_tokens=10000,
sum_prefill_tokens=500,
num_prefill_requests=1,
wall_time=0.5,
)
d_fpm = _make_fpm(
worker_id="w1",
sum_decode_kv_tokens=5000,
queued_decode_kv_tokens=3000,
num_decode_requests=20,
wall_time=0.6,
)
tick = TickInput(
now_s=5.0,
fpm_observations=FpmObservations(
prefill={("w1", 0): p_fpm},
decode={("w1", 0): d_fpm},
),
worker_counts=WorkerCounts(ready_num_prefill=1, ready_num_decode=1),
)
effects = core.on_tick(_tick_for(tick), tick)
assert effects.scale_to is not None
# ── Throughput scaling ────────────────────────────────────────────────
class TestThroughputScaling:
def test_throughput_only_returns_decision(self):
core = _make_core(
mode="prefill", enable_load_scaling=False, enable_throughput_scaling=True
)
_train_prefill_regression(core)
# Warm predictor with traffic
core._observe_traffic(
TrafficObservation(duration_s=60, num_req=100, isl=1000, osl=150)
)
tick = TickInput(
now_s=60.0,
traffic=TrafficObservation(duration_s=60, num_req=100, isl=1000, osl=150),
worker_counts=WorkerCounts(ready_num_prefill=1),
)
effects = core.on_tick(_tick_for(tick), tick)
assert effects.scale_to is not None
assert effects.scale_to.num_prefill is not None
assert effects.scale_to.num_prefill >= 1
def test_throughput_sets_lower_bound_when_load_enabled(self):
core = _make_core(enable_load_scaling=True, enable_throughput_scaling=True)
_train_prefill_regression(core)
_train_decode_regression(core)
core._observe_traffic(
TrafficObservation(duration_s=60, num_req=100, isl=1000, osl=150)
)
tick = TickInput(
now_s=60.0,
traffic=TrafficObservation(duration_s=60, num_req=100, isl=1000, osl=150),
worker_counts=WorkerCounts(ready_num_prefill=1, ready_num_decode=1),
)
effects = core.on_tick(_tick_for(tick), tick)
# When both modes enabled, throughput tick returns None (just sets lower bound)
assert effects.scale_to is None
assert core._throughput_lower_bound_p >= 1
assert core._throughput_lower_bound_d >= 1
def test_next_tick_scheduled_after_traffic(self):
core = _make_core(mode="prefill")
tick = TickInput(
now_s=60.0,
traffic=TrafficObservation(duration_s=60, num_req=0, isl=0, osl=0),
)
effects = core.on_tick(_tick_for(tick), tick)
assert effects.next_tick is not None
assert effects.next_tick.need_traffic_metrics
assert effects.next_tick.at_s == 120.0
# ── FPM reconciliation ───────────────────────────────────────────────
class TestFpmReconciliation:
def test_mismatch_skips_scaling(self):
core = _make_core(mode="prefill", ttft=5.0)
_train_prefill_regression(core)
tick = TickInput(
now_s=5.0,
fpm_observations=FpmObservations(
prefill={
("w1", 0): _make_fpm(
queued_prefill_tokens=10000,
sum_prefill_tokens=500,
num_prefill_requests=1,
wall_time=0.5,
),
("w2", 0): _make_fpm(
worker_id="w2",
queued_prefill_tokens=8000,
sum_prefill_tokens=500,
num_prefill_requests=1,
wall_time=0.5,
),
}
),
worker_counts=WorkerCounts(ready_num_prefill=3),
)
effects = core.on_tick(_tick_for(tick), tick)
# FPM reports 2 workers but ready count is 3 -> skip scaling
assert effects.scale_to is None
# ── Agg planner core ──────────────────────────────────────────────────
class TestAggPlannerStateMachine:
def _train_agg(self, core: PlannerStateMachine) -> None:
fpms = [
_make_fpm(
sum_prefill_tokens=p,
num_prefill_requests=1,
sum_decode_kv_tokens=d,
num_decode_requests=10,
wall_time=0.001 * p + 0.0001 * d + 0.001,
)
for p, d in [
(100, 1000),
(200, 2000),
(300, 3000),
(400, 4000),
(500, 5000),
]
]
core.load_benchmark_fpms(agg_fpms=fpms)
def test_initial_tick(self):
core = _make_agg_core()
tick = core.initial_tick(start_s=0.0)
assert tick.at_s == 5.0
assert tick.need_worker_fpm
def test_fpm_feeds_regression(self):
core = _make_agg_core()
assert core.regression.num_observations == 0
fpm = _make_fpm(
sum_prefill_tokens=200,
num_prefill_requests=1,
sum_decode_kv_tokens=2000,
num_decode_requests=10,
wall_time=0.3,
)
tick = TickInput(
now_s=5.0,
fpm_observations=FpmObservations(decode={("w1", 0): fpm}),
worker_counts=WorkerCounts(ready_num_decode=1),
)
core.on_tick(_tick_for(tick), tick)
assert core.regression.num_observations == 1
def test_throughput_only_returns_decision(self):
core = _make_agg_core(enable_load_scaling=False, enable_throughput_scaling=True)
self._train_agg(core)
core._observe_traffic(
TrafficObservation(duration_s=60, num_req=100, isl=1000, osl=150)
)
tick = TickInput(
now_s=60.0,
traffic=TrafficObservation(duration_s=60, num_req=100, isl=1000, osl=150),
worker_counts=WorkerCounts(ready_num_decode=1),
)
effects = core.on_tick(_tick_for(tick), tick)
assert effects.scale_to is not None
assert effects.scale_to.num_decode is not None
assert effects.scale_to.num_decode >= 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment