"vscode:/vscode.git/clone" did not exist on "122777c8aaef748d6756095fc8ea35e31cc1094b"
Unverified Commit f7e0b3fd authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

refactor(planner): extract discrete-event state machine with explicit inputs (#8046)

parent 39a6a240
......@@ -21,10 +21,12 @@ from typing import Union
from pydantic import BaseModel
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.agg import AggPlanner
from dynamo.planner.core.decode import DecodePlanner
from dynamo.planner.core.disagg import DisaggPlanner
from dynamo.planner.core.prefill import PrefillPlanner
from dynamo.planner.core.adapters import (
AggPlanner,
DecodePlanner,
DisaggPlanner,
PrefillPlanner,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
logger = logging.getLogger(__name__)
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.planner.core.state_machine import PlannerStateMachine
from dynamo.planner.core.types import (
EngineCapabilities,
FpmObservations,
PlannerEffects,
ScalingDecision,
ScheduledTick,
TickInput,
TrafficObservation,
WorkerCapabilities,
WorkerCounts,
)
__all__ = [
"EngineCapabilities",
"FpmObservations",
"PlannerEffects",
"PlannerStateMachine",
"ScalingDecision",
"ScheduledTick",
"TickInput",
"TrafficObservation",
"WorkerCapabilities",
"WorkerCounts",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Native planner adapter subclasses (one per mode).
Each subclass sets ``require_prefill`` / ``require_decode`` and overrides
``_bootstrap_regression()`` and ``_apply_effects()``. Everything else
(connector, Prometheus, FPM subscribers, tick loop) is in ``NativePlannerBase``.
"""
import logging
from dynamo.planner.config.defaults import SubComponentType, TargetReplica
from dynamo.planner.core.base import NativePlannerBase
from dynamo.planner.core.types import PlannerEffects
from dynamo.planner.monitoring.perf_metrics import fetch_pre_deployment_metrics
logger = logging.getLogger(__name__)
class PrefillPlanner(NativePlannerBase):
"""Prefill-only mode."""
require_prefill = True
require_decode = False
async def _bootstrap_regression(self) -> None:
try:
fpms = await fetch_pre_deployment_metrics(
runtime=self.runtime,
namespace=self.namespace,
worker_info=self.prefill_worker_info,
profile_results_dir=self.config.profile_results_dir,
component_type=SubComponentType.PREFILL,
)
self.state_machine.load_benchmark_fpms(prefill_fpms=fpms)
except Exception as e:
if self.config.enable_throughput_scaling:
raise
logger.warning(f"No pre-deployment data for prefill: {e}")
async def _apply_effects(self, effects: PlannerEffects) -> None:
if effects.scale_to is None or effects.scale_to.num_prefill is None:
return
desired = effects.scale_to.num_prefill
if self.prometheus_port != 0:
self.prometheus_metrics.predicted_num_p.set(desired)
await self._apply_scaling_targets(
[
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_worker_info.k8s_name,
desired_replicas=desired,
)
]
)
class DecodePlanner(NativePlannerBase):
"""Decode-only mode."""
require_prefill = False
require_decode = True
async def _bootstrap_regression(self) -> None:
try:
fpms = await fetch_pre_deployment_metrics(
runtime=self.runtime,
namespace=self.namespace,
worker_info=self.decode_worker_info,
profile_results_dir=self.config.profile_results_dir,
component_type=SubComponentType.DECODE,
)
self.state_machine.load_benchmark_fpms(decode_fpms=fpms)
except Exception as e:
if self.config.enable_throughput_scaling:
raise
logger.warning(f"No pre-deployment data for decode: {e}")
async def _apply_effects(self, effects: PlannerEffects) -> None:
if effects.scale_to is None or effects.scale_to.num_decode is None:
return
desired = effects.scale_to.num_decode
if self.prometheus_port != 0:
self.prometheus_metrics.predicted_num_d.set(desired)
await self._apply_scaling_targets(
[
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.decode_worker_info.k8s_name,
desired_replicas=desired,
)
]
)
class AggPlanner(NativePlannerBase):
"""Aggregated mode (single engine type handles both prefill and decode)."""
require_prefill = False
require_decode = True
async def _bootstrap_regression(self) -> None:
try:
fpms = await fetch_pre_deployment_metrics(
runtime=self.runtime,
namespace=self.namespace,
worker_info=self.decode_worker_info,
profile_results_dir=self.config.profile_results_dir,
component_type=SubComponentType.DECODE,
)
self.state_machine.load_benchmark_fpms(agg_fpms=fpms)
except Exception as e:
if self.config.enable_throughput_scaling:
raise
logger.warning(f"No pre-deployment data for agg: {e}")
async def _apply_effects(self, effects: PlannerEffects) -> None:
if effects.scale_to is None or effects.scale_to.num_decode is None:
return
desired = effects.scale_to.num_decode
if self.prometheus_port != 0:
self.prometheus_metrics.predicted_num_d.set(desired)
await self._apply_scaling_targets(
[
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.decode_worker_info.k8s_name,
desired_replicas=desired,
)
]
)
class DisaggPlanner(NativePlannerBase):
"""Disaggregated mode (separate prefill and decode engines)."""
require_prefill = True
require_decode = True
async def _bootstrap_regression(self) -> None:
for component, kwarg in [
(SubComponentType.PREFILL, "prefill_fpms"),
(SubComponentType.DECODE, "decode_fpms"),
]:
worker_info = (
self.prefill_worker_info
if component == SubComponentType.PREFILL
else self.decode_worker_info
)
try:
fpms = await fetch_pre_deployment_metrics(
runtime=self.runtime,
namespace=self.namespace,
worker_info=worker_info,
profile_results_dir=self.config.profile_results_dir,
component_type=component,
)
self.state_machine.load_benchmark_fpms(**{kwarg: fpms})
except Exception as e:
if self.config.enable_throughput_scaling:
raise
logger.warning(f"No pre-deployment data for {component.value}: {e}")
async def _apply_effects(self, effects: PlannerEffects) -> None:
if effects.scale_to is None:
return
decision = effects.scale_to
if decision.num_prefill is not None and self.prometheus_port != 0:
self.prometheus_metrics.predicted_num_p.set(decision.num_prefill)
if decision.num_decode is not None and self.prometheus_port != 0:
self.prometheus_metrics.predicted_num_d.set(decision.num_decode)
targets = []
if decision.num_prefill is not None:
targets.append(
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_worker_info.k8s_name,
desired_replicas=decision.num_prefill,
)
)
if decision.num_decode is not None:
targets.append(
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.decode_worker_info.k8s_name,
desired_replicas=decision.num_decode,
)
)
await self._apply_scaling_targets(targets)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import math
import time
from typing import TYPE_CHECKING, Optional
from dynamo.planner.config.backend_components import WORKER_COMPONENT_NAMES
from dynamo.planner.config.defaults import SubComponentType, TargetReplica
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.base import BasePlanner
from dynamo.planner.core.budget import (
_apply_component_gpu_budget,
_initialize_gpu_counts,
)
from dynamo.planner.core.perf_model import AggRegressionModel
from dynamo.planner.core.state import PlannerSharedState
from dynamo.planner.monitoring.perf_metrics import fetch_pre_deployment_metrics
from dynamo.planner.monitoring.planner_metrics import PlannerPrometheusMetrics
from dynamo.runtime import DistributedRuntime
if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class AggPlanner:
"""Aggregated planner: FPM-driven scaling for single engine type.
In aggregated mode, engines handle both prefill and decode (chunked prefill).
A single AggRegressionModel maps (sum_prefill_tokens, sum_decode_kv_tokens)
to wall_time using 2D linear regression.
Supports load-only, throughput-only, or both scaling modes.
Scaling logic (load-based):
- Estimate next TTFT per engine by simulating prefill chunking with
piggybacked decode (steady-state decode load).
- Estimate next ITL per engine by predicting decode iteration time with
average piggybacked prefill load.
- Scale up if (ALL TTFT > SLA) OR (ALL ITL > SLA).
- Scale down if (ALL TTFT < SLA * sensitivity) AND (ALL ITL < SLA * sensitivity).
Scaling logic (throughput-based):
- Use compute_agg_replicas() to find minimum replicas where both SLAs
are met under predicted traffic load.
"""
def __init__(self, runtime: DistributedRuntime, config: PlannerConfig) -> None:
self.config = config
self.runtime = runtime
self.shared_state = PlannerSharedState()
self.enable_throughput = config.enable_throughput_scaling
self.enable_load = config.enable_load_scaling
if not self.enable_throughput and not self.enable_load:
raise ValueError(
"Aggregated planner requires at least one scaling mode enabled."
)
prometheus_metrics = PlannerPrometheusMetrics()
self.planner = BasePlanner(
runtime,
config,
shared_state=self.shared_state,
prometheus_metrics=prometheus_metrics,
start_prometheus_server=True,
component_type=SubComponentType.DECODE,
)
self.regression = AggRegressionModel(
max_num_fpm_samples=config.max_num_fpm_samples,
min_observations=config.load_min_observations,
bucket_count=config.fpm_sample_bucket_size,
)
async def _async_init(self):
defaults = WORKER_COMPONENT_NAMES.get(self.config.backend)
if not self.config.no_operation:
connector = getattr(self.planner, "connector", None)
if connector and hasattr(connector, "_async_init"):
await connector._async_init()
logger.info("Validating deployment...")
await self.planner.connector.validate_deployment(
prefill_component_name=None,
decode_component_name=(
defaults.decode_worker_k8s_name if defaults else None
),
require_prefill=False,
require_decode=True,
)
logger.info("Successfully validated the deployment")
_initialize_gpu_counts(
self.config,
self.planner.connector,
require_prefill=False,
require_decode=True,
)
await self.planner.connector.wait_for_deployment_ready(
include_planner=False
)
await self.planner._init_worker_info(require_prefill=False, require_decode=True)
if self.runtime is not None:
await self.planner._init_fpm_subscriber()
await self._bootstrap_regression()
async def _bootstrap_regression(self) -> None:
"""Bootstrap agg regression from pre-deployment benchmark data."""
worker_info = self.planner.decode_worker_info
try:
fpms = await fetch_pre_deployment_metrics(
runtime=self.runtime,
namespace=self.config.namespace,
worker_info=worker_info,
profile_results_dir=self.config.profile_results_dir,
component_type=SubComponentType.DECODE,
)
self.regression.load_benchmark_fpms(fpms)
logger.info(
f"Bootstrapped agg regression with {len(fpms)} pre-deployment FPMs"
)
except Exception as e:
if self.enable_throughput:
raise
logger.warning(
f"No pre-deployment data for agg regression: {e}. "
"Load-based scaling will learn from live FPM only."
)
async def run(self):
"""Main scaling loop. Call _async_init() before this."""
self.shared_state.last_adjustment_time = time.time()
loops = []
if self.enable_throughput:
loops.append(self._throughput_loop())
loops.append(self._load_and_fpm_update_loop())
await asyncio.gather(*loops)
async def _throughput_loop(self) -> None:
"""Throughput-based scaling loop for agg mode."""
while True:
current_time = time.time()
if (
current_time - self.shared_state.last_adjustment_time
>= self.config.throughput_adjustment_interval
):
self.shared_state.last_adjustment_time = time.time()
logger.info("New agg throughput adjustment interval started!")
await self.planner.observe_traffic_stats(
require_prefill=False, require_decode=True
)
metrics = self.shared_state.last_metrics
if not metrics.is_valid():
logger.info("Metrics invalid, skipping agg throughput adjustment")
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
continue
next_num_req = self.planner.num_req_predictor.predict_next()
next_isl = self.planner.isl_predictor.predict_next()
next_osl = self.planner.osl_predictor.predict_next()
max_num_batched_tokens = getattr(
self.planner.decode_worker_info, "max_num_batched_tokens", None
)
if not max_num_batched_tokens or max_num_batched_tokens <= 0:
logger.warning(
"max_num_batched_tokens not available, skipping agg throughput"
)
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
continue
(
engine_rps,
actual_ttft,
actual_itl,
) = self.regression.find_best_engine_agg_rps(
isl=next_isl,
osl=next_osl,
max_num_batched_tokens=max_num_batched_tokens,
ttft_sla=self.config.ttft,
itl_sla=self.config.itl,
)
if engine_rps <= 0:
logger.warning(
"Agg perf model not ready, skipping throughput scaling"
)
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
continue
if actual_ttft > self.config.ttft or actual_itl > self.config.itl:
logger.warning(
f"Agg SLA not fully met: TTFT={actual_ttft:.1f}ms "
f"(target {self.config.ttft:.1f}ms), "
f"ITL={actual_itl:.1f}ms (target {self.config.itl:.1f}ms), "
"scaling with best achievable rate"
)
demand_rps = next_num_req / self.config.throughput_adjustment_interval
desired = math.ceil(demand_rps / engine_rps)
desired = max(desired, self.config.min_endpoint)
logger.info(
f"Agg: {demand_rps:.2f}(demand rps) / "
f"{engine_rps:.2f}(engine rps) = {desired}(replicas), "
f"est_ttft={actual_ttft:.1f}ms, est_itl={actual_itl:.1f}ms"
)
if self.enable_load:
self.shared_state.throughput_lower_bound_d = desired
logger.info(f"Agg throughput lower bound set to {desired}")
else:
assert self.config.decode_engine_num_gpu is not None
desired = _apply_component_gpu_budget(
desired, self.config.decode_engine_num_gpu, self.config
)
if (
self.planner.prometheus_port != 0
and self.planner.prometheus_metrics is not None
):
self.planner.prometheus_metrics.predicted_num_d.set(desired)
if not self.config.no_operation:
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.planner.decode_worker_info.k8s_name,
desired_replicas=desired,
)
]
await self.planner.connector.set_component_replicas(
target_replicas, blocking=False
)
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
async def _load_and_fpm_update_loop(self) -> None:
"""FPM observation and (optionally) load-based scaling for agg mode.
Always updates regression with live FPM. When load-based scaling
is enabled, makes scaling decisions immediately after.
"""
pending_desired: Optional[int] = None
while True:
await asyncio.sleep(self.config.load_adjustment_interval)
logger.info("New agg load/FPM update interval started!")
_, num_d, _ = await self.planner.get_workers_info(
require_prefill=False, require_decode=True
)
self.shared_state.num_d_workers = num_d
num_workers = num_d
fpm_stats = self.planner._get_fpm_stats()
if not fpm_stats:
continue
for (wid, dp), fpm in fpm_stats.items():
BasePlanner._log_fpm(wid, dp, fpm, "agg")
self.regression.add_observation(fpm)
if not self.enable_load:
continue
if pending_desired is not None:
if num_workers == pending_desired:
logger.info(
f"Scaling to {pending_desired} complete, resuming decisions"
)
pending_desired = None
else:
logger.info(
f"Scaling in progress ({num_workers} -> {pending_desired}), "
"observing only"
)
continue
if not BasePlanner._reconcile_fpm_worker_count(
fpm_stats, num_workers, "agg"
):
continue
if not self.regression.has_sufficient_data():
logger.info(
f"Agg regression: insufficient data "
f"({self.regression.num_observations}/{self.regression.min_observations})"
)
continue
max_num_batched_tokens = getattr(
self.planner.decode_worker_info, "max_num_batched_tokens", None
)
if not max_num_batched_tokens or max_num_batched_tokens <= 0:
logger.warning(
"max_num_batched_tokens not available from WorkerInfo, "
"skipping agg scaling"
)
continue
p_desired = self._prefill_scaling_decision(
fpm_stats, num_workers, max_num_batched_tokens
)
d_desired = self._decode_scaling_decision(fpm_stats, num_workers)
logger.info(
f"Agg scaling decisions: prefill={p_desired}, decode={d_desired} "
f"(current={num_workers})"
)
if p_desired is not None and p_desired > num_workers:
desired = p_desired
elif d_desired is not None and d_desired > num_workers:
desired = d_desired
elif (
p_desired is not None
and p_desired < num_workers
and d_desired is not None
and d_desired < num_workers
):
desired = max(p_desired, d_desired)
else:
logger.info("Agg scaling: no scaling needed")
continue
desired = max(desired, self.config.min_endpoint)
if self.enable_throughput:
desired = max(desired, self.shared_state.throughput_lower_bound_d)
assert self.config.decode_engine_num_gpu is not None
desired = _apply_component_gpu_budget(
desired, self.config.decode_engine_num_gpu, self.config
)
logger.info(f"Agg load-based scaling: {num_workers} -> {desired}")
if (
self.planner.prometheus_port != 0
and self.planner.prometheus_metrics is not None
):
self.planner.prometheus_metrics.predicted_num_d.set(desired)
if not self.config.no_operation:
pending_desired = desired
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.planner.decode_worker_info.k8s_name,
desired_replicas=desired,
)
]
await self.planner.connector.set_component_replicas(
target_replicas, blocking=False
)
def _prefill_scaling_decision(
self,
fpm_stats: "dict[tuple[str, int], ForwardPassMetrics]",
num_workers: int,
max_num_batched_tokens: int,
) -> Optional[int]:
estimated_ttfts: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
est = self.regression.estimate_next_ttft(
queued_prefill_tokens=fpm.queued_requests.sum_prefill_tokens,
max_num_batched_tokens=max_num_batched_tokens,
current_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
)
if est is not None:
estimated_ttfts.append(est * 1000)
return self.planner._load_based_scaling_decision_from_estimates(
estimated_ttfts, self.config.ttft, num_workers, "agg TTFT"
)
def _decode_scaling_decision(
self,
fpm_stats: "dict[tuple[str, int], ForwardPassMetrics]",
num_workers: int,
) -> Optional[int]:
estimated_itls: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
est = self.regression.estimate_next_itl(
scheduled_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
queued_decode_kv=fpm.queued_requests.sum_decode_kv_tokens,
)
if est is not None:
estimated_itls.append(est * 1000)
return self.planner._load_based_scaling_decision_from_estimates(
estimated_itls, self.config.itl, num_workers, "agg ITL"
)
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import math
from typing import Optional
from dynamo.planner.config.defaults import SubComponentType
from dynamo.planner.core.base import BasePlanner
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class DecodePlanner(BasePlanner):
component_type = SubComponentType.DECODE
def load_plan_adjustment(self) -> Optional[int]:
"""Load-based scaling decision for decode using FPM data.
For each engine, estimates next decode ITL:
- Uses scheduled + queued decode KV tokens + avg decode length
- Predicts wall time via regression
Scale up if ALL engines' estimated ITL > SLA.
Scale down if ALL engines' estimated ITL < SLA * sensitivity.
"""
if not self.itl_regression.has_sufficient_data():
logger.info(
f"ITL regression: insufficient data ({self.itl_regression.num_observations}"
f"/{self.itl_regression.min_observations}), skipping load-based scaling"
)
return None
fpm_stats = self._get_fpm_stats()
if not fpm_stats:
return None
num_workers = self.shared_state.num_d_workers
if num_workers == 0:
return None
estimated_itls: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
scheduled_kv = fpm.scheduled_requests.sum_decode_kv_tokens
queued_kv = fpm.queued_requests.sum_decode_kv_tokens
est = self.itl_regression.estimate_next_itl(
scheduled_decode_kv=scheduled_kv,
queued_decode_kv=queued_kv,
)
if est is None:
continue
est_ms = est * 1000
estimated_itls.append(est_ms)
logger.info(
f"Decode engine {wid}:dp{dp}: estimated ITL {est_ms:.2f}ms "
f"(sched_kv={scheduled_kv}, queued_kv={queued_kv}, "
f"avg_decode_len={self.itl_regression.avg_decode_length:.1f})"
)
return self._load_based_scaling_decision_from_estimates(
estimates=estimated_itls,
sla=self.config.itl,
num_workers=num_workers,
label="decode ITL",
)
def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float
) -> Optional[int]:
demand_rps = next_num_req / self.config.throughput_adjustment_interval
engine_rps, actual_itl_ms = self.itl_regression.find_best_engine_decode_rps(
itl=self.config.itl,
context_length=next_isl + next_osl / 2,
osl=next_osl,
)
if engine_rps <= 0:
logger.warning("Decode perf model not ready, skipping throughput scaling")
return None
if actual_itl_ms > self.config.itl:
logger.warning(
f"Decode ITL SLA not met: {actual_itl_ms:.1f}ms > "
f"{self.config.itl:.1f}ms, scaling with best achievable rate"
)
next_num_d = math.ceil(demand_rps / engine_rps)
next_num_d = max(next_num_d, self.config.min_endpoint)
logger.info(
f"Decode: {demand_rps:.2f}(demand rps) / "
f"{engine_rps:.2f}(engine rps) = {next_num_d}(num_d), "
f"est_itl={actual_itl_ms:.1f}ms"
)
return next_num_d
def update_predicted_replicas_metric(self, desired_replicas: int) -> None:
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.predicted_num_d.set(desired_replicas)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import time
from dynamo.planner.config.backend_components import WORKER_COMPONENT_NAMES
from dynamo.planner.config.defaults import SubComponentType, TargetReplica
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.base import BasePlanner
from dynamo.planner.core.budget import _apply_global_gpu_budget, _initialize_gpu_counts
from dynamo.planner.core.decode import DecodePlanner
from dynamo.planner.core.prefill import PrefillPlanner
from dynamo.planner.core.state import PlannerSharedState
from dynamo.planner.monitoring.planner_metrics import PlannerPrometheusMetrics
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class DisaggPlanner:
def __init__(self, runtime: DistributedRuntime, config: PlannerConfig) -> None:
self.config = config
self.shared_state = PlannerSharedState()
prometheus_metrics = PlannerPrometheusMetrics()
self.enable_throughput = config.enable_throughput_scaling
self.enable_load = config.enable_load_scaling
self.prefill_planner = PrefillPlanner(
runtime,
config,
shared_state=self.shared_state,
prometheus_metrics=prometheus_metrics,
start_prometheus_server=True,
)
self.decode_planner = DecodePlanner(
runtime,
config,
shared_state=self.shared_state,
prometheus_metrics=prometheus_metrics,
prometheus_traffic_client=getattr(
self.prefill_planner, "prometheus_traffic_client", None
),
connector=getattr(self.prefill_planner, "connector", None),
start_prometheus_server=False,
)
async def _async_init(self):
# DisaggPlanner overrides _async_init to handle both prefill+decode
# and share WorkerInfo between the two sub-planners.
defaults = WORKER_COMPONENT_NAMES.get(self.config.backend)
if not self.config.no_operation:
# Connector init (prefill/decode share the same connector)
connector = getattr(self.prefill_planner, "connector", None)
if connector and hasattr(connector, "_async_init"):
await connector._async_init()
logger.info("Validating deployment...")
await self.prefill_planner.connector.validate_deployment(
prefill_component_name=(
defaults.prefill_worker_k8s_name if defaults else None
),
decode_component_name=(
defaults.decode_worker_k8s_name if defaults else None
),
require_prefill=True,
require_decode=True,
)
logger.info("Successfully validated the deployment")
_initialize_gpu_counts(
self.config,
self.prefill_planner.connector,
require_prefill=True,
require_decode=True,
)
await self.prefill_planner.connector.wait_for_deployment_ready(
include_planner=False
)
await self.prefill_planner._init_worker_info(
require_prefill=True, require_decode=True
)
# Share WorkerInfo and model name with decode planner
self.decode_planner.prefill_worker_info = (
self.prefill_planner.prefill_worker_info
)
self.decode_planner.decode_worker_info = self.prefill_planner.decode_worker_info
self.decode_planner.model_name = self.prefill_planner.model_name
if self.prefill_planner.runtime is not None:
await self.prefill_planner._init_fpm_subscriber()
if self.decode_planner.runtime is not None:
await self.decode_planner._init_fpm_subscriber()
await self.prefill_planner._bootstrap_regression()
await self.decode_planner._bootstrap_regression()
async def run(self):
"""Main scaling loop. Call _async_init() before this."""
self.shared_state.last_adjustment_time = time.time()
self.shared_state.last_load_adjustment_time = time.time()
loops = []
if self.enable_throughput:
loops.append(self._throughput_loop())
loops.append(self._load_and_fpm_update_loop())
await asyncio.gather(*loops)
async def _throughput_loop(self) -> None:
"""Throughput-based scaling loop for disagg mode."""
while True:
current_time = time.time()
if (
current_time - self.shared_state.last_adjustment_time
>= self.config.throughput_adjustment_interval
):
self.shared_state.last_adjustment_time = time.time()
logger.info("New throughput adjustment interval started!")
await self.prefill_planner.observe_traffic_stats(
require_prefill=True, require_decode=True
)
self.decode_planner.update_predictors_from_metrics(
self.shared_state.last_metrics
)
next_num_p = self.prefill_planner.plan_adjustment()
next_num_d = self.decode_planner.plan_adjustment()
if next_num_p is None or next_num_d is None:
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
continue
if self.enable_load:
# When load-based is also enabled: just set lower bounds
self.shared_state.throughput_lower_bound_p = next_num_p
self.shared_state.throughput_lower_bound_d = next_num_d
logger.info(
f"Throughput lower bounds set: prefill={next_num_p}, decode={next_num_d}"
)
else:
# Throughput-only: apply scaling directly
next_num_p, next_num_d = _apply_global_gpu_budget(
next_num_p, next_num_d, self.config
)
self.prefill_planner.update_predicted_replicas_metric(next_num_p)
self.decode_planner.update_predicted_replicas_metric(next_num_d)
if not self.config.no_operation:
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_planner.prefill_worker_info.k8s_name,
desired_replicas=next_num_p,
),
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.prefill_planner.decode_worker_info.k8s_name,
desired_replicas=next_num_d,
),
]
await self.prefill_planner.connector.set_component_replicas(
target_replicas, blocking=False
)
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
async def _load_and_fpm_update_loop(self) -> None:
"""FPM observation and (optionally) load-based scaling for disagg mode.
Always updates regression models with live FPM. When load-based
scaling is enabled, makes scaling decisions immediately after.
"""
while True:
await asyncio.sleep(self.config.load_adjustment_interval)
logger.info("New load/FPM update interval started!")
num_p, num_d, _ = await self.prefill_planner.get_workers_info(
require_prefill=True, require_decode=True
)
self.shared_state.num_p_workers = num_p
self.shared_state.num_d_workers = num_d
p_stats = self.prefill_planner.observe_fpm_load_stats()
d_stats = self.decode_planner.observe_fpm_load_stats()
if not self.enable_load:
continue
if not p_stats and not d_stats:
logger.warning("No FPM data for either prefill or decode, skipping")
continue
if p_stats and not BasePlanner._reconcile_fpm_worker_count(
p_stats, num_p, "prefill"
):
continue
if d_stats and not BasePlanner._reconcile_fpm_worker_count(
d_stats, num_d, "decode"
):
continue
p_desired = self.prefill_planner.load_plan_adjustment()
d_desired = self.decode_planner.load_plan_adjustment()
final_p = (
p_desired if p_desired is not None else self.shared_state.num_p_workers
)
final_d = (
d_desired if d_desired is not None else self.shared_state.num_d_workers
)
if (
final_p == self.shared_state.num_p_workers
and final_d == self.shared_state.num_d_workers
):
logger.info("Load-based scaling: no scaling needed")
continue
if self.enable_throughput:
final_p = max(final_p, self.shared_state.throughput_lower_bound_p)
final_d = max(final_d, self.shared_state.throughput_lower_bound_d)
final_p = max(final_p, self.config.min_endpoint)
final_d = max(final_d, self.config.min_endpoint)
final_p, final_d = _apply_global_gpu_budget(final_p, final_d, self.config)
logger.info(
f"Load-based disagg scaling: prefill {self.shared_state.num_p_workers}->{final_p}, "
f"decode {self.shared_state.num_d_workers}->{final_d}"
)
self.prefill_planner.update_predicted_replicas_metric(final_p)
self.decode_planner.update_predicted_replicas_metric(final_d)
if not self.config.no_operation:
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_planner.prefill_worker_info.k8s_name,
desired_replicas=final_p,
),
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.prefill_planner.decode_worker_info.k8s_name,
desired_replicas=final_d,
),
]
await self.prefill_planner.connector.set_component_replicas(
target_replicas, blocking=True
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# mypy: disable-error-code="attr-defined"
"""Load-based scaling logic (FPM-driven, reactive).
Mixin consumed by ``PlannerStateMachine``. All methods access state
via ``self._config``, ``self._capabilities``, and regression models.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Optional
from dynamo.planner.core.types import FpmObservations, ScalingDecision
if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
logger = logging.getLogger(__name__)
class LoadScalingMixin:
"""FPM-driven load-based scaling decisions."""
def _advance_load(self, obs: FpmObservations) -> Optional[ScalingDecision]:
if not self._config.enable_load_scaling:
return None
mode = self._config.mode
if mode == "agg":
return self._advance_load_agg(obs)
if mode == "disagg":
return self._advance_load_disagg(obs)
return self._advance_load_single(obs, mode)
def _advance_load_single(
self, obs: FpmObservations, component: str
) -> Optional[ScalingDecision]:
if self._scaling_in_progress(component):
logger.info(f"Scaling in progress for {component}, observing only")
return None
fpm_stats = obs.prefill if component == "prefill" else obs.decode
num_workers = (
self._num_p_workers if component == "prefill" else self._num_d_workers
)
if not fpm_stats:
return None
if not self._reconcile_fpm_worker_count(fpm_stats, num_workers, component):
return None
desired = (
self._prefill_load_decision(fpm_stats, num_workers)
if component == "prefill"
else self._decode_load_decision(fpm_stats, num_workers)
)
if desired is None:
return None
if self._config.enable_throughput_scaling:
bound = (
self._throughput_lower_bound_p
if component == "prefill"
else self._throughput_lower_bound_d
)
desired = max(desired, bound)
desired = self._apply_single_budget(desired, component)
return (
ScalingDecision(num_prefill=desired)
if component == "prefill"
else ScalingDecision(num_decode=desired)
)
def _advance_load_disagg(self, obs: FpmObservations) -> Optional[ScalingDecision]:
p_stats, d_stats = obs.prefill, obs.decode
if not p_stats and not d_stats:
logger.warning("No FPM data for either prefill or decode, skipping")
return None
if p_stats and not self._reconcile_fpm_worker_count(
p_stats, self._num_p_workers, "prefill"
):
return None
if d_stats and not self._reconcile_fpm_worker_count(
d_stats, self._num_d_workers, "decode"
):
return None
p_desired = (
self._prefill_load_decision(p_stats, self._num_p_workers)
if p_stats
else None
)
d_desired = (
self._decode_load_decision(d_stats, self._num_d_workers)
if d_stats
else None
)
final_p = p_desired if p_desired is not None else self._num_p_workers
final_d = d_desired if d_desired is not None else self._num_d_workers
if final_p == self._num_p_workers and final_d == self._num_d_workers:
logger.info("Load-based scaling: no scaling needed")
return None
if self._config.enable_throughput_scaling:
final_p = max(final_p, self._throughput_lower_bound_p)
final_d = max(final_d, self._throughput_lower_bound_d)
final_p = max(final_p, self._config.min_endpoint)
final_d = max(final_d, self._config.min_endpoint)
final_p, final_d = self._apply_global_budget(final_p, final_d)
logger.info(
f"Load-based disagg scaling: prefill {self._num_p_workers}->{final_p}, "
f"decode {self._num_d_workers}->{final_d}"
)
return ScalingDecision(num_prefill=final_p, num_decode=final_d)
def _advance_load_agg(self, obs: FpmObservations) -> Optional[ScalingDecision]:
fpm_stats = obs.decode
if not fpm_stats:
return None
num_workers = self._num_d_workers
if self._scaling_in_progress("decode"):
logger.info(
f"Scaling in progress ({num_workers} -> {self._expected_num_d}), observing only"
)
return None
if not self._reconcile_fpm_worker_count(fpm_stats, num_workers, "agg"):
return None
if not self._agg_regression.has_sufficient_data():
logger.info(
f"Agg regression: insufficient data "
f"({self._agg_regression.num_observations}/{self._agg_regression.min_observations})"
)
return None
d_caps = self._capabilities.decode
max_tokens = d_caps.max_num_batched_tokens if d_caps else None
if not max_tokens or max_tokens <= 0:
logger.warning("max_num_batched_tokens not available, skipping agg scaling")
return None
p_desired = self._agg_prefill_scaling(fpm_stats, num_workers, max_tokens)
d_desired = self._agg_decode_scaling(fpm_stats, num_workers)
logger.info(
f"Agg scaling decisions: prefill={p_desired}, decode={d_desired} (current={num_workers})"
)
if p_desired is not None and p_desired > num_workers:
desired = p_desired
elif d_desired is not None and d_desired > num_workers:
desired = d_desired
elif (
p_desired is not None
and p_desired < num_workers
and d_desired is not None
and d_desired < num_workers
):
desired = max(p_desired, d_desired)
else:
logger.info("Agg scaling: no scaling needed")
return None
desired = max(desired, self._config.min_endpoint)
if self._config.enable_throughput_scaling:
desired = max(desired, self._throughput_lower_bound_d)
desired = self._apply_single_budget(desired, "decode")
logger.info(f"Agg load-based scaling: {num_workers} -> {desired}")
return ScalingDecision(num_decode=desired)
# ------------------------------------------------------------------
# Per-engine latency estimation
# ------------------------------------------------------------------
def _prefill_load_decision(
self, fpm_stats: dict[tuple[str, int], ForwardPassMetrics], num_workers: int
) -> Optional[int]:
if not self._prefill_regression.has_sufficient_data():
logger.info(
f"TTFT regression: insufficient data "
f"({self._prefill_regression.num_observations}/{self._prefill_regression.min_observations})"
)
return None
if num_workers == 0:
return None
p_caps = self._capabilities.prefill
max_tokens = p_caps.max_num_batched_tokens if p_caps else None
if not max_tokens or max_tokens <= 0:
logger.warning(
"max_num_batched_tokens not available, skipping prefill load scaling"
)
return None
estimates: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
est = self._prefill_regression.estimate_next_ttft(
queued_prefill_tokens=fpm.queued_requests.sum_prefill_tokens,
max_num_batched_tokens=max_tokens,
)
if est is not None:
est_ms = est * 1000
estimates.append(est_ms)
logger.info(
f"Prefill engine {wid}:dp{dp}: estimated TTFT {est_ms:.2f}ms "
f"(queued={fpm.queued_requests.sum_prefill_tokens}, "
f"avg_isl={self._prefill_regression.avg_isl:.1f})"
)
return self._scale_decision(
estimates, self._config.ttft, num_workers, "prefill TTFT"
)
def _decode_load_decision(
self, fpm_stats: dict[tuple[str, int], ForwardPassMetrics], num_workers: int
) -> Optional[int]:
if not self._decode_regression.has_sufficient_data():
logger.info(
f"ITL regression: insufficient data "
f"({self._decode_regression.num_observations}/{self._decode_regression.min_observations})"
)
return None
if num_workers == 0:
return None
estimates: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
est = self._decode_regression.estimate_next_itl(
scheduled_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
queued_decode_kv=fpm.queued_requests.sum_decode_kv_tokens,
)
if est is not None:
est_ms = est * 1000
estimates.append(est_ms)
logger.info(
f"Decode engine {wid}:dp{dp}: estimated ITL {est_ms:.2f}ms "
f"(sched_kv={fpm.scheduled_requests.sum_decode_kv_tokens}, "
f"queued_kv={fpm.queued_requests.sum_decode_kv_tokens})"
)
return self._scale_decision(
estimates, self._config.itl, num_workers, "decode ITL"
)
def _agg_prefill_scaling(
self,
fpm_stats: dict[tuple[str, int], ForwardPassMetrics],
num_workers: int,
max_tokens: int,
) -> Optional[int]:
estimates: list[float] = []
for fpm in fpm_stats.values():
est = self._agg_regression.estimate_next_ttft(
queued_prefill_tokens=fpm.queued_requests.sum_prefill_tokens,
max_num_batched_tokens=max_tokens,
current_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
)
if est is not None:
estimates.append(est * 1000)
return self._scale_decision(
estimates, self._config.ttft, num_workers, "agg TTFT"
)
def _agg_decode_scaling(
self,
fpm_stats: dict[tuple[str, int], ForwardPassMetrics],
num_workers: int,
) -> Optional[int]:
estimates: list[float] = []
for fpm in fpm_stats.values():
est = self._agg_regression.estimate_next_itl(
scheduled_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
queued_decode_kv=fpm.queued_requests.sum_decode_kv_tokens,
)
if est is not None:
estimates.append(est * 1000)
return self._scale_decision(estimates, self._config.itl, num_workers, "agg ITL")
def _scale_decision(
self, estimates: list[float], sla: float, num_workers: int, label: str
) -> Optional[int]:
if not estimates:
return None
sensitivity = self._config.load_scaling_down_sensitivity / 100.0
logger.info(
f"Load-based {label}: workers={num_workers}, sla={sla:.1f}ms, "
f"estimates={[f'{t:.1f}' for t in estimates]}"
)
if all(t > sla for t in estimates):
logger.info(
f"Load-based {label}: ALL above SLA, scaling up to {num_workers + 1}"
)
return num_workers + 1
if num_workers > 1:
threshold = sla * sensitivity
if all(t < threshold for t in estimates):
desired = max(num_workers - 1, self._config.min_endpoint)
logger.info(
f"Load-based {label}: ALL below threshold ({threshold:.1f}ms), -> {desired}"
)
return desired
return None
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import math
from typing import Optional
from dynamo.planner.config.defaults import SubComponentType
from dynamo.planner.core.base import BasePlanner
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class PrefillPlanner(BasePlanner):
component_type = SubComponentType.PREFILL
def load_plan_adjustment(self) -> Optional[int]:
"""Load-based scaling decision for prefill using FPM data.
For each engine, simulates prefill scheduling to estimate next TTFT:
- Uses queued prefill tokens + avg ISL as total tokens to process
- Chunks into max_num_batched_tokens-sized iterations
- Sums regression-predicted wall time per chunk
Scale up if ALL engines' estimated TTFT > SLA.
Scale down if ALL engines' estimated TTFT < SLA * sensitivity.
"""
if not self.ttft_regression.has_sufficient_data():
logger.info(
f"TTFT regression: insufficient data ({self.ttft_regression.num_observations}"
f"/{self.ttft_regression.min_observations}), skipping load-based scaling"
)
return None
fpm_stats = self._get_fpm_stats()
if not fpm_stats:
return None
num_workers = self.shared_state.num_p_workers
if num_workers == 0:
return None
max_num_batched_tokens = getattr(
self.prefill_worker_info, "max_num_batched_tokens", None
)
if not max_num_batched_tokens or max_num_batched_tokens <= 0:
logger.warning(
"max_num_batched_tokens not available from WorkerInfo, "
"skipping prefill load-based scaling"
)
return None
estimated_ttfts: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
queued_prefill = fpm.queued_requests.sum_prefill_tokens
est = self.ttft_regression.estimate_next_ttft(
queued_prefill_tokens=queued_prefill,
max_num_batched_tokens=max_num_batched_tokens,
)
if est is None:
continue
est_ms = est * 1000
estimated_ttfts.append(est_ms)
logger.info(
f"Prefill engine {wid}:dp{dp}: estimated TTFT {est_ms:.2f}ms "
f"(queued_prefill={queued_prefill}, avg_isl={self.ttft_regression.avg_isl:.1f})"
)
return self._load_based_scaling_decision_from_estimates(
estimates=estimated_ttfts,
sla=self.config.ttft,
num_workers=num_workers,
label="prefill TTFT",
)
def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float
) -> Optional[int]:
demand_rps = next_num_req / self.config.throughput_adjustment_interval
engine_rps, actual_ttft_ms = self.ttft_regression.find_best_engine_prefill_rps(
ttft_sla=self.config.ttft, isl=next_isl
)
if engine_rps <= 0:
logger.warning("Prefill perf model not ready, skipping throughput scaling")
return None
if actual_ttft_ms > self.config.ttft:
logger.warning(
f"Prefill TTFT SLA not met: {actual_ttft_ms:.1f}ms > "
f"{self.config.ttft:.1f}ms, scaling with best achievable rate"
)
next_num_p = math.ceil(demand_rps / engine_rps)
next_num_p = max(next_num_p, self.config.min_endpoint)
logger.info(
f"Prefill: {demand_rps:.2f}(demand rps) / "
f"{engine_rps:.2f}(engine rps) = {next_num_p}(num_p), "
f"est_ttft={actual_ttft_ms:.1f}ms"
)
return next_num_p
def update_predicted_replicas_metric(self, desired_replicas: int) -> None:
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.predicted_num_p.set(desired_replicas)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Pure discrete-event state machine for planner scaling decisions.
``PlannerStateMachine`` receives events (``ScheduledTick`` + ``TickInput``),
updates internal state (regression models, load predictors, worker inventory),
and returns effects (``PlannerEffects``: optional scaling decision + next tick).
This module contains **zero I/O** -- no runtime, connector, subscriber, asyncio,
or Prometheus dependencies. All external interaction is done by the adapter
layer (``NativePlannerBase`` and its subclasses) which feeds data in and
applies decisions out.
Load-based scaling logic lives in ``load_scaling.py``.
Throughput-based scaling logic lives in ``throughput_scaling.py``.
"""
from __future__ import annotations
import logging
import math
from typing import TYPE_CHECKING, Optional
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.load.predictors import LOAD_PREDICTORS
from dynamo.planner.core.load_scaling import LoadScalingMixin
from dynamo.planner.core.perf_model import (
AggRegressionModel,
DecodeRegressionModel,
PrefillRegressionModel,
)
from dynamo.planner.core.throughput_scaling import ThroughputScalingMixin
from dynamo.planner.core.types import (
FpmObservations,
PlannerEffects,
ScheduledTick,
TickInput,
TrafficObservation,
WorkerCapabilities,
WorkerCounts,
)
if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
logger = logging.getLogger(__name__)
class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
"""Discrete-event state machine for all planner modes.
Owns regression models, load predictors, throughput lower bounds,
and all scaling decision logic. Receives events, returns effects.
Has no runtime dependencies.
"""
def __init__(
self,
config: PlannerConfig,
capabilities: Optional[WorkerCapabilities] = None,
) -> None:
self._config = config
self._capabilities = capabilities or WorkerCapabilities()
self._is_agg = config.mode == "agg"
self._has_prefill = config.mode in ("disagg", "prefill")
self._has_decode = config.mode in ("disagg", "decode", "agg")
if self._is_agg:
self._agg_regression = AggRegressionModel(
max_num_fpm_samples=config.max_num_fpm_samples,
min_observations=config.load_min_observations,
bucket_count=config.fpm_sample_bucket_size,
)
else:
if self._has_prefill:
self._prefill_regression = PrefillRegressionModel(
max_num_fpm_samples=config.max_num_fpm_samples,
min_observations=config.load_min_observations,
bucket_count=config.fpm_sample_bucket_size,
)
if self._has_decode:
self._decode_regression = DecodeRegressionModel(
max_num_fpm_samples=config.max_num_fpm_samples,
min_observations=config.load_min_observations,
bucket_count=config.fpm_sample_bucket_size,
)
predictor_cls = LOAD_PREDICTORS[config.load_predictor]
self._num_req_predictor = predictor_cls(config)
self._isl_predictor = predictor_cls(config)
self._osl_predictor = predictor_cls(config)
self._num_p_workers: int = 0
self._num_d_workers: int = 0
self._expected_num_p: Optional[int] = None
self._expected_num_d: Optional[int] = None
self._throughput_lower_bound_p: int = 1
self._throughput_lower_bound_d: int = 1
self._next_load_s: float = float("inf")
self._next_throughput_s: float = float("inf")
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def initial_tick(self, start_s: float) -> ScheduledTick:
self._next_load_s = start_s + self._config.load_adjustment_interval
if self._config.enable_throughput_scaling:
self._next_throughput_s = (
start_s + self._config.throughput_adjustment_interval
)
return self._next_scheduled_tick()
def load_benchmark_fpms(
self,
prefill_fpms: Optional[list[ForwardPassMetrics]] = None,
decode_fpms: Optional[list[ForwardPassMetrics]] = None,
agg_fpms: Optional[list[ForwardPassMetrics]] = None,
) -> None:
if agg_fpms and self._is_agg:
self._agg_regression.load_benchmark_fpms(agg_fpms)
logger.info(f"Bootstrapped agg regression with {len(agg_fpms)} FPMs")
if prefill_fpms and self._has_prefill and not self._is_agg:
self._prefill_regression.load_benchmark_fpms(prefill_fpms)
logger.info(
f"Bootstrapped prefill regression with {len(prefill_fpms)} FPMs"
)
if decode_fpms and self._has_decode and not self._is_agg:
self._decode_regression.load_benchmark_fpms(decode_fpms)
logger.info(f"Bootstrapped decode regression with {len(decode_fpms)} FPMs")
def warm_load_predictors(self, observations: list[TrafficObservation]) -> None:
for obs in observations:
self._num_req_predictor.add_data_point(obs.num_req)
self._isl_predictor.add_data_point(obs.isl)
self._osl_predictor.add_data_point(obs.osl)
logger.info(f"Warmed load predictors with {len(observations)} intervals")
for p in (self._num_req_predictor, self._isl_predictor, self._osl_predictor):
if hasattr(p, "reset_idle_skip"):
p.reset_idle_skip()
def on_tick(self, tick: ScheduledTick, tick_input: TickInput) -> PlannerEffects:
effects = PlannerEffects()
if tick_input.worker_counts is not None:
self._update_inventory(tick_input.worker_counts)
if tick.run_load_scaling:
if tick_input.fpm_observations is not None:
self._observe_fpm(tick_input.fpm_observations)
load_decision = self._advance_load(tick_input.fpm_observations)
if load_decision is not None:
effects.scale_to = load_decision
self._next_load_s = tick_input.now_s + self._config.load_adjustment_interval
if tick.run_throughput_scaling:
if tick_input.traffic is not None:
self._observe_traffic(tick_input.traffic)
throughput_decision = self._advance_throughput(tick_input.traffic)
if throughput_decision is not None:
if effects.scale_to is None:
effects.scale_to = throughput_decision
self._next_throughput_s = (
tick_input.now_s + self._config.throughput_adjustment_interval
)
effects.next_tick = self._next_scheduled_tick()
return effects
# ------------------------------------------------------------------
# Tick scheduling
# ------------------------------------------------------------------
_MERGE_TOLERANCE_S = 0.5
def _next_scheduled_tick(self) -> ScheduledTick:
"""Build the single next tick, merging cadences if they coincide."""
at_s = min(self._next_load_s, self._next_throughput_s)
is_load = self._next_load_s <= at_s + self._MERGE_TOLERANCE_S
is_throughput = self._next_throughput_s <= at_s + self._MERGE_TOLERANCE_S
return ScheduledTick(
at_s=at_s,
run_load_scaling=is_load,
run_throughput_scaling=is_throughput,
need_worker_states=True,
need_worker_fpm=is_load,
need_traffic_metrics=is_throughput,
traffic_metrics_duration_s=(
self._config.throughput_adjustment_interval if is_throughput else 0.0
),
)
# ------------------------------------------------------------------
# Inventory
# ------------------------------------------------------------------
def _update_inventory(self, counts: WorkerCounts) -> None:
if counts.ready_num_prefill is not None:
self._num_p_workers = counts.ready_num_prefill
if counts.ready_num_decode is not None:
self._num_d_workers = counts.ready_num_decode
self._expected_num_p = counts.expected_num_prefill
self._expected_num_d = counts.expected_num_decode
def _scaling_in_progress(self, component: str) -> bool:
if component == "prefill":
return (
self._expected_num_p is not None
and self._expected_num_p != self._num_p_workers
)
return (
self._expected_num_d is not None
and self._expected_num_d != self._num_d_workers
)
# ------------------------------------------------------------------
# FPM / traffic observation
# ------------------------------------------------------------------
def _observe_fpm(self, obs: FpmObservations) -> None:
if self._is_agg:
if obs.decode:
for fpm in obs.decode.values():
self._agg_regression.add_observation(fpm)
logger.info(f"FPM load stats: {len(obs.decode)} agg engines observed")
return
if obs.prefill and self._has_prefill:
for fpm in obs.prefill.values():
self._prefill_regression.add_observation(fpm)
logger.info(f"FPM load stats: {len(obs.prefill)} prefill engines observed")
if obs.decode and self._has_decode:
for fpm in obs.decode.values():
self._decode_regression.add_observation(fpm)
logger.info(f"FPM load stats: {len(obs.decode)} decode engines observed")
def _observe_traffic(self, traffic: TrafficObservation) -> None:
self._num_req_predictor.add_data_point(traffic.num_req)
self._isl_predictor.add_data_point(traffic.isl)
self._osl_predictor.add_data_point(traffic.osl)
# ------------------------------------------------------------------
# Budget
# ------------------------------------------------------------------
def _apply_single_budget(self, desired: int, component: str) -> int:
caps = (
self._capabilities.prefill
if component == "prefill"
else self._capabilities.decode
)
gpu = caps.num_gpu if caps else None
if gpu is None:
return desired
return self._budget_clamp(max(desired, self._config.min_endpoint), gpu)
def _apply_global_budget(self, num_p: int, num_d: int) -> tuple[int, int]:
budget = self._config.max_gpu_budget
p_gpu = (
self._capabilities.prefill.num_gpu if self._capabilities.prefill else None
)
d_gpu = self._capabilities.decode.num_gpu if self._capabilities.decode else None
if budget < 0 or p_gpu is None or d_gpu is None:
return num_p, num_d
total = num_p * p_gpu + num_d * d_gpu
if total <= budget:
return num_p, num_d
min_req = self._config.min_endpoint * p_gpu + self._config.min_endpoint * d_gpu
if budget < min_req:
logger.warning(
f"max_gpu_budget ({budget}) below min ({min_req}); zero replicas"
)
return 0, 0
scale = budget / total
max_p = math.floor((budget - self._config.min_endpoint * d_gpu) / p_gpu)
num_p = max(self._config.min_endpoint, min(max_p, math.floor(num_p * scale)))
remaining = budget - num_p * p_gpu
num_d = max(self._config.min_endpoint, math.floor(remaining / d_gpu))
logger.warning(f"GPUs ({total}) > budget ({budget}), -> {num_p}P + {num_d}D")
return num_p, num_d
def _budget_clamp(self, desired: int, engine_gpu: int) -> int:
budget = self._config.max_gpu_budget
if budget < 0:
return desired
total = desired * engine_gpu
if total <= budget:
return desired
min_req = self._config.min_endpoint * engine_gpu
if budget < min_req:
logger.warning(
f"max_gpu_budget ({budget}) below min ({min_req}); zero replicas"
)
return 0
result = max(self._config.min_endpoint, math.floor(budget / engine_gpu))
logger.warning(f"GPUs ({total}) > budget ({budget}), -> {result} replicas")
return result
# ------------------------------------------------------------------
# FPM / worker count reconciliation
# ------------------------------------------------------------------
@staticmethod
def _reconcile_fpm_worker_count(
fpm_stats: dict[tuple[str, int], ForwardPassMetrics], dgd_count: int, label: str
) -> bool:
workers_to_dp: dict[str, set[int]] = {}
for wid, dp in fpm_stats:
workers_to_dp.setdefault(wid, set()).add(dp)
if len(workers_to_dp) != dgd_count:
logger.warning(
f"Worker count mismatch: DGD={dgd_count}, FPM={len(workers_to_dp)} for {label}"
)
return False
dp_sizes = {len(dps) for dps in workers_to_dp.values()}
if len(dp_sizes) > 1:
logger.warning(f"Inconsistent DP ranks for {label}: {dict(workers_to_dp)}")
return False
dp_size = dp_sizes.pop() if dp_sizes else 1
if len(fpm_stats) != dgd_count * dp_size:
logger.warning(
f"Incomplete FPM coverage for {label}: expected {dgd_count}x{dp_size}, got {len(fpm_stats)}"
)
return False
return True
# ------------------------------------------------------------------
# Accessors
# ------------------------------------------------------------------
@property
def prefill_regression(self) -> PrefillRegressionModel:
if not self._has_prefill:
raise AttributeError(f"No prefill regression in mode={self._config.mode}")
return self._prefill_regression
@property
def decode_regression(self) -> DecodeRegressionModel:
if not self._has_decode or self._is_agg:
raise AttributeError(f"No decode regression in mode={self._config.mode}")
return self._decode_regression
@property
def agg_regression(self) -> AggRegressionModel:
if not self._is_agg:
raise AttributeError(f"No agg regression in mode={self._config.mode}")
return self._agg_regression
@property
def regression(self) -> AggRegressionModel:
return self.agg_regression
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# mypy: disable-error-code="attr-defined"
"""Throughput-based scaling logic (Prometheus traffic-driven, predictive).
Mixin consumed by ``PlannerStateMachine``. All methods access state
via ``self._config``, ``self._capabilities``, and regression models.
"""
from __future__ import annotations
import logging
import math
from typing import Optional
from dynamo.planner.core.types import ScalingDecision, TrafficObservation
logger = logging.getLogger(__name__)
class ThroughputScalingMixin:
"""Traffic-driven throughput-based scaling decisions."""
def _advance_throughput(
self, traffic: TrafficObservation
) -> Optional[ScalingDecision]:
if not self._config.enable_throughput_scaling:
return None
next_num_req, next_isl, next_osl = self._predict_load()
if next_num_req is None or next_isl is None or next_osl is None:
return None
if traffic.duration_s <= 0:
logger.warning("Traffic observation has non-positive duration, skipping")
return None
demand_rps = next_num_req / traffic.duration_s
mode = self._config.mode
if mode == "agg":
return self._throughput_agg(demand_rps, next_isl, next_osl)
if mode == "disagg":
return self._throughput_disagg(demand_rps, next_isl, next_osl)
return self._throughput_single(demand_rps, next_isl, next_osl, mode)
def _predict_load(self) -> tuple[Optional[float], Optional[float], Optional[float]]:
try:
nr = self._num_req_predictor.predict_next()
isl = self._isl_predictor.predict_next()
osl = self._osl_predictor.predict_next()
logger.info(
f"Predicted load: num_req={nr:.2f}, isl={isl:.2f}, osl={osl:.2f}"
)
return nr, isl, osl
except Exception as e:
logger.error(f"Failed to predict load: {e}")
return None, None, None
def _throughput_single(
self, demand_rps: float, isl: float, osl: float, component: str
) -> Optional[ScalingDecision]:
desired = (
self._compute_prefill_replicas(demand_rps, isl, osl)
if component == "prefill"
else self._compute_decode_replicas(demand_rps, isl, osl)
)
if desired is None:
return None
if self._config.enable_load_scaling:
if component == "prefill":
self._throughput_lower_bound_p = desired
else:
self._throughput_lower_bound_d = desired
logger.info(f"Throughput lower bound set to {desired} for {component}")
return None
desired = self._apply_single_budget(desired, component)
return (
ScalingDecision(num_prefill=desired)
if component == "prefill"
else ScalingDecision(num_decode=desired)
)
def _throughput_disagg(
self, demand_rps: float, isl: float, osl: float
) -> Optional[ScalingDecision]:
num_p = self._compute_prefill_replicas(demand_rps, isl, osl)
num_d = self._compute_decode_replicas(demand_rps, isl, osl)
if num_p is None or num_d is None:
return None
if self._config.enable_load_scaling:
self._throughput_lower_bound_p = num_p
self._throughput_lower_bound_d = num_d
logger.info(f"Throughput lower bounds set: prefill={num_p}, decode={num_d}")
return None
num_p, num_d = self._apply_global_budget(num_p, num_d)
return ScalingDecision(num_prefill=num_p, num_decode=num_d)
def _throughput_agg(
self, demand_rps: float, isl: float, osl: float
) -> Optional[ScalingDecision]:
d_caps = self._capabilities.decode
max_tokens = d_caps.max_num_batched_tokens if d_caps else None
if not max_tokens or max_tokens <= 0:
logger.warning(
"max_num_batched_tokens not available, skipping agg throughput"
)
return None
(
engine_rps,
actual_ttft,
actual_itl,
) = self._agg_regression.find_best_engine_agg_rps(
isl=isl,
osl=osl,
max_num_batched_tokens=max_tokens,
ttft_sla=self._config.ttft,
itl_sla=self._config.itl,
)
if engine_rps <= 0:
logger.warning("Agg perf model not ready, skipping throughput scaling")
return None
if actual_ttft > self._config.ttft or actual_itl > self._config.itl:
logger.warning(
f"Agg SLA not fully met: TTFT={actual_ttft:.1f}ms, ITL={actual_itl:.1f}ms"
)
desired = max(math.ceil(demand_rps / engine_rps), self._config.min_endpoint)
logger.info(
f"Agg: {demand_rps:.2f} rps / {engine_rps:.2f} engine_rps = {desired} replicas"
)
if self._config.enable_load_scaling:
self._throughput_lower_bound_d = desired
logger.info(f"Agg throughput lower bound set to {desired}")
return None
desired = self._apply_single_budget(desired, "decode")
return ScalingDecision(num_decode=desired)
def _compute_prefill_replicas(
self, demand_rps: float, isl: float, osl: float
) -> Optional[int]:
engine_rps, ttft_ms = self._prefill_regression.find_best_engine_prefill_rps(
ttft_sla=self._config.ttft, isl=isl
)
if engine_rps <= 0:
logger.warning("Prefill perf model not ready, skipping throughput scaling")
return None
if ttft_ms > self._config.ttft:
logger.warning(
f"Prefill TTFT SLA not met: {ttft_ms:.1f}ms > {self._config.ttft:.1f}ms"
)
result = max(math.ceil(demand_rps / engine_rps), self._config.min_endpoint)
logger.info(
f"Prefill: {demand_rps:.2f} rps / {engine_rps:.2f} = {result}, est_ttft={ttft_ms:.1f}ms"
)
return result
def _compute_decode_replicas(
self, demand_rps: float, isl: float, osl: float
) -> Optional[int]:
engine_rps, itl_ms = self._decode_regression.find_best_engine_decode_rps(
itl=self._config.itl,
context_length=isl + osl / 2,
osl=osl,
)
if engine_rps <= 0:
logger.warning("Decode perf model not ready, skipping throughput scaling")
return None
if itl_ms > self._config.itl:
logger.warning(
f"Decode ITL SLA not met: {itl_ms:.1f}ms > {self._config.itl:.1f}ms"
)
result = max(math.ceil(demand_rps / engine_rps), self._config.min_endpoint)
logger.info(
f"Decode: {demand_rps:.2f} rps / {engine_rps:.2f} = {result}, est_itl={itl_ms:.1f}ms"
)
return result
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Explicit-input types for the planner core.
These types form the boundary between the planner core (pure decision logic)
and any adapter (native runtime, replay harness, tests). The core receives
``TickInput`` and returns ``PlannerEffects``; the adapter fills the input
based on the previous tick's ``ScheduledTick`` requirements.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
@dataclass
class ScheduledTick:
"""Declares when the core next needs to be called, what data it needs,
and what decisions to make.
All times are absolute seconds (wall clock for native adapter,
simulated clock for replay).
"""
at_s: float
# What decisions the core will make on this tick
run_load_scaling: bool = False
run_throughput_scaling: bool = False
# What data the adapter should collect before calling on_tick
need_traffic_metrics: bool = False
traffic_metrics_duration_s: float = 0.0
need_worker_states: bool = False
need_worker_fpm: bool = False
@dataclass
class TrafficObservation:
"""Aggregated traffic metrics over an observation window."""
duration_s: float
num_req: float
isl: float
osl: float
@dataclass
class WorkerCounts:
"""Current worker inventory as reported by the adapter."""
ready_num_prefill: Optional[int] = None
ready_num_decode: Optional[int] = None
expected_num_prefill: Optional[int] = None
expected_num_decode: Optional[int] = None
@dataclass
class FpmObservations:
"""Per-engine ForwardPassMetrics keyed by (worker_id, dp_rank)."""
prefill: Optional[dict[tuple[str, int], ForwardPassMetrics]] = None
decode: Optional[dict[tuple[str, int], ForwardPassMetrics]] = None
@dataclass
class TickInput:
"""What the adapter provides to the core on each tick.
Fields are filled according to the previous ``ScheduledTick``'s
declared requirements.
"""
now_s: float
traffic: Optional[TrafficObservation] = None
worker_counts: Optional[WorkerCounts] = None
fpm_observations: Optional[FpmObservations] = None
@dataclass
class ScalingDecision:
"""Desired replica counts. ``None`` means the core has no opinion
on that component (e.g. prefill-only planner leaves decode as None).
"""
num_prefill: Optional[int] = None
num_decode: Optional[int] = None
@dataclass
class PlannerEffects:
"""What the core returns after processing a tick."""
scale_to: Optional[ScalingDecision] = None
next_tick: Optional[ScheduledTick] = None
@dataclass
class EngineCapabilities:
"""Static capabilities for a single engine stage (prefill or decode)."""
num_gpu: Optional[int] = None
max_num_batched_tokens: Optional[int] = None
max_num_seqs: Optional[int] = None
context_length: Optional[int] = None
@dataclass
class WorkerCapabilities:
"""Static per-engine capabilities discovered at startup from MDC.
Provided once when constructing the planner core. In native mode
these come from ``WorkerInfo`` (resolved via MDC / DGD); in replay
they come from the simulated engine args.
For agg mode, only ``decode`` is populated (single engine type).
"""
prefill: Optional[EngineCapabilities] = None
decode: Optional[EngineCapabilities] = None
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import os
from unittest.mock import Mock, patch
"""Regression model unit tests.
These test the perf_model classes directly (PrefillRegressionModel,
DecodeRegressionModel, AggRegressionModel) without any planner adapter.
FPM-driven scaling integration tests live in test_state_machine.py.
"""
import pytest
......@@ -15,18 +20,12 @@ from dynamo.common.forward_pass_metrics import (
ForwardPassMetrics,
QueuedRequestMetrics,
ScheduledRequestMetrics,
encode,
)
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.decode import DecodePlanner
from dynamo.planner.core.perf_model import (
AggRegressionModel,
DecodeRegressionModel,
PrefillRegressionModel,
)
from dynamo.planner.core.prefill import PrefillPlanner
from dynamo.planner.core.state import PlannerSharedState
from dynamo.planner.monitoring.worker_info import WorkerInfo
pytestmark = [
pytest.mark.gpu_0,
......@@ -149,7 +148,6 @@ class TestPrefillRegressionModel:
ttft_sla=2000.0, isl=1000.0
)
assert rps > 0
# wall_time ~1.002s for 1000 tokens -> rps ~ 1/1.002 ~ 0.998
assert 0.5 < rps < 2.0
assert actual_ttft_ms > 0
assert 1000 < actual_ttft_ms < 2000
......@@ -190,7 +188,6 @@ class TestPrefillRegressionModel:
class TestBucketedRetirement:
def test_total_capped_at_max(self):
"""Total observations never exceed max_num_fpm_samples."""
model = PrefillRegressionModel(
max_num_fpm_samples=10, min_observations=3, bucket_count=4
)
......@@ -204,12 +201,9 @@ class TestBucketedRetirement:
assert model.num_observations == 10
def test_most_populated_bucket_loses_oldest(self):
"""When evicting, the oldest entry from the most-populated bucket is removed."""
model = PrefillRegressionModel(
max_num_fpm_samples=6, min_observations=1, bucket_count=4
)
# 3 observations at low tokens (bucket 0 area)
for i in range(3):
fpm = _make_fpm(
sum_prefill_tokens=10 + i,
......@@ -217,8 +211,6 @@ class TestBucketedRetirement:
wall_time=0.001 * (10 + i),
)
model.add_observation(fpm)
# 3 observations at high tokens (different bucket)
for i in range(3):
fpm = _make_fpm(
sum_prefill_tokens=1000 + i * 100,
......@@ -226,51 +218,31 @@ class TestBucketedRetirement:
wall_time=0.001 * (1000 + i * 100),
)
model.add_observation(fpm)
assert model.num_observations == 6
# One more at low tokens; total would exceed 6 so most-populated
# bucket loses its oldest entry.
fpm = _make_fpm(
sum_prefill_tokens=15,
num_prefill_requests=1,
wall_time=0.015,
)
fpm = _make_fpm(sum_prefill_tokens=15, num_prefill_requests=1, wall_time=0.015)
model.add_observation(fpm)
assert model.num_observations == 6
def test_uniform_distribution_preserved(self):
"""Bucketed eviction keeps observations across operating points."""
model = DecodeRegressionModel(
max_num_fpm_samples=10, min_observations=3, bucket_count=16
)
# Many observations at a single operating point
for _ in range(15):
fpm = _make_fpm(
num_decode_requests=32,
sum_decode_kv_tokens=32000,
wall_time=0.01,
num_decode_requests=32, sum_decode_kv_tokens=32000, wall_time=0.01
)
model.add_observation(fpm)
assert model.num_observations == 10
# Add a different operating point; the concentrated bucket loses one
fpm = _make_fpm(
num_decode_requests=4,
sum_decode_kv_tokens=4000,
wall_time=0.005,
num_decode_requests=4, sum_decode_kv_tokens=4000, wall_time=0.005
)
model.add_observation(fpm)
assert model.num_observations == 10
def test_2d_bucketed_retirement(self):
"""2D models retire from the most-populated grid cell."""
model = AggRegressionModel(
max_num_fpm_samples=8, min_observations=1, bucket_count=16
)
# Fill with varied data
for p, d in [(100, 500), (200, 1000), (300, 1500), (400, 2000)]:
fpm = _make_fpm(
sum_prefill_tokens=p,
......@@ -280,8 +252,6 @@ class TestBucketedRetirement:
wall_time=0.001 * p + 0.0001 * d,
)
model.add_observation(fpm)
# Concentrate 4 more in one region
for _ in range(4):
fpm = _make_fpm(
sum_prefill_tokens=100,
......@@ -291,10 +261,7 @@ class TestBucketedRetirement:
wall_time=0.15,
)
model.add_observation(fpm)
assert model.num_observations == 8
# Overflow triggers retirement from the concentrated cell
fpm = _make_fpm(
sum_prefill_tokens=350,
num_prefill_requests=1,
......@@ -310,15 +277,8 @@ class TestBucketedRetirement:
class TestDecodeRegressionModel:
def _train_2d(self, model: DecodeRegressionModel) -> None:
"""Populate with 2D data: wall_time = f(num_decode_requests, sum_decode_kv_tokens)."""
for n_req, kv in [
(5, 1000),
(10, 2000),
(15, 3000),
(20, 4000),
(25, 5000),
]:
def _train_2d(self, model):
for n_req, kv in [(5, 1000), (10, 2000), (15, 3000), (20, 4000), (25, 5000)]:
fpm = _make_fpm(
sum_decode_kv_tokens=kv,
num_decode_requests=n_req,
......@@ -346,11 +306,9 @@ class TestDecodeRegressionModel:
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_2d(model)
assert model.has_sufficient_data()
est = model.estimate_next_itl(scheduled_decode_kv=3000, queued_decode_kv=0)
assert est is not None
assert est > 0
assert est is not None and est > 0
def test_avg_decode_length_tracking(self):
model = DecodeRegressionModel(
......@@ -365,8 +323,7 @@ class TestDecodeRegressionModel:
model.add_observation(fpm)
assert abs(model.avg_decode_length - 200.0) < 1.0
def _train_thpt_model(self, model: DecodeRegressionModel) -> None:
"""Populate with 2D data at decode-realistic wall-time scale."""
def _train_thpt_model(self, model):
for n_req, kv in [
(5, 5000),
(10, 10000),
......@@ -386,13 +343,10 @@ class TestDecodeRegressionModel:
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_thpt_model(model)
rps, actual_itl = model.find_best_engine_decode_rps(
itl=50.0, context_length=1000.0, osl=150.0
)
assert rps > 0
assert actual_itl > 0
assert actual_itl <= 50.0
assert rps > 0 and actual_itl > 0 and actual_itl <= 50.0
def test_find_best_engine_decode_rps_zero_context(self):
model = DecodeRegressionModel(
......@@ -402,8 +356,7 @@ class TestDecodeRegressionModel:
rps, itl_ms = model.find_best_engine_decode_rps(
itl=50.0, context_length=0.0, osl=150.0
)
assert rps == 0.0
assert itl_ms == 0.0
assert rps == 0.0 and itl_ms == 0.0
def test_load_benchmark_fpms(self):
model = DecodeRegressionModel(
......@@ -418,15 +371,14 @@ class TestDecodeRegressionModel:
for n in [5, 10, 15, 20, 25]
]
model.load_benchmark_fpms(fpms)
assert model.num_observations == 5
assert model.has_sufficient_data()
assert model.num_observations == 5 and model.has_sufficient_data()
# ── AggRegressionModel tests ─────────────────────────────────────────
class TestAggRegressionModel:
def _train_agg(self, model: AggRegressionModel) -> None:
def _train_agg(self, model):
for p, d in [(100, 1000), (200, 2000), (300, 3000), (400, 4000), (500, 5000)]:
fpm = _make_fpm(
sum_prefill_tokens=p,
......@@ -458,27 +410,19 @@ class TestAggRegressionModel:
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_agg(model)
assert model.has_sufficient_data()
ttft = model.estimate_next_ttft(
queued_prefill_tokens=0,
max_num_batched_tokens=2048,
current_decode_kv=3000,
queued_prefill_tokens=0, max_num_batched_tokens=2048, current_decode_kv=3000
)
assert ttft is not None
assert ttft > 0
assert ttft is not None and ttft > 0
itl = model.estimate_next_itl(scheduled_decode_kv=3000, queued_decode_kv=0)
assert itl is not None
assert itl > 0
assert itl is not None and itl > 0
def test_find_best_engine_agg_rps(self):
model = AggRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_agg(model)
thpt, actual_ttft, actual_itl = model.find_best_engine_agg_rps(
isl=2048.0,
osl=150.0,
......@@ -486,10 +430,7 @@ class TestAggRegressionModel:
ttft_sla=500.0,
itl_sla=50.0,
)
assert isinstance(thpt, float)
assert thpt > 0
assert actual_ttft >= 0
assert actual_itl >= 0
assert thpt > 0 and actual_ttft >= 0 and actual_itl >= 0
def test_find_best_engine_agg_rps_insufficient_data(self):
model = AggRegressionModel(
......@@ -503,221 +444,3 @@ class TestAggRegressionModel:
itl_sla=50.0,
)
assert thpt == 0.0
# ── Planner integration tests (with mocked FPM subscriber) ──────────
@pytest.fixture(autouse=True)
def mock_prometheus_metrics():
with patch("dynamo.planner.monitoring.planner_metrics.Gauge") as mock_gauge:
mock_gauge.return_value = Mock()
yield
def _build_load_config(**overrides) -> PlannerConfig:
defaults = dict(
throughput_adjustment_interval=60,
prefill_engine_num_gpu=1,
decode_engine_num_gpu=1,
min_endpoint=1,
max_gpu_budget=-1,
ttft=500.0,
itl=50.0,
backend="vllm",
no_operation=True,
metric_pulling_prometheus_endpoint="http://localhost:9090",
metric_reporting_prometheus_port=0,
load_predictor="constant",
profile_results_dir=os.path.join(
os.path.dirname(__file__),
"..",
"data",
"profiling_results",
"H200_TP1P_TP1D",
),
environment="kubernetes",
namespace="test-namespace",
mode="disagg",
enable_load_scaling=True,
enable_throughput_scaling=True,
load_adjustment_interval=5,
max_num_fpm_samples=50,
fpm_sample_bucket_size=16,
load_scaling_down_sensitivity=80,
load_metric_samples=10,
load_min_observations=5,
)
defaults.update(overrides)
return PlannerConfig.model_construct(**defaults)
def _mock_fpm_subscriber(fpm_stats: dict[tuple[str, int], ForwardPassMetrics]):
"""Create a mock FPM subscriber that returns encoded FPM stats."""
mock = Mock()
encoded = {k: encode(v) for k, v in fpm_stats.items()}
mock.get_recent_stats.return_value = encoded
return mock
class TestPrefillFpmScaling:
def test_scale_up_all_engines_above_sla(self):
"""All engines have high queued prefill -> estimated TTFT > SLA -> scale up."""
config = _build_load_config(ttft=5.0) # 5ms SLA (easy to exceed)
shared_state = PlannerSharedState()
shared_state.num_p_workers = 2
planner = PrefillPlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
planner.prefill_worker_info = WorkerInfo(max_num_batched_tokens=2048)
for tokens in range(200, 1200, 100):
fpm = _make_fpm(
sum_prefill_tokens=tokens,
num_prefill_requests=1,
wall_time=0.001 * tokens,
)
planner.ttft_regression.add_observation(fpm)
stats = {
("w1", 0): _make_fpm(
worker_id="w1",
queued_prefill_tokens=10000,
sum_prefill_tokens=500,
num_prefill_requests=1,
wall_time=0.5,
),
("w2", 0): _make_fpm(
worker_id="w2",
queued_prefill_tokens=8000,
sum_prefill_tokens=600,
num_prefill_requests=1,
wall_time=0.6,
),
}
planner.fpm_subscriber = _mock_fpm_subscriber(stats)
result = planner.load_plan_adjustment()
assert result == 3
def test_scale_down_all_engines_below_sla(self):
"""All engines have low queued prefill -> estimated TTFT < SLA * sensitivity."""
config = _build_load_config(ttft=500.0, load_scaling_down_sensitivity=100)
shared_state = PlannerSharedState()
shared_state.num_p_workers = 3
planner = PrefillPlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
planner.prefill_worker_info = WorkerInfo(max_num_batched_tokens=2048)
for tokens in range(100, 600, 50):
fpm = _make_fpm(
sum_prefill_tokens=tokens,
num_prefill_requests=1,
wall_time=0.001 * tokens,
)
planner.ttft_regression.add_observation(fpm)
stats = {
(f"w{i}", 0): _make_fpm(
worker_id=f"w{i}",
queued_prefill_tokens=0,
sum_prefill_tokens=100,
num_prefill_requests=1,
wall_time=0.1,
)
for i in range(3)
}
planner.fpm_subscriber = _mock_fpm_subscriber(stats)
result = planner.load_plan_adjustment()
assert result == 2
def test_cold_start_returns_none(self):
config = _build_load_config()
shared_state = PlannerSharedState()
shared_state.num_p_workers = 2
planner = PrefillPlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
planner.prefill_worker_info = WorkerInfo(max_num_batched_tokens=2048)
for tokens in [100, 200]:
fpm = _make_fpm(sum_prefill_tokens=tokens, wall_time=0.01)
planner.ttft_regression.add_observation(fpm)
stats = {("w1", 0): _make_fpm(queued_prefill_tokens=5000, wall_time=0.5)}
planner.fpm_subscriber = _mock_fpm_subscriber(stats)
result = planner.load_plan_adjustment()
assert result is None
class TestDecodeFpmScaling:
def test_scale_up_all_engines_above_sla(self):
"""All engines have high decode load -> estimated ITL > SLA -> scale up."""
config = _build_load_config(itl=5.0) # 5ms SLA
shared_state = PlannerSharedState()
shared_state.num_d_workers = 2
planner = DecodePlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
# 2D regression: vary both num_decode_requests and sum_decode_kv_tokens
for n_req, kv in [
(5, 1000),
(10, 2000),
(15, 3000),
(20, 4000),
(25, 5000),
]:
fpm = _make_fpm(
sum_decode_kv_tokens=kv,
num_decode_requests=n_req,
wall_time=0.0001 * kv + 0.0005 * n_req + 0.001,
)
planner.itl_regression.add_observation(fpm)
stats = {
("w1", 0): _make_fpm(
worker_id="w1",
sum_decode_kv_tokens=5000,
queued_decode_kv_tokens=3000,
num_decode_requests=20,
wall_time=0.6,
),
("w2", 0): _make_fpm(
worker_id="w2",
sum_decode_kv_tokens=4500,
queued_decode_kv_tokens=2500,
num_decode_requests=18,
wall_time=0.55,
),
}
planner.fpm_subscriber = _mock_fpm_subscriber(stats)
result = planner.load_plan_adjustment()
assert result == 3
def test_cold_start_returns_none(self):
config = _build_load_config()
shared_state = PlannerSharedState()
shared_state.num_d_workers = 2
planner = DecodePlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
fpm = _make_fpm(
sum_decode_kv_tokens=1000, num_decode_requests=5, wall_time=0.01
)
planner.itl_regression.add_observation(fpm)
stats = {
("w1", 0): _make_fpm(
sum_decode_kv_tokens=5000, num_decode_requests=10, wall_time=0.5
)
}
planner.fpm_subscriber = _mock_fpm_subscriber(stats)
result = planner.load_plan_adjustment()
assert result is None
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import math
import os
from unittest.mock import MagicMock, Mock, patch
import pytest
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.budget import _initialize_gpu_counts
from dynamo.planner.core.decode import DecodePlanner
from dynamo.planner.core.prefill import PrefillPlanner
from dynamo.planner.core.state import PlannerSharedState
from dynamo.planner.errors import DeploymentValidationError
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.planner,
]
PREFILL_ENGINE_RPS = 10.0
DECODE_ENGINE_RPS = 5.0
DECODE_ACTUAL_ITL_MS = 40.0
@pytest.fixture(autouse=True)
def mock_prometheus_metrics():
with patch("dynamo.planner.monitoring.planner_metrics.Gauge") as mock_gauge:
mock_gauge.return_value = Mock()
yield
def _build_config():
return PlannerConfig.model_construct(
throughput_adjustment_interval=60,
prefill_engine_num_gpu=1,
decode_engine_num_gpu=1,
min_endpoint=1,
max_gpu_budget=-1,
ttft=500.0,
itl=50.0,
backend="vllm",
no_operation=True,
metric_pulling_prometheus_endpoint="http://localhost:9090",
metric_reporting_prometheus_port=0,
load_predictor="constant",
load_predictor_warmup_trace=None,
load_predictor_log1p=False,
profile_results_dir=os.path.join(
os.path.dirname(__file__),
"..",
"data",
"profiling_results",
"H200_TP1P_TP1D",
),
environment="kubernetes",
namespace="test-namespace",
mode="disagg",
enable_throughput_scaling=True,
enable_load_scaling=False,
)
def _build_prometheus_client(samples):
client = Mock()
client.get_avg_time_to_first_token.side_effect = [
s["ttft_ms"] / 1000 for s in samples
]
client.get_avg_inter_token_latency.side_effect = [
s["itl_ms"] / 1000 for s in samples
]
client.get_avg_request_count.side_effect = [s["num_req"] for s in samples]
client.get_avg_request_duration.side_effect = [
s["request_duration"] for s in samples
]
client.get_avg_input_sequence_tokens.side_effect = [s["isl"] for s in samples]
client.get_avg_output_sequence_tokens.side_effect = [s["osl"] for s in samples]
return client
def _build_planners(config, prometheus_client):
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(None, config, shared_state=shared_state)
decode_planner = DecodePlanner(None, config, shared_state=shared_state)
prefill_planner.prometheus_traffic_client = prometheus_client
decode_planner.prometheus_traffic_client = prometheus_client
prefill_planner.model_name = "test-model"
decode_planner.model_name = "test-model"
prefill_planner.ttft_regression = MagicMock()
prefill_planner.ttft_regression.find_best_engine_prefill_rps.return_value = (
PREFILL_ENGINE_RPS,
75.0,
)
prefill_planner.ttft_regression.has_sufficient_data.return_value = True
decode_planner.itl_regression = MagicMock()
decode_planner.itl_regression.find_best_engine_decode_rps.return_value = (
DECODE_ENGINE_RPS,
DECODE_ACTUAL_ITL_MS,
)
decode_planner.itl_regression.has_sufficient_data.return_value = True
async def mock_get_workers_info(require_prefill=True, require_decode=True):
return (
1 if require_prefill else 0,
1 if require_decode else 0,
True, # is_stable
)
prefill_planner.get_workers_info = mock_get_workers_info
decode_planner.get_workers_info = mock_get_workers_info
return prefill_planner, decode_planner, shared_state
def _expected_prefill(config, prefill_planner, sample):
demand_rps = sample["num_req"] / config.throughput_adjustment_interval
engine_rps, _ = prefill_planner.ttft_regression.find_best_engine_prefill_rps(
ttft_sla=config.ttft, isl=sample["isl"]
)
expected = math.ceil(demand_rps / engine_rps)
return max(expected, config.min_endpoint)
def _expected_decode(config, decode_planner, sample):
demand_rps = sample["num_req"] / config.throughput_adjustment_interval
engine_rps, _ = decode_planner.itl_regression.find_best_engine_decode_rps(
itl=config.itl, context_length=sample["isl"] + sample["osl"] / 2
)
expected = math.ceil(demand_rps / engine_rps)
return max(expected, config.min_endpoint)
def _run_interval(prefill_planner, decode_planner, shared_state):
asyncio.run(
prefill_planner.observe_traffic_stats(require_prefill=True, require_decode=True)
)
decode_planner.update_predictors_from_metrics(shared_state.last_metrics)
next_num_p = prefill_planner.plan_adjustment()
next_num_d = decode_planner.plan_adjustment()
return next_num_p, next_num_d
def test_disagg_scale_up():
config = _build_config()
samples = [
{
"num_req": 10,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
{
"num_req": 5000,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
]
client = _build_prometheus_client(samples)
prefill_planner, decode_planner, shared_state = _build_planners(config, client)
low_p, low_d = _run_interval(prefill_planner, decode_planner, shared_state)
high_p, high_d = _run_interval(prefill_planner, decode_planner, shared_state)
assert low_p == _expected_prefill(config, prefill_planner, samples[0])
assert low_d == _expected_decode(config, decode_planner, samples[0])
assert high_p == _expected_prefill(config, prefill_planner, samples[1])
assert high_d == _expected_decode(config, decode_planner, samples[1])
assert high_p > low_p
assert high_d > low_d
def test_disagg_scale_down():
config = _build_config()
samples = [
{
"num_req": 5000,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
{
"num_req": 10,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
]
client = _build_prometheus_client(samples)
prefill_planner, decode_planner, shared_state = _build_planners(config, client)
high_p, high_d = _run_interval(prefill_planner, decode_planner, shared_state)
low_p, low_d = _run_interval(prefill_planner, decode_planner, shared_state)
assert high_p == _expected_prefill(config, prefill_planner, samples[0])
assert high_d == _expected_decode(config, decode_planner, samples[0])
assert low_p == _expected_prefill(config, prefill_planner, samples[1])
assert low_d == _expected_decode(config, decode_planner, samples[1])
assert low_p < high_p
assert low_d < high_d
class TestInitializeGpuCounts:
@staticmethod
def _make_config(**overrides):
defaults = dict(prefill_engine_num_gpu=None, decode_engine_num_gpu=None)
defaults.update(overrides)
return PlannerConfig.model_construct(**defaults)
def test_kubernetes_mode_reads_from_dgd(self):
"""Test that GPU counts are read from DGD in Kubernetes mode"""
config = self._make_config()
connector = Mock()
connector.get_gpu_counts = Mock(return_value=(2, 4))
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert config.prefill_engine_num_gpu == 2
assert config.decode_engine_num_gpu == 4
connector.get_gpu_counts.assert_called_once_with(
require_prefill=True, require_decode=True
)
def test_kubernetes_mode_prefill_only(self):
"""Test GPU count initialization for prefill-only mode"""
config = self._make_config()
connector = Mock()
connector.get_gpu_counts = Mock(return_value=(2, 0))
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=False
)
assert config.prefill_engine_num_gpu == 2
assert config.decode_engine_num_gpu == 0
connector.get_gpu_counts.assert_called_once_with(
require_prefill=True, require_decode=False
)
def test_virtual_mode_uses_cli_args(self):
"""Test that GPU counts come from config in virtual mode"""
config = self._make_config(prefill_engine_num_gpu=2, decode_engine_num_gpu=4)
connector = Mock(spec=[])
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert config.prefill_engine_num_gpu == 2
assert config.decode_engine_num_gpu == 4
def test_virtual_mode_missing_prefill_raises_error(self):
"""Test that missing prefill GPU config raises error in virtual mode"""
config = self._make_config(decode_engine_num_gpu=4)
connector = Mock(spec=[])
with pytest.raises(DeploymentValidationError) as exc_info:
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert "prefill_engine_num_gpu" in str(exc_info.value)
def test_virtual_mode_missing_decode_raises_error(self):
"""Test that missing decode GPU config raises error in virtual mode"""
config = self._make_config(prefill_engine_num_gpu=2)
connector = Mock(spec=[])
with pytest.raises(DeploymentValidationError) as exc_info:
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert "decode_engine_num_gpu" in str(exc_info.value)
def test_virtual_mode_missing_both_raises_error_with_both_messages(self):
"""Test that missing both GPU configs shows both error messages"""
config = self._make_config()
connector = Mock(spec=[])
with pytest.raises(DeploymentValidationError) as exc_info:
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert len(exc_info.value.errors) == 2
def test_virtual_mode_decode_only_no_prefill_error(self):
"""Test decode-only mode doesn't require prefill GPU config"""
config = self._make_config(decode_engine_num_gpu=4)
connector = Mock(spec=[])
_initialize_gpu_counts(
config, connector, require_prefill=False, require_decode=True
)
assert config.decode_engine_num_gpu == 4
def test_kubernetes_mode_fallback_to_cli_on_dgd_error(self):
"""Test that K8s mode falls back to config when DGD parsing fails"""
config = self._make_config(prefill_engine_num_gpu=2, decode_engine_num_gpu=4)
connector = Mock()
connector.get_gpu_counts = Mock(
side_effect=ValueError("No GPU count specified")
)
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert config.prefill_engine_num_gpu == 2
assert config.decode_engine_num_gpu == 4
def test_kubernetes_mode_fallback_missing_cli_flags_raises_error(self):
"""Test that K8s fallback raises error when config also missing"""
config = self._make_config()
connector = Mock()
connector.get_gpu_counts = Mock(
side_effect=ValueError("No GPU count specified")
)
with pytest.raises(DeploymentValidationError) as exc_info:
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert len(exc_info.value.errors) == 2
def test_kubernetes_mode_fallback_partial_cli_flags(self):
"""Test K8s fallback with only one config value provided"""
config = self._make_config(prefill_engine_num_gpu=2)
connector = Mock()
connector.get_gpu_counts = Mock(
side_effect=ValueError("No GPU count specified")
)
with pytest.raises(DeploymentValidationError) as exc_info:
_initialize_gpu_counts(
config, connector, require_prefill=True, require_decode=True
)
assert "decode_engine_num_gpu" in str(exc_info.value)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment