Unverified Commit 66f7832a authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat(planner): unify throughput and load scaling on FPM regression (#7961)

parent 0b7a18ce
...@@ -2,7 +2,11 @@ ...@@ -2,7 +2,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
This module converts planner profiler's results for mocker to use. Convert planner profiler results to mocker-compatible NPZ format.
Uses the FPM-based regression models from ``dynamo.planner.core.perf_model``
to evaluate prefill TTFT and decode ITL on a regular grid, producing the
lookup tables that the mocker uses for latency simulation.
Example prefill query: Example prefill query:
input: input:
...@@ -10,8 +14,6 @@ Example prefill query: ...@@ -10,8 +14,6 @@ Example prefill query:
1. binary search prefill_isl to find isl_idx 1. binary search prefill_isl to find isl_idx
2. predicted TTFT is prefill_ttft_ms[isl_idx] 2. predicted TTFT is prefill_ttft_ms[isl_idx]
For chunked prefill, can ignore the KV cache read time and use ISL=prefill_tokens in this iteration.
This ignores the KV read time, which might leads to slightly lower latency..
Example decode query: Example decode query:
input: input:
...@@ -22,20 +24,18 @@ Example decode query: ...@@ -22,20 +24,18 @@ Example decode query:
2. binary search decode_active_kv_tokens to find kv_idx 2. binary search decode_active_kv_tokens to find kv_idx
3. binary search decode_context_length to find context_idx 3. binary search decode_context_length to find context_idx
4. predicted ITL is decode_itl[kv_idx, context_idx] 4. predicted ITL is decode_itl[kv_idx, context_idx]
For aggregated engines, can separately query prefill and decode and use their sum as the total latency.
This ignores the fact that active tokens' up/down projection is usually combine in one kernel,
and might leads to slightly higher latency.
""" """
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any
import numpy as np import numpy as np
from dynamo.planner.core.throughput.interpolation import ( from dynamo.planner.core.perf_model import DecodeRegressionModel, PrefillRegressionModel
DecodeInterpolator, from dynamo.planner.monitoring.perf_metrics import (
PrefillInterpolator, _convert_decode_profiling,
_convert_prefill_profiling,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -46,14 +46,17 @@ def convert_profile_results_to_npz( ...@@ -46,14 +46,17 @@ def convert_profile_results_to_npz(
output_path: str | Path, output_path: str | Path,
resolution: int = 100, resolution: int = 100,
) -> Path: ) -> Path:
""" """Convert planner profiler results to mocker-compatible NPZ format.
Convert planner profiler results directory to mocker-compatible NPZ format.
Loads the profiler's raw data (npz or JSON), fits FPM regression
models, and evaluates them on a regular grid to produce the lookup
tables the mocker expects.
Args: Args:
profile_results_dir: Path to directory containing selected_prefill_interpolation profile_results_dir: Path containing selected_prefill_interpolation
and selected_decode_interpolation subdirectories with raw_data.npz files. and selected_decode_interpolation subdirectories with raw_data.
output_path: Full path where the output perf_data.npz will be written. output_path: Full path where the output perf_data.npz will be written.
resolution: Resolution for the interpolation grid (default: 100). resolution: Resolution for the evaluation grid (default: 100).
Returns: Returns:
Path to the generated NPZ file. Path to the generated NPZ file.
...@@ -63,27 +66,68 @@ def convert_profile_results_to_npz( ...@@ -63,27 +66,68 @@ def convert_profile_results_to_npz(
logger.info(f"Converting profile results from {profile_results_dir}...") logger.info(f"Converting profile results from {profile_results_dir}...")
# Convert prefill data result: dict[str, Any] = {}
prefill_interpolator = PrefillInterpolator(profile_results_dir)
prefill_x = np.linspace( # --- Prefill: fit 1D model, evaluate TTFT on ISL grid ---
prefill_interpolator.ttft_interpolator.x.min(), prefill_fpms = _convert_prefill_profiling(profile_results_dir)
prefill_interpolator.ttft_interpolator.x.max(), if not prefill_fpms:
resolution, raise FileNotFoundError(
f"No prefill profiling data found in {profile_results_dir}"
) )
prefill_y = prefill_interpolator.ttft_interpolator(prefill_x)
result = { prefill_model = PrefillRegressionModel(
"prefill_isl": prefill_x.tolist(), max_num_fpm_samples=len(prefill_fpms) + 10,
"prefill_ttft_ms": prefill_y.tolist(), min_observations=1,
} )
prefill_model.load_benchmark_fpms(prefill_fpms)
if not prefill_model._ensure_fitted():
raise RuntimeError("Failed to fit prefill regression from profiling data")
isl_values = [
float(fpm.scheduled_requests.sum_prefill_tokens) for fpm in prefill_fpms
]
prefill_x = np.linspace(min(isl_values), max(isl_values), resolution)
prefill_y = np.array(
[prefill_model._predict_wall_time(isl) * 1000.0 for isl in prefill_x]
)
result["prefill_isl"] = prefill_x.tolist()
result["prefill_ttft_ms"] = prefill_y.tolist()
# Convert decode data # --- Decode: fit 2D model, evaluate ITL on (kv_tokens, context_length) grid ---
decode_interpolator = DecodeInterpolator(profile_results_dir, resolution=resolution) decode_fpms = _convert_decode_profiling(profile_results_dir)
if not decode_fpms:
raise FileNotFoundError(
f"No decode profiling data found in {profile_results_dir}"
)
decode_model = DecodeRegressionModel(
max_num_fpm_samples=len(decode_fpms) + 10,
min_observations=1,
)
decode_model.load_benchmark_fpms(decode_fpms)
if not decode_model._ensure_fitted():
raise RuntimeError("Failed to fit decode regression from profiling data")
decode_active_kv_tokens = decode_interpolator.xi * decode_interpolator.max_kv_tokens max_kv = max(
decode_context_length = decode_interpolator.yi float(fpm.scheduled_requests.sum_decode_kv_tokens) for fpm in decode_fpms
decode_itl = decode_interpolator.itl_interpolator.transpose() )
ctx_values = [
float(fpm.scheduled_requests.sum_decode_kv_tokens)
/ max(1, fpm.scheduled_requests.num_decode_requests)
for fpm in decode_fpms
if fpm.scheduled_requests.num_decode_requests > 0
]
max_ctx = max(ctx_values) if ctx_values else 8192.0
decode_active_kv_tokens = np.linspace(0, max_kv, resolution)
decode_context_length = np.linspace(1, max_ctx, resolution)
decode_itl = np.zeros((resolution, resolution))
for i, kv in enumerate(decode_active_kv_tokens):
for j, ctx in enumerate(decode_context_length):
bs = max(1, kv / ctx) if ctx > 0 else 1
decode_itl[i, j] = decode_model._predict_2d(bs, kv) * 1000.0
result["decode_active_kv_tokens"] = decode_active_kv_tokens.tolist() result["decode_active_kv_tokens"] = decode_active_kv_tokens.tolist()
result["decode_context_length"] = decode_context_length.tolist() result["decode_context_length"] = decode_context_length.tolist()
...@@ -96,18 +140,11 @@ def convert_profile_results_to_npz( ...@@ -96,18 +140,11 @@ def convert_profile_results_to_npz(
def is_profile_results_dir(path: Path) -> bool: def is_profile_results_dir(path: Path) -> bool:
""" """Check if the given path is a profile results directory.
Check if the given path is a profile results directory (profiler-style format).
A profile results directory contains: A profile results directory contains:
- selected_prefill_interpolation/raw_data.npz (or prefill_raw_data.json) - selected_prefill_interpolation/raw_data.npz (or prefill_raw_data.json)
- selected_decode_interpolation/raw_data.npz (or decode_raw_data.json) - selected_decode_interpolation/raw_data.npz (or decode_raw_data.json)
Args:
path: Path to check.
Returns:
True if path is a profile results directory, False otherwise.
""" """
if not path.is_dir(): if not path.is_dir():
return False return False
...@@ -124,18 +161,11 @@ def is_profile_results_dir(path: Path) -> bool: ...@@ -124,18 +161,11 @@ def is_profile_results_dir(path: Path) -> bool:
def is_mocker_format_npz(path: Path) -> bool: def is_mocker_format_npz(path: Path) -> bool:
""" """Check if the given path is a mocker-format NPZ file.
Check if the given path is a mocker-format NPZ file.
A mocker-format NPZ file contains: A mocker-format NPZ file contains:
- prefill_isl, prefill_ttft_ms - prefill_isl, prefill_ttft_ms
- decode_active_kv_tokens, decode_context_length, decode_itl - decode_active_kv_tokens, decode_context_length, decode_itl
Args:
path: Path to check.
Returns:
True if path is a valid mocker-format NPZ file, False otherwise.
""" """
if not path.is_file(): if not path.is_file():
return False return False
......
...@@ -58,7 +58,6 @@ class SLAPlannerDefaults(BasePlannerDefaults): ...@@ -58,7 +58,6 @@ class SLAPlannerDefaults(BasePlannerDefaults):
kalman_r = 10.0 kalman_r = 10.0
kalman_min_points = 5 kalman_min_points = 5
no_correction = True
mode: Literal["disagg", "prefill", "decode", "agg"] = "disagg" mode: Literal["disagg", "prefill", "decode", "agg"] = "disagg"
throughput_metrics_source: Literal["frontend", "router"] = "frontend" throughput_metrics_source: Literal["frontend", "router"] = "frontend"
...@@ -68,8 +67,11 @@ class SLAPlannerDefaults(BasePlannerDefaults): ...@@ -68,8 +67,11 @@ class SLAPlannerDefaults(BasePlannerDefaults):
enable_load_scaling = False enable_load_scaling = False
# Load-based scaling settings # Load-based scaling settings
load_adjustment_interval = 5 # in seconds, must be < throughput_adjustment_interval load_adjustment_interval = 5 # in seconds; also controls FPM regression update frequency for throughput scaling
load_learning_window = 50 # sliding window size for regression max_num_fpm_samples = 64 # max retained FPM observations for regression
fpm_sample_bucket_size = (
16 # must be a perfect square; total buckets across input axes
)
load_scaling_down_sensitivity = 80 # 0-100 load_scaling_down_sensitivity = 80 # 0-100
load_metric_samples = 10 # number of samples per interval load_metric_samples = 10 # number of samples per interval
load_min_observations = 5 # cold start threshold load_min_observations = 5 # cold start threshold
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import json import json
import logging import logging
import math
import os import os
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
...@@ -97,7 +98,6 @@ class PlannerConfig(BaseModel): ...@@ -97,7 +98,6 @@ class PlannerConfig(BaseModel):
"frontend", "router" "frontend", "router"
] = SLAPlannerDefaults.throughput_metrics_source ] = SLAPlannerDefaults.throughput_metrics_source
no_correction: bool = SLAPlannerDefaults.no_correction
model_name: Optional[str] = None model_name: Optional[str] = None
# Global planner environment # Global planner environment
...@@ -108,8 +108,18 @@ class PlannerConfig(BaseModel): ...@@ -108,8 +108,18 @@ class PlannerConfig(BaseModel):
enable_load_scaling: bool = SLAPlannerDefaults.enable_load_scaling enable_load_scaling: bool = SLAPlannerDefaults.enable_load_scaling
# Load-based scaling settings # Load-based scaling settings
load_adjustment_interval: int = SLAPlannerDefaults.load_adjustment_interval load_adjustment_interval: int = Field(
load_learning_window: int = SLAPlannerDefaults.load_learning_window default=SLAPlannerDefaults.load_adjustment_interval,
description=(
"Interval in seconds for FPM regression model updates AND load-based "
"scaling decisions. Even when only throughput-based scaling is enabled, "
"live FPM observations are fed into the regression at this interval to "
"keep the performance model accurate. Must be shorter than "
"throughput_adjustment_interval."
),
)
max_num_fpm_samples: int = SLAPlannerDefaults.max_num_fpm_samples
fpm_sample_bucket_size: int = SLAPlannerDefaults.fpm_sample_bucket_size
load_scaling_down_sensitivity: int = ( load_scaling_down_sensitivity: int = (
SLAPlannerDefaults.load_scaling_down_sensitivity SLAPlannerDefaults.load_scaling_down_sensitivity
) )
...@@ -118,7 +128,13 @@ class PlannerConfig(BaseModel): ...@@ -118,7 +128,13 @@ class PlannerConfig(BaseModel):
@model_validator(mode="after") @model_validator(mode="after")
def _validate_config(self) -> "PlannerConfig": def _validate_config(self) -> "PlannerConfig":
# global-planner environment requires a namespace sqrt = math.isqrt(self.fpm_sample_bucket_size)
if sqrt * sqrt != self.fpm_sample_bucket_size:
raise ValueError(
f"fpm_sample_bucket_size must be a perfect square, "
f"got {self.fpm_sample_bucket_size}"
)
if self.environment == "global-planner" and not self.global_planner_namespace: if self.environment == "global-planner" and not self.global_planner_namespace:
raise ValueError( raise ValueError(
"global_planner_namespace is required when environment='global-planner'. " "global_planner_namespace is required when environment='global-planner'. "
...@@ -145,7 +161,6 @@ class PlannerConfig(BaseModel): ...@@ -145,7 +161,6 @@ class PlannerConfig(BaseModel):
) )
if self.enable_load_scaling: if self.enable_load_scaling:
# Load-based interval must be shorter than throughput interval
if self.enable_throughput_scaling: if self.enable_throughput_scaling:
if self.load_adjustment_interval >= self.throughput_adjustment_interval: if self.load_adjustment_interval >= self.throughput_adjustment_interval:
raise ValueError( raise ValueError(
...@@ -155,15 +170,6 @@ class PlannerConfig(BaseModel): ...@@ -155,15 +170,6 @@ class PlannerConfig(BaseModel):
"slow predictive loop." "slow predictive loop."
) )
# Auto-disable correction factor when load-based scaling is enabled
if not self.no_correction:
logger.warning(
"Correction factor is automatically disabled when load-based "
"scaling is enabled. Load-based scaling already accounts for "
"actual latency conditions."
)
self.no_correction = True
return self return self
@classmethod @classmethod
......
...@@ -3,23 +3,26 @@ ...@@ -3,23 +3,26 @@
import asyncio import asyncio
import logging import logging
import math
import time
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from dynamo.planner.config.backend_components import WORKER_COMPONENT_NAMES from dynamo.planner.config.backend_components import WORKER_COMPONENT_NAMES
from dynamo.planner.config.defaults import SubComponentType, TargetReplica from dynamo.planner.config.defaults import SubComponentType, TargetReplica
from dynamo.planner.config.planner_config import PlannerConfig from dynamo.planner.config.planner_config import PlannerConfig
if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.planner.core.base import BasePlanner from dynamo.planner.core.base import BasePlanner
from dynamo.planner.core.budget import ( from dynamo.planner.core.budget import (
_apply_component_gpu_budget, _apply_component_gpu_budget,
_initialize_gpu_counts, _initialize_gpu_counts,
) )
from dynamo.planner.core.perf_model import AggRegressionModel
from dynamo.planner.core.state import PlannerSharedState from dynamo.planner.core.state import PlannerSharedState
from dynamo.planner.monitoring.perf_metrics import fetch_pre_deployment_metrics
from dynamo.planner.monitoring.planner_metrics import PlannerPrometheusMetrics from dynamo.planner.monitoring.planner_metrics import PlannerPrometheusMetrics
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging() configure_dynamo_logging()
...@@ -27,19 +30,25 @@ logger = logging.getLogger(__name__) ...@@ -27,19 +30,25 @@ logger = logging.getLogger(__name__)
class AggPlanner: class AggPlanner:
"""Aggregated planner: FPM-driven load-based scaling, single engine type. """Aggregated planner: FPM-driven scaling for single engine type.
In aggregated mode, engines handle both prefill and decode (chunked prefill). In aggregated mode, engines handle both prefill and decode (chunked prefill).
A single AggRegressionModel maps (sum_prefill_tokens, sum_decode_kv_tokens) A single AggRegressionModel maps (sum_prefill_tokens, sum_decode_kv_tokens)
to wall_time using 2D linear regression. to wall_time using 2D linear regression.
Scaling logic: Supports load-only, throughput-only, or both scaling modes.
Scaling logic (load-based):
- Estimate next TTFT per engine by simulating prefill chunking with - Estimate next TTFT per engine by simulating prefill chunking with
piggybacked decode (steady-state decode load). piggybacked decode (steady-state decode load).
- Estimate next ITL per engine by predicting decode iteration time with - Estimate next ITL per engine by predicting decode iteration time with
average piggybacked prefill load. average piggybacked prefill load.
- Scale up if (ALL TTFT > SLA) OR (ALL ITL > SLA). - Scale up if (ALL TTFT > SLA) OR (ALL ITL > SLA).
- Scale down if (ALL TTFT < SLA * sensitivity) AND (ALL ITL < SLA * sensitivity). - Scale down if (ALL TTFT < SLA * sensitivity) AND (ALL ITL < SLA * sensitivity).
Scaling logic (throughput-based):
- Use compute_agg_replicas() to find minimum replicas where both SLAs
are met under predicted traffic load.
""" """
def __init__(self, runtime: DistributedRuntime, config: PlannerConfig) -> None: def __init__(self, runtime: DistributedRuntime, config: PlannerConfig) -> None:
...@@ -47,14 +56,12 @@ class AggPlanner: ...@@ -47,14 +56,12 @@ class AggPlanner:
self.runtime = runtime self.runtime = runtime
self.shared_state = PlannerSharedState() self.shared_state = PlannerSharedState()
if config.enable_throughput_scaling: self.enable_throughput = config.enable_throughput_scaling
raise ValueError( self.enable_load = config.enable_load_scaling
"Aggregated planner only supports load-based scaling. "
"Set enable_throughput_scaling to false in the config." if not self.enable_throughput and not self.enable_load:
)
if not config.enable_load_scaling:
raise ValueError( raise ValueError(
"Aggregated planner requires enable_load_scaling to be true." "Aggregated planner requires at least one scaling mode enabled."
) )
prometheus_metrics = PlannerPrometheusMetrics() prometheus_metrics = PlannerPrometheusMetrics()
...@@ -68,11 +75,10 @@ class AggPlanner: ...@@ -68,11 +75,10 @@ class AggPlanner:
component_type=SubComponentType.DECODE, component_type=SubComponentType.DECODE,
) )
from dynamo.planner.core.load.fpm_regression import AggRegressionModel
self.regression = AggRegressionModel( self.regression = AggRegressionModel(
window_size=config.load_learning_window, max_num_fpm_samples=config.max_num_fpm_samples,
min_observations=config.load_min_observations, min_observations=config.load_min_observations,
bucket_count=config.fpm_sample_bucket_size,
) )
async def _async_init(self): async def _async_init(self):
...@@ -107,20 +113,153 @@ class AggPlanner: ...@@ -107,20 +113,153 @@ class AggPlanner:
await self.planner._init_worker_info(require_prefill=False, require_decode=True) await self.planner._init_worker_info(require_prefill=False, require_decode=True)
# Delegate FPM tracking to the inner BasePlanner (component_type=DECODE).
if self.runtime is not None: if self.runtime is not None:
await self.planner._init_fpm_subscriber() await self.planner._init_fpm_subscriber()
await self._bootstrap_regression()
async def _bootstrap_regression(self) -> None:
"""Bootstrap agg regression from pre-deployment benchmark data."""
worker_info = self.planner.decode_worker_info
try:
fpms = await fetch_pre_deployment_metrics(
runtime=self.runtime,
namespace=self.config.namespace,
worker_info=worker_info,
profile_results_dir=self.config.profile_results_dir,
component_type=SubComponentType.DECODE,
)
self.regression.load_benchmark_fpms(fpms)
logger.info(
f"Bootstrapped agg regression with {len(fpms)} pre-deployment FPMs"
)
except Exception as e:
if self.enable_throughput:
raise
logger.warning(
f"No pre-deployment data for agg regression: {e}. "
"Load-based scaling will learn from live FPM only."
)
async def run(self): async def run(self):
"""Main scaling loop. Call _async_init() before this.""" """Main scaling loop. Call _async_init() before this."""
await asyncio.gather(self._load_loop()) self.shared_state.last_adjustment_time = time.time()
loops = []
if self.enable_throughput:
loops.append(self._throughput_loop())
loops.append(self._load_and_fpm_update_loop())
async def _load_loop(self) -> None: await asyncio.gather(*loops)
"""FPM-driven load-based scaling loop for aggregated mode."""
async def _throughput_loop(self) -> None:
"""Throughput-based scaling loop for agg mode."""
while True:
current_time = time.time()
if (
current_time - self.shared_state.last_adjustment_time
>= self.config.throughput_adjustment_interval
):
self.shared_state.last_adjustment_time = time.time()
logger.info("New agg throughput adjustment interval started!")
await self.planner.observe_traffic_stats(
require_prefill=False, require_decode=True
)
metrics = self.shared_state.last_metrics
if not metrics.is_valid():
logger.info("Metrics invalid, skipping agg throughput adjustment")
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
continue
next_num_req = self.planner.num_req_predictor.predict_next()
next_isl = self.planner.isl_predictor.predict_next()
next_osl = self.planner.osl_predictor.predict_next()
max_num_batched_tokens = getattr(
self.planner.decode_worker_info, "max_num_batched_tokens", None
)
if not max_num_batched_tokens or max_num_batched_tokens <= 0:
logger.warning(
"max_num_batched_tokens not available, skipping agg throughput"
)
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
continue
(
engine_rps,
actual_ttft,
actual_itl,
) = self.regression.find_best_engine_agg_rps(
isl=next_isl,
osl=next_osl,
max_num_batched_tokens=max_num_batched_tokens,
ttft_sla=self.config.ttft,
itl_sla=self.config.itl,
)
if engine_rps <= 0:
logger.warning(
"Agg perf model not ready, skipping throughput scaling"
)
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
continue
if actual_ttft > self.config.ttft or actual_itl > self.config.itl:
logger.warning(
f"Agg SLA not fully met: TTFT={actual_ttft:.1f}ms "
f"(target {self.config.ttft:.1f}ms), "
f"ITL={actual_itl:.1f}ms (target {self.config.itl:.1f}ms), "
"scaling with best achievable rate"
)
demand_rps = next_num_req / self.config.throughput_adjustment_interval
desired = math.ceil(demand_rps / engine_rps)
desired = max(desired, self.config.min_endpoint)
logger.info(
f"Agg: {demand_rps:.2f}(demand rps) / "
f"{engine_rps:.2f}(engine rps) = {desired}(replicas), "
f"est_ttft={actual_ttft:.1f}ms, est_itl={actual_itl:.1f}ms"
)
if self.enable_load:
self.shared_state.throughput_lower_bound_d = desired
logger.info(f"Agg throughput lower bound set to {desired}")
else:
assert self.config.decode_engine_num_gpu is not None
desired = _apply_component_gpu_budget(
desired, self.config.decode_engine_num_gpu, self.config
)
if (
self.planner.prometheus_port != 0
and self.planner.prometheus_metrics is not None
):
self.planner.prometheus_metrics.predicted_num_d.set(desired)
if not self.config.no_operation:
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.planner.decode_worker_info.k8s_name,
desired_replicas=desired,
)
]
await self.planner.connector.set_component_replicas(
target_replicas, blocking=False
)
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
async def _load_and_fpm_update_loop(self) -> None:
"""FPM observation and (optionally) load-based scaling for agg mode.
Always updates regression with live FPM. When load-based scaling
is enabled, makes scaling decisions immediately after.
"""
pending_desired: Optional[int] = None pending_desired: Optional[int] = None
while True: while True:
await asyncio.sleep(self.config.load_adjustment_interval) await asyncio.sleep(self.config.load_adjustment_interval)
logger.info("New agg load-based adjustment interval started!") logger.info("New agg load/FPM update interval started!")
_, num_d, _ = await self.planner.get_workers_info( _, num_d, _ = await self.planner.get_workers_info(
require_prefill=False, require_decode=True require_prefill=False, require_decode=True
...@@ -128,17 +267,17 @@ class AggPlanner: ...@@ -128,17 +267,17 @@ class AggPlanner:
self.shared_state.num_d_workers = num_d self.shared_state.num_d_workers = num_d
num_workers = num_d num_workers = num_d
# Always observe FPM stats and update regression, even during scaling.
fpm_stats = self.planner._get_fpm_stats() fpm_stats = self.planner._get_fpm_stats()
if not fpm_stats: if not fpm_stats:
logger.warning("No FPM data available for agg engines")
continue continue
for (wid, dp), fpm in fpm_stats.items(): for (wid, dp), fpm in fpm_stats.items():
BasePlanner._log_fpm(wid, dp, fpm, "agg") BasePlanner._log_fpm(wid, dp, fpm, "agg")
self.regression.add_observation(fpm) self.regression.add_observation(fpm)
# If a previous scaling action is still in progress, skip decisions. if not self.enable_load:
continue
if pending_desired is not None: if pending_desired is not None:
if num_workers == pending_desired: if num_workers == pending_desired:
logger.info( logger.info(
...@@ -184,8 +323,6 @@ class AggPlanner: ...@@ -184,8 +323,6 @@ class AggPlanner:
f"(current={num_workers})" f"(current={num_workers})"
) )
# Scale up if EITHER dimension wants more workers.
# Scale down only if BOTH dimensions agree on fewer.
if p_desired is not None and p_desired > num_workers: if p_desired is not None and p_desired > num_workers:
desired = p_desired desired = p_desired
elif d_desired is not None and d_desired > num_workers: elif d_desired is not None and d_desired > num_workers:
...@@ -202,6 +339,8 @@ class AggPlanner: ...@@ -202,6 +339,8 @@ class AggPlanner:
continue continue
desired = max(desired, self.config.min_endpoint) desired = max(desired, self.config.min_endpoint)
if self.enable_throughput:
desired = max(desired, self.shared_state.throughput_lower_bound_d)
assert self.config.decode_engine_num_gpu is not None assert self.config.decode_engine_num_gpu is not None
desired = _apply_component_gpu_budget( desired = _apply_component_gpu_budget(
desired, self.config.decode_engine_num_gpu, self.config desired, self.config.decode_engine_num_gpu, self.config
...@@ -234,7 +373,6 @@ class AggPlanner: ...@@ -234,7 +373,6 @@ class AggPlanner:
num_workers: int, num_workers: int,
max_num_batched_tokens: int, max_num_batched_tokens: int,
) -> Optional[int]: ) -> Optional[int]:
"""Returns desired replica count for the prefill (TTFT) dimension, or None."""
estimated_ttfts: list[float] = [] estimated_ttfts: list[float] = []
for (wid, dp), fpm in fpm_stats.items(): for (wid, dp), fpm in fpm_stats.items():
est = self.regression.estimate_next_ttft( est = self.regression.estimate_next_ttft(
...@@ -254,7 +392,6 @@ class AggPlanner: ...@@ -254,7 +392,6 @@ class AggPlanner:
fpm_stats: "dict[tuple[str, int], ForwardPassMetrics]", fpm_stats: "dict[tuple[str, int], ForwardPassMetrics]",
num_workers: int, num_workers: int,
) -> Optional[int]: ) -> Optional[int]:
"""Returns desired replica count for the decode (ITL) dimension, or None."""
estimated_itls: list[float] = [] estimated_itls: list[float] = []
for (wid, dp), fpm in fpm_stats.items(): for (wid, dp), fpm in fpm_stats.items():
est = self.regression.estimate_next_itl( est = self.regression.estimate_next_itl(
......
...@@ -19,12 +19,9 @@ from dynamo.planner.core.budget import ( ...@@ -19,12 +19,9 @@ from dynamo.planner.core.budget import (
_initialize_gpu_counts, _initialize_gpu_counts,
) )
from dynamo.planner.core.load.predictors import LOAD_PREDICTORS from dynamo.planner.core.load.predictors import LOAD_PREDICTORS
from dynamo.planner.core.perf_model import DecodeRegressionModel, PrefillRegressionModel
from dynamo.planner.core.state import PlannerSharedState from dynamo.planner.core.state import PlannerSharedState
from dynamo.planner.core.throughput.interpolation import ( from dynamo.planner.monitoring.perf_metrics import fetch_pre_deployment_metrics
DecodeInterpolator,
PrefillInterpolator,
)
from dynamo.planner.core.throughput.pre_swept_results import PreSweptResultsHelper
from dynamo.planner.monitoring.planner_metrics import PlannerPrometheusMetrics from dynamo.planner.monitoring.planner_metrics import PlannerPrometheusMetrics
from dynamo.planner.monitoring.traffic_metrics import Metrics, PrometheusAPIClient from dynamo.planner.monitoring.traffic_metrics import Metrics, PrometheusAPIClient
from dynamo.planner.monitoring.worker_info import WorkerInfo, resolve_worker_info from dynamo.planner.monitoring.worker_info import WorkerInfo, resolve_worker_info
...@@ -33,6 +30,8 @@ from dynamo.planner.offline.trace_data import extract_metrics_from_mooncake ...@@ -33,6 +30,8 @@ from dynamo.planner.offline.trace_data import extract_metrics_from_mooncake
if TYPE_CHECKING: if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.llm import FpmEventSubscriber from dynamo.llm import FpmEventSubscriber
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -49,7 +48,6 @@ class BasePlanner: ...@@ -49,7 +48,6 @@ class BasePlanner:
self, self,
runtime: Optional[DistributedRuntime], runtime: Optional[DistributedRuntime],
config: PlannerConfig, config: PlannerConfig,
dryrun: bool = False,
shared_state: Optional[PlannerSharedState] = None, shared_state: Optional[PlannerSharedState] = None,
prometheus_metrics: Optional[PlannerPrometheusMetrics] = None, prometheus_metrics: Optional[PlannerPrometheusMetrics] = None,
prometheus_traffic_client: Optional[PrometheusAPIClient] = None, prometheus_traffic_client: Optional[PrometheusAPIClient] = None,
...@@ -61,19 +59,16 @@ class BasePlanner: ...@@ -61,19 +59,16 @@ class BasePlanner:
self.component_type = component_type self.component_type = component_type
self.config = config self.config = config
self.dryrun = dryrun
self.shared_state = shared_state or PlannerSharedState() self.shared_state = shared_state or PlannerSharedState()
# Rely on getting model name from connector
self.model_name: Optional[str] = None
if not self.dryrun:
self.runtime = runtime self.runtime = runtime
self.namespace = config.namespace self.namespace = config.namespace
self.model_name: Optional[str] = None
self.connector: ConnectorType self.connector: ConnectorType
if not config.no_operation: if connector is not None:
# Initialize connector based on environment self.connector = connector
elif not config.no_operation:
if config.environment == "global-planner": if config.environment == "global-planner":
assert config.global_planner_namespace is not None assert config.global_planner_namespace is not None
assert runtime is not None assert runtime is not None
...@@ -85,9 +80,7 @@ class BasePlanner: ...@@ -85,9 +80,7 @@ class BasePlanner:
config.model_name, config.model_name,
) )
elif config.environment == "kubernetes": elif config.environment == "kubernetes":
self.connector = KubernetesConnector( self.connector = KubernetesConnector(self.namespace, config.model_name)
self.namespace, self.model_name
)
elif config.environment == "virtual": elif config.environment == "virtual":
assert runtime is not None assert runtime is not None
self.connector = VirtualConnector( self.connector = VirtualConnector(
...@@ -144,67 +137,23 @@ class BasePlanner: ...@@ -144,67 +137,23 @@ class BasePlanner:
if hasattr(p, "reset_idle_skip"): if hasattr(p, "reset_idle_skip"):
p.reset_idle_skip() p.reset_idle_skip()
# Load-based scaling flags.
# Argument validation (flag resolution, constraint checks, correction factor
# auto-disable) is handled by validate_sla_planner_args() in planner_argparse.
self.enable_load = config.enable_load_scaling self.enable_load = config.enable_load_scaling
self.enable_throughput = config.enable_throughput_scaling self.enable_throughput = config.enable_throughput_scaling
# Only create interpolators when throughput-based scaling is enabled
# (they require profiling data that isn't needed for load-based-only mode)
if self.enable_throughput:
if "use-pre-swept-results" in config.profile_results_dir:
config_list = config.profile_results_dir.split(":")
configs = {
"gpu_type": config_list[1],
"model": config_list[2],
"framework": config_list[3],
"framework_version": config_list[4],
"tp": int(config_list[5]),
"dp": int(config_list[6]),
"pp": int(config_list[7]),
"block_size": int(config_list[8]),
"max_batch_size": int(config_list[9]),
"gpu_count": int(config_list[10]),
}
if self.dryrun:
pre_swept_results_helper = PreSweptResultsHelper(
configs["gpu_type"], configs["framework"], configs["model"]
)
raw_data = pre_swept_results_helper.select_data("prefill", configs)
self.prefill_interpolator = PrefillInterpolator(raw_data=raw_data)
raw_data = pre_swept_results_helper.select_data("decode", configs)
self.decode_interpolator = DecodeInterpolator(raw_data=raw_data)
else:
raise ValueError(
"Cannot set profile_results_dir to 'use-pre-swept-results' in non-dryrun mode"
)
else:
self.prefill_interpolator = PrefillInterpolator(
config.profile_results_dir
)
self.decode_interpolator = DecodeInterpolator(
config.profile_results_dir
)
# WorkerInfo: finalized by _init_worker_info() at the start of run().
# Empty placeholders until then.
self.prefill_worker_info = WorkerInfo() self.prefill_worker_info = WorkerInfo()
self.decode_worker_info = WorkerInfo() self.decode_worker_info = WorkerInfo()
self.prometheus_metrics: PlannerPrometheusMetrics | None = None
if not self.dryrun:
self.prefill_client = None self.prefill_client = None
self.workers_client = None self.workers_client = None
self.prometheus_port = config.metric_reporting_prometheus_port self.prometheus_port = config.metric_reporting_prometheus_port
self.prometheus_metrics: PlannerPrometheusMetrics | None = None
if prometheus_metrics is None: if prometheus_metrics is None:
self.prometheus_metrics = PlannerPrometheusMetrics() self.prometheus_metrics = PlannerPrometheusMetrics()
else: else:
self.prometheus_metrics = prometheus_metrics self.prometheus_metrics = prometheus_metrics
# Start Prometheus HTTP server if port is specified
if start_prometheus_server and self.prometheus_port != 0: if start_prometheus_server and self.prometheus_port != 0:
try: try:
start_http_server(self.prometheus_port) start_http_server(self.prometheus_port)
...@@ -213,34 +162,20 @@ class BasePlanner: ...@@ -213,34 +162,20 @@ class BasePlanner:
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to start Prometheus metrics server: {e}") logger.error(f"Failed to start Prometheus metrics server: {e}")
else:
self.prometheus_port = 0
self.prometheus_metrics = prometheus_metrics
self.p_correction_factor = 1.0
self.d_correction_factor = 1.0
if self.dryrun:
self.no_correction = True
else:
self.no_correction = config.no_correction
if self.enable_load:
from dynamo.planner.core.load.fpm_regression import (
DecodeRegressionModel,
PrefillRegressionModel,
)
self.fpm_subscriber: "Optional[FpmEventSubscriber]" = None self.fpm_subscriber: "Optional[FpmEventSubscriber]" = None
if self.component_type == SubComponentType.PREFILL: if self.component_type == SubComponentType.PREFILL:
self.ttft_regression = PrefillRegressionModel( self.ttft_regression = PrefillRegressionModel(
window_size=self.config.load_learning_window, max_num_fpm_samples=self.config.max_num_fpm_samples,
min_observations=self.config.load_min_observations, min_observations=self.config.load_min_observations,
bucket_count=self.config.fpm_sample_bucket_size,
) )
elif self.component_type == SubComponentType.DECODE: elif self.component_type == SubComponentType.DECODE:
self.itl_regression = DecodeRegressionModel( self.itl_regression = DecodeRegressionModel(
window_size=self.config.load_learning_window, max_num_fpm_samples=self.config.max_num_fpm_samples,
min_observations=self.config.load_min_observations, min_observations=self.config.load_min_observations,
bucket_count=self.config.fpm_sample_bucket_size,
) )
@property @property
...@@ -271,17 +206,13 @@ class BasePlanner: ...@@ -271,17 +206,13 @@ class BasePlanner:
async def _async_init(self): async def _async_init(self):
"""Async initialization: connector init, deployment validation, WorkerInfo.""" """Async initialization: connector init, deployment validation, WorkerInfo."""
if ( if hasattr(self, "connector") and hasattr(self.connector, "_async_init"):
not self.dryrun
and hasattr(self, "connector")
and hasattr(self.connector, "_async_init")
):
await self.connector._async_init() await self.connector._async_init()
require_prefill = self.component_type == SubComponentType.PREFILL require_prefill = self.component_type == SubComponentType.PREFILL
require_decode = self.component_type == SubComponentType.DECODE require_decode = self.component_type == SubComponentType.DECODE
if not self.dryrun and not self.config.no_operation: if not self.config.no_operation:
defaults = WORKER_COMPONENT_NAMES.get(self.config.backend) defaults = WORKER_COMPONENT_NAMES.get(self.config.backend)
logger.info("Validating deployment...") logger.info("Validating deployment...")
...@@ -315,11 +246,42 @@ class BasePlanner: ...@@ -315,11 +246,42 @@ class BasePlanner:
require_decode=require_decode, require_decode=require_decode,
) )
# Start FPM tracking if load-based scaling is enabled. if self.runtime is not None:
# The subscriber auto-discovers FPM publishers for this component.
if self.enable_load and self.runtime is not None:
await self._init_fpm_subscriber() await self._init_fpm_subscriber()
await self._bootstrap_regression()
async def _bootstrap_regression(self) -> None:
"""Fetch pre-deployment FPM data and bootstrap the regression model."""
worker_info = (
self.prefill_worker_info
if self.component_type == SubComponentType.PREFILL
else self.decode_worker_info
)
try:
fpms = await fetch_pre_deployment_metrics(
runtime=self.runtime,
namespace=self.namespace,
worker_info=worker_info,
profile_results_dir=self.config.profile_results_dir,
component_type=self.component_type,
)
if self.component_type == SubComponentType.PREFILL:
self.ttft_regression.load_benchmark_fpms(fpms)
elif self.component_type == SubComponentType.DECODE:
self.itl_regression.load_benchmark_fpms(fpms)
logger.info(
f"Bootstrapped {self.component_type.value} regression with "
f"{len(fpms)} pre-deployment FPMs"
)
except Exception as e:
if self.enable_throughput:
raise
logger.warning(
f"No pre-deployment data for {self.component_type.value} regression: {e}. "
"Load-based scaling will learn from live FPM only."
)
async def _init_fpm_subscriber(self) -> None: async def _init_fpm_subscriber(self) -> None:
"""Create and start the FPM subscriber for load-based scaling.""" """Create and start the FPM subscriber for load-based scaling."""
from dynamo.llm import FpmEventSubscriber from dynamo.llm import FpmEventSubscriber
...@@ -562,13 +524,6 @@ class BasePlanner: ...@@ -562,13 +524,6 @@ class BasePlanner:
logger.error(f"Failed to predict load: {e}") logger.error(f"Failed to predict load: {e}")
return None, None, None return None, None, None
def dryrun_observe_traffic_stats(
self, num_req: int, isl_avg: float, osl_avg: float
):
self.num_req_predictor.add_data_point(num_req)
self.isl_predictor.add_data_point(isl_avg)
self.osl_predictor.add_data_point(osl_avg)
def plan_adjustment(self) -> Optional[int]: def plan_adjustment(self) -> Optional[int]:
if not self.last_metrics.is_valid(): if not self.last_metrics.is_valid():
logger.info( logger.info(
...@@ -576,14 +531,6 @@ class BasePlanner: ...@@ -576,14 +531,6 @@ class BasePlanner:
) )
return None return None
if not self.no_correction:
try:
if not self._update_correction_factor():
return None
except Exception as e:
logger.error(f"Failed to correct prediction factors: {e}")
return None
next_num_req, next_isl, next_osl = self.predict_load() next_num_req, next_isl, next_osl = self.predict_load()
if next_num_req is None or next_isl is None or next_osl is None: if next_num_req is None or next_isl is None or next_osl is None:
return None return None
...@@ -607,10 +554,7 @@ class BasePlanner: ...@@ -607,10 +554,7 @@ class BasePlanner:
def _compute_replica_requirements( def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float self, next_num_req: float, next_isl: float, next_osl: float
) -> int: ) -> Optional[int]:
raise NotImplementedError
def _update_correction_factor(self) -> bool:
raise NotImplementedError raise NotImplementedError
def _component_name(self) -> str: def _component_name(self) -> str:
...@@ -646,18 +590,7 @@ class BasePlanner: ...@@ -646,18 +590,7 @@ class BasePlanner:
] ]
await self.connector.set_component_replicas(target_replicas, blocking=False) await self.connector.set_component_replicas(target_replicas, blocking=False)
async def _apply_scaling_blocking(self, desired_replicas: int) -> None: _apply_scaling_blocking = _apply_scaling
"""Apply scaling without blocking so the loop continues observing metrics."""
if self.config.no_operation:
return
target_replicas = [
TargetReplica(
sub_component_type=self.component_type,
component_name=self._component_name(),
desired_replicas=desired_replicas,
)
]
await self.connector.set_component_replicas(target_replicas, blocking=False)
@staticmethod @staticmethod
def _reconcile_fpm_worker_count( def _reconcile_fpm_worker_count(
...@@ -854,30 +787,34 @@ class BasePlanner: ...@@ -854,30 +787,34 @@ class BasePlanner:
await asyncio.sleep(self.config.throughput_adjustment_interval / 10) await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
async def _load_loop(self, require_prefill: bool, require_decode: bool) -> None: async def _load_and_fpm_update_loop(
"""Load-based scaling loop at shorter interval. self, require_prefill: bool, require_decode: bool
) -> None:
"""FPM observation and (optionally) load-based scaling loop.
Uses FPM stats from the event plane (via FpmEventSubscriber) instead Runs every load_adjustment_interval. Always updates the FPM
of scraping the router's /metrics endpoint. regression model with live observations. When load-based scaling
is enabled, also makes scaling decisions immediately after the
FPM update.
""" """
pending_desired: Optional[int] = None pending_desired: Optional[int] = None
while True: while True:
await asyncio.sleep(self.config.load_adjustment_interval) await asyncio.sleep(self.config.load_adjustment_interval)
logger.info("New load-based adjustment interval started!") logger.info("New load/FPM update interval started!")
# Query DGD for fresh worker counts
num_p, num_d, is_stable = await self.get_workers_info( num_p, num_d, is_stable = await self.get_workers_info(
require_prefill=require_prefill, require_decode=require_decode require_prefill=require_prefill, require_decode=require_decode
) )
self.shared_state.num_p_workers = num_p self.shared_state.num_p_workers = num_p
self.shared_state.num_d_workers = num_d self.shared_state.num_d_workers = num_d
# Always observe FPM stats and update regression, even during scaling.
fpm_stats = self.observe_fpm_load_stats() fpm_stats = self.observe_fpm_load_stats()
if not fpm_stats: if not fpm_stats:
continue continue
# If a previous scaling action is still in progress, skip decisions. if not self.enable_load:
continue
if pending_desired is not None: if pending_desired is not None:
dgd_count = ( dgd_count = (
num_p if self.component_type == SubComponentType.PREFILL else num_d num_p if self.component_type == SubComponentType.PREFILL else num_d
...@@ -905,7 +842,6 @@ class BasePlanner: ...@@ -905,7 +842,6 @@ class BasePlanner:
desired_replicas = self.load_plan_adjustment() desired_replicas = self.load_plan_adjustment()
if desired_replicas is not None: if desired_replicas is not None:
# Enforce lower bound from throughput-based
if self.enable_throughput: if self.enable_throughput:
if self.component_type == SubComponentType.PREFILL: if self.component_type == SubComponentType.PREFILL:
lower_bound = self.shared_state.throughput_lower_bound_p lower_bound = self.shared_state.throughput_lower_bound_p
...@@ -925,13 +861,9 @@ class BasePlanner: ...@@ -925,13 +861,9 @@ class BasePlanner:
self.shared_state.last_adjustment_time = time.time() self.shared_state.last_adjustment_time = time.time()
self.shared_state.last_load_adjustment_time = time.time() self.shared_state.last_load_adjustment_time = time.time()
# Build list of concurrent loops based on enabled scaling modes.
# FPM tracking (started in _async_init) replaces the former
# DirectRouterMetricsClient.run_sampling_loop().
loops = [] loops = []
if self.enable_throughput: if self.enable_throughput:
loops.append(self._throughput_loop(require_prefill, require_decode)) loops.append(self._throughput_loop(require_prefill, require_decode))
if self.enable_load: loops.append(self._load_and_fpm_update_loop(require_prefill, require_decode))
loops.append(self._load_loop(require_prefill, require_decode))
await asyncio.gather(*loops) await asyncio.gather(*loops)
...@@ -66,67 +66,29 @@ class DecodePlanner(BasePlanner): ...@@ -66,67 +66,29 @@ class DecodePlanner(BasePlanner):
label="decode ITL", label="decode ITL",
) )
def _update_correction_factor(self) -> bool:
if self.shared_state.num_d_workers == 0:
logger.warning(
"No decode workers found for correction factor, skipping correction update"
)
return True
assert self.last_metrics.num_req is not None
assert self.last_metrics.request_duration is not None
assert self.last_metrics.isl is not None
assert self.last_metrics.osl is not None
assert self.last_metrics.itl is not None
expect_itl = self.decode_interpolator.interpolate_itl(
concurrency=self.last_metrics.num_req
/ self.shared_state.num_d_workers
* self.last_metrics.request_duration
/ self.config.throughput_adjustment_interval,
context_length=self.last_metrics.isl + self.last_metrics.osl / 2,
)
self.d_correction_factor = self.last_metrics.itl / expect_itl
logger.info(f"Correction factor (decode ITL): {self.d_correction_factor:.3f}")
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.d_correction_factor.set(self.d_correction_factor)
return True
def _compute_replica_requirements( def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float self, next_num_req: float, next_isl: float, next_osl: float
) -> int: ) -> Optional[int]:
if self.d_correction_factor <= 0: demand_rps = next_num_req / self.config.throughput_adjustment_interval
logger.warning( engine_rps, actual_itl_ms = self.itl_regression.find_best_engine_decode_rps(
f"d_correction_factor is {self.d_correction_factor}, using default value of 1.0" itl=self.config.itl,
) context_length=next_isl + next_osl / 2,
corrected_itl = self.config.itl osl=next_osl,
else:
corrected_itl = self.config.itl / self.d_correction_factor
(
pred_decode_thpt_per_gpu,
_,
_,
) = self.decode_interpolator.find_best_throughput_per_gpu(
itl=corrected_itl, context_length=next_isl + next_osl / 2
) )
if pred_decode_thpt_per_gpu <= 0: if engine_rps <= 0:
logger.warning("Decode perf model not ready, skipping throughput scaling")
return None
if actual_itl_ms > self.config.itl:
logger.warning( logger.warning(
f"pred_decode_thpt_per_gpu is {pred_decode_thpt_per_gpu} " f"Decode ITL SLA not met: {actual_itl_ms:.1f}ms > "
"(no throughput satisfies ITL target), falling back to min_endpoint" f"{self.config.itl:.1f}ms, scaling with best achievable rate"
)
return self.config.min_endpoint
assert self.config.decode_engine_num_gpu is not None
pred_decode_throughput = (
next_num_req * next_osl / self.config.throughput_adjustment_interval
)
next_num_d = math.ceil(
pred_decode_throughput
/ pred_decode_thpt_per_gpu
/ self.config.decode_engine_num_gpu
) )
next_num_d = math.ceil(demand_rps / engine_rps)
next_num_d = max(next_num_d, self.config.min_endpoint) next_num_d = max(next_num_d, self.config.min_endpoint)
logger.info( logger.info(
f"Decode calculation: {pred_decode_throughput:.2f}(d_thpt) / " f"Decode: {demand_rps:.2f}(demand rps) / "
f"{pred_decode_thpt_per_gpu * self.config.decode_engine_num_gpu:.2f}(d_engine_cap) = " f"{engine_rps:.2f}(engine rps) = {next_num_d}(num_d), "
f"{next_num_d}(num_d)" f"est_itl={actual_itl_ms:.1f}ms"
) )
return next_num_d return next_num_d
......
...@@ -94,26 +94,23 @@ class DisaggPlanner: ...@@ -94,26 +94,23 @@ class DisaggPlanner:
self.decode_planner.decode_worker_info = self.prefill_planner.decode_worker_info self.decode_planner.decode_worker_info = self.prefill_planner.decode_worker_info
self.decode_planner.model_name = self.prefill_planner.model_name self.decode_planner.model_name = self.prefill_planner.model_name
# Start FPM tracking for both planners. DisaggPlanner bypasses each
# sub-planner's _async_init(), so we init subscribers explicitly here.
if self.enable_load:
if self.prefill_planner.runtime is not None: if self.prefill_planner.runtime is not None:
await self.prefill_planner._init_fpm_subscriber() await self.prefill_planner._init_fpm_subscriber()
if self.decode_planner.runtime is not None: if self.decode_planner.runtime is not None:
await self.decode_planner._init_fpm_subscriber() await self.decode_planner._init_fpm_subscriber()
await self.prefill_planner._bootstrap_regression()
await self.decode_planner._bootstrap_regression()
async def run(self): async def run(self):
"""Main scaling loop. Call _async_init() before this.""" """Main scaling loop. Call _async_init() before this."""
self.shared_state.last_adjustment_time = time.time() self.shared_state.last_adjustment_time = time.time()
self.shared_state.last_load_adjustment_time = time.time() self.shared_state.last_load_adjustment_time = time.time()
# FPM tracking (started in _async_init) replaces the former
# DirectRouterMetricsClient.run_sampling_loop().
loops = [] loops = []
if self.enable_throughput: if self.enable_throughput:
loops.append(self._throughput_loop()) loops.append(self._throughput_loop())
if self.enable_load: loops.append(self._load_and_fpm_update_loop())
loops.append(self._load_loop())
await asyncio.gather(*loops) await asyncio.gather(*loops)
...@@ -175,11 +172,15 @@ class DisaggPlanner: ...@@ -175,11 +172,15 @@ class DisaggPlanner:
await asyncio.sleep(self.config.throughput_adjustment_interval / 10) await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
async def _load_loop(self) -> None: async def _load_and_fpm_update_loop(self) -> None:
"""FPM-driven load-based scaling loop for disagg mode.""" """FPM observation and (optionally) load-based scaling for disagg mode.
Always updates regression models with live FPM. When load-based
scaling is enabled, makes scaling decisions immediately after.
"""
while True: while True:
await asyncio.sleep(self.config.load_adjustment_interval) await asyncio.sleep(self.config.load_adjustment_interval)
logger.info("New load-based adjustment interval started!") logger.info("New load/FPM update interval started!")
num_p, num_d, _ = await self.prefill_planner.get_workers_info( num_p, num_d, _ = await self.prefill_planner.get_workers_info(
require_prefill=True, require_decode=True require_prefill=True, require_decode=True
...@@ -187,10 +188,12 @@ class DisaggPlanner: ...@@ -187,10 +188,12 @@ class DisaggPlanner:
self.shared_state.num_p_workers = num_p self.shared_state.num_p_workers = num_p
self.shared_state.num_d_workers = num_d self.shared_state.num_d_workers = num_d
# Observe FPM stats and feed into regression models
p_stats = self.prefill_planner.observe_fpm_load_stats() p_stats = self.prefill_planner.observe_fpm_load_stats()
d_stats = self.decode_planner.observe_fpm_load_stats() d_stats = self.decode_planner.observe_fpm_load_stats()
if not self.enable_load:
continue
if not p_stats and not d_stats: if not p_stats and not d_stats:
logger.warning("No FPM data for either prefill or decode, skipping") logger.warning("No FPM data for either prefill or decode, skipping")
continue continue
...@@ -221,16 +224,13 @@ class DisaggPlanner: ...@@ -221,16 +224,13 @@ class DisaggPlanner:
logger.info("Load-based scaling: no scaling needed") logger.info("Load-based scaling: no scaling needed")
continue continue
# Enforce lower bounds from throughput-based
if self.enable_throughput: if self.enable_throughput:
final_p = max(final_p, self.shared_state.throughput_lower_bound_p) final_p = max(final_p, self.shared_state.throughput_lower_bound_p)
final_d = max(final_d, self.shared_state.throughput_lower_bound_d) final_d = max(final_d, self.shared_state.throughput_lower_bound_d)
# Enforce minimum endpoints
final_p = max(final_p, self.config.min_endpoint) final_p = max(final_p, self.config.min_endpoint)
final_d = max(final_d, self.config.min_endpoint) final_d = max(final_d, self.config.min_endpoint)
# Apply GPU budget
final_p, final_d = _apply_global_gpu_budget(final_p, final_d, self.config) final_p, final_d = _apply_global_gpu_budget(final_p, final_d, self.config)
logger.info( logger.info(
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""FPM-driven regression models for load-based scaling.
Each model takes ForwardPassMetrics observations and estimates per-engine
TTFT or ITL by simulating the scheduler's chunked prefill / decode
iteration pipeline.
- PrefillRegressionModel: 1D regression (sum_prefill_tokens -> wall_time)
- DecodeRegressionModel: 1D regression (sum_decode_kv_tokens -> wall_time)
- AggRegressionModel: 2D regression (sum_prefill_tokens, sum_decode_kv_tokens -> wall_time)
"""
import logging
import math
from collections import deque
from typing import Optional, Union
import numpy as np
from sklearn.linear_model import LinearRegression
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
logger = logging.getLogger(__name__)
class _MovingAverage:
"""Fixed-window moving average that skips leading zeros.
Initial zero values (pre-traffic idle period) are ignored until the
first non-zero value arrives, matching the throughput planner's
load predictor behavior.
"""
__slots__ = ("_window", "_sum", "_seen_nonzero")
def __init__(self, window_size: int):
self._window: deque[float] = deque(maxlen=window_size)
self._sum: float = 0.0
self._seen_nonzero: bool = False
def add(self, value: float) -> None:
if value == 0.0 and not self._seen_nonzero:
return
if value != 0.0:
self._seen_nonzero = True
if len(self._window) == self._window.maxlen:
self._sum -= self._window[0]
self._window.append(value)
self._sum += value
@property
def value(self) -> float:
if not self._window:
return 0.0
return self._sum / len(self._window)
def __len__(self) -> int:
return len(self._window)
class _BaseRegressionModel:
"""Shared regression infrastructure for FPM-based models."""
def __init__(self, window_size: int, min_observations: int = 5, ndim: int = 1):
self.window_size = window_size
self.min_observations = min_observations
self._ndim = ndim
self._observations: deque[tuple[Union[float, list[float]], float]] = deque(
maxlen=window_size
)
self._model = LinearRegression()
self._is_fitted = False
def _extract_x(self, fpm: ForwardPassMetrics) -> Union[float, list[float]]:
"""Return the regression input(s) from an FPM snapshot."""
raise NotImplementedError
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
"""Update moving averages (called for every FPM, including idle)."""
raise NotImplementedError
def add_observation(self, fpm: ForwardPassMetrics) -> None:
# Always update moving averages so idle state is reflected.
self._update_moving_averages(fpm)
if fpm.wall_time == 0.0:
return
self._observations.append((self._extract_x(fpm), fpm.wall_time))
self._is_fitted = False
def _fit(self) -> bool:
if len(self._observations) < self.min_observations:
return False
X = np.array([o[0] for o in self._observations])
if self._ndim == 1:
X = X.reshape(-1, 1)
y = np.array([o[1] for o in self._observations])
self._model.fit(X, y)
self._is_fitted = True
return True
def _ensure_fitted(self) -> bool:
return self._is_fitted or self._fit()
def has_sufficient_data(self) -> bool:
return len(self._observations) >= self.min_observations
@property
def num_observations(self) -> int:
return len(self._observations)
class PrefillRegressionModel(_BaseRegressionModel):
"""Predict per-iteration wall time from scheduled prefill tokens.
Regression: wall_time = f(sum_prefill_tokens)
Simulation: estimate TTFT by chunking queued_prefill_tokens + avg_isl
into max_num_batched_tokens-sized iterations and summing
the predicted wall time for each.
"""
def __init__(self, window_size: int, min_observations: int = 5):
super().__init__(window_size, min_observations, ndim=1)
self._avg_isl = _MovingAverage(window_size)
self._avg_num_prefill = _MovingAverage(window_size)
def _extract_x(self, fpm: ForwardPassMetrics) -> float:
return float(fpm.scheduled_requests.sum_prefill_tokens)
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
sched = fpm.scheduled_requests
if sched.num_prefill_requests > 0:
self._avg_isl.add(sched.sum_prefill_tokens / sched.num_prefill_requests)
self._avg_num_prefill.add(float(sched.num_prefill_requests))
@property
def avg_isl(self) -> float:
return self._avg_isl.value
def estimate_next_ttft(
self,
queued_prefill_tokens: int,
max_num_batched_tokens: int,
) -> Optional[float]:
"""Simulate prefill scheduling to estimate TTFT for the next request.
The scheduler processes prefill tokens in chunks of
max_num_batched_tokens per iteration. We sum the regression-predicted
wall time for each chunk to approximate TTFT.
Args:
queued_prefill_tokens: tokens already queued ahead of the next request.
max_num_batched_tokens: per-iteration token budget (from WorkerInfo/MDC).
Returns:
Estimated TTFT in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted() or max_num_batched_tokens <= 0:
return None
total_tokens = queued_prefill_tokens + self._avg_isl.value
if total_tokens <= 0:
return 0.0
num_iterations = math.ceil(total_tokens / max_num_batched_tokens)
total_time = 0.0
remaining = total_tokens
for _ in range(num_iterations):
chunk = min(remaining, max_num_batched_tokens)
pred = self._model.predict(np.array([[chunk]]))[0]
total_time += max(0.0, float(pred))
remaining -= chunk
return total_time
class DecodeRegressionModel(_BaseRegressionModel):
"""Predict per-iteration wall time from scheduled decode KV tokens.
Regression: wall_time = f(sum_decode_kv_tokens)
Estimation: predict ITL for the next decode step accounting for
queued (preempted) decode load and one additional request.
"""
def __init__(self, window_size: int, min_observations: int = 5):
super().__init__(window_size, min_observations, ndim=1)
self._avg_decode_len = _MovingAverage(window_size)
self._avg_num_decode = _MovingAverage(window_size)
def _extract_x(self, fpm: ForwardPassMetrics) -> float:
return float(fpm.scheduled_requests.sum_decode_kv_tokens)
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
sched = fpm.scheduled_requests
if sched.num_decode_requests > 0:
self._avg_decode_len.add(
sched.sum_decode_kv_tokens / sched.num_decode_requests
)
self._avg_num_decode.add(float(sched.num_decode_requests))
@property
def avg_decode_length(self) -> float:
return self._avg_decode_len.value
def estimate_next_itl(
self,
scheduled_decode_kv: int,
queued_decode_kv: int,
) -> Optional[float]:
"""Estimate the next decode iteration time.
Predicts wall time for the total decode KV load: currently scheduled +
queued (preempted) + one additional request worth of decode context.
Args:
scheduled_decode_kv: sum_decode_kv_tokens from the latest FPM.
queued_decode_kv: sum_decode_kv_tokens from the queued metrics.
Returns:
Estimated ITL in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted():
return None
total_kv = scheduled_decode_kv + queued_decode_kv + self._avg_decode_len.value
return max(0.0, float(self._model.predict(np.array([[total_kv]]))[0]))
class AggRegressionModel(_BaseRegressionModel):
"""2D regression for aggregated (chunked prefill + decode) engines.
Regression: wall_time = f(sum_prefill_tokens, sum_decode_kv_tokens)
Estimation: estimate TTFT by simulating prefill chunking while assuming
steady-state decode load; estimate ITL by predicting decode
iteration time while assuming average piggybacked prefill load.
"""
def __init__(self, window_size: int, min_observations: int = 5):
super().__init__(window_size, min_observations, ndim=2)
self._avg_isl = _MovingAverage(window_size)
self._avg_decode_len = _MovingAverage(window_size)
self._avg_prefill_tokens = _MovingAverage(window_size)
self._avg_num_prefill = _MovingAverage(window_size)
self._avg_num_decode = _MovingAverage(window_size)
def _extract_x(self, fpm: ForwardPassMetrics) -> list[float]:
sched = fpm.scheduled_requests
return [float(sched.sum_prefill_tokens), float(sched.sum_decode_kv_tokens)]
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
sched = fpm.scheduled_requests
if sched.num_prefill_requests > 0:
self._avg_isl.add(sched.sum_prefill_tokens / sched.num_prefill_requests)
if sched.num_decode_requests > 0:
self._avg_decode_len.add(
sched.sum_decode_kv_tokens / sched.num_decode_requests
)
self._avg_prefill_tokens.add(float(sched.sum_prefill_tokens))
self._avg_num_prefill.add(float(sched.num_prefill_requests))
self._avg_num_decode.add(float(sched.num_decode_requests))
@property
def avg_isl(self) -> float:
return self._avg_isl.value
@property
def avg_decode_length(self) -> float:
return self._avg_decode_len.value
@property
def avg_prefill_tokens(self) -> float:
return self._avg_prefill_tokens.value
def _predict_2d(self, prefill_tokens: float, decode_kv_tokens: float) -> float:
return float(
self._model.predict(np.array([[prefill_tokens, decode_kv_tokens]]))[0]
)
def estimate_next_ttft(
self,
queued_prefill_tokens: int,
max_num_batched_tokens: int,
current_decode_kv: int,
) -> Optional[float]:
"""Simulate prefill scheduling with piggybacked decode.
Same chunking simulation as PrefillRegressionModel, but each
iteration also carries the current decode KV load (steady state).
Args:
queued_prefill_tokens: prefill tokens queued ahead of the next request.
max_num_batched_tokens: per-iteration token budget (from MDC).
current_decode_kv: scheduled decode KV tokens from the latest FPM
(assumed steady during prefill).
Returns:
Estimated TTFT in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted() or max_num_batched_tokens <= 0:
return None
total_tokens = queued_prefill_tokens + self._avg_isl.value
if total_tokens <= 0:
return 0.0
num_iterations = math.ceil(total_tokens / max_num_batched_tokens)
total_time = 0.0
remaining = total_tokens
for _ in range(num_iterations):
chunk = min(remaining, max_num_batched_tokens)
total_time += max(0.0, self._predict_2d(chunk, float(current_decode_kv)))
remaining -= chunk
return total_time
def estimate_next_itl(
self,
scheduled_decode_kv: int,
queued_decode_kv: int,
) -> Optional[float]:
"""Estimate decode iteration time with piggybacked prefill.
Uses the moving average of scheduled prefill tokens as the
piggybacked prefill load in the next iteration.
Args:
scheduled_decode_kv: sum_decode_kv_tokens from the latest FPM.
queued_decode_kv: sum_decode_kv_tokens from the queued metrics.
Returns:
Estimated ITL in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted():
return None
total_kv = scheduled_decode_kv + queued_decode_kv + self._avg_decode_len.value
return max(0.0, self._predict_2d(self._avg_prefill_tokens.value, total_kv))
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import deque
from typing import Optional
import numpy as np
from sklearn.linear_model import LinearRegression
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class LoadBasedRegressionModel:
"""Sliding window linear regression for load-based scaling.
Maintains a fixed-size window of (X, y) observations and provides:
- Forward prediction: y = mx + b (given X, predict latency)
- Reverse prediction: X = (y - b) / m (given target SLA, find max load)
Used to map:
- Prefill: (active_prefill_tokens + ISL) -> TTFT
- Decode: active_decode_blocks -> ITL
"""
def __init__(self, window_size: int, min_observations: int = 5):
self.window_size = window_size
self.min_observations = min_observations
self._observations: deque = deque(maxlen=window_size)
self._model = LinearRegression()
self._is_fitted = False
def add_observation(self, x: float, y: float) -> None:
"""Add an (X, y) observation to the sliding window."""
self._observations.append((x, y))
self._is_fitted = False
def fit(self) -> bool:
"""Fit the linear regression model on current observations.
Returns:
True if fitting succeeded, False if insufficient data.
"""
if len(self._observations) < self.min_observations:
return False
X = np.array([obs[0] for obs in self._observations]).reshape(-1, 1)
y = np.array([obs[1] for obs in self._observations])
self._model.fit(X, y)
self._is_fitted = True
return True
def predict_x_from_sla(self, target_y: float) -> Optional[float]:
"""Reverse prediction: given a target latency (SLA), find the max load.
Solves: x = (y - b) / m
Safety guards:
- Returns None if insufficient data (cold start)
- Falls back to observation-based heuristic if slope <= 0
- Clamps result to non-negative
Args:
target_y: Target latency SLA value (e.g., TTFT in ms, ITL in ms)
Returns:
Maximum load value that satisfies the SLA, or None if insufficient data.
"""
if not self._is_fitted and not self.fit():
return None
coef = float(self._model.coef_[0])
intercept = float(self._model.intercept_)
if coef <= 0:
logger.warning(
f"Regression slope is non-positive ({coef:.6f}), "
"falling back to observation-based heuristic"
)
return self._fallback_x_from_observations(target_y)
x_sla = (target_y - intercept) / coef
return max(0.0, x_sla)
def _fallback_x_from_observations(self, target_y: float) -> float:
"""Fallback when regression slope is non-positive.
Returns the minimum x among observations where y < target_y.
If all observations have y >= target_y, returns the smallest x overall.
"""
below = [(x, y) for x, y in self._observations if y < target_y]
if below:
result = min(x for x, _ in below)
else:
result = min(x for x, _ in self._observations)
logger.info(
f"Fallback x from observations: {result:.1f} "
f"(points below SLA: {len(below)}/{len(self._observations)})"
)
return max(0.0, result)
def has_sufficient_data(self) -> bool:
"""Check if enough observations have been collected (cold start guard)."""
return len(self._observations) >= self.min_observations
@property
def num_observations(self) -> int:
return len(self._observations)
@property
def slope(self) -> Optional[float]:
"""Return the current regression slope, or None if not fitted."""
if not self._is_fitted and not self.fit():
return None
return float(self._model.coef_[0])
@property
def intercept(self) -> Optional[float]:
"""Return the current regression intercept, or None if not fitted."""
if not self._is_fitted and not self.fit():
return None
return float(self._model.intercept_)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.planner.core.perf_model.agg import AggRegressionModel
from dynamo.planner.core.perf_model.decode import DecodeRegressionModel
from dynamo.planner.core.perf_model.prefill import PrefillRegressionModel
__all__ = [
"PrefillRegressionModel",
"DecodeRegressionModel",
"AggRegressionModel",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Aggregated (chunked prefill + decode) engine performance model.
Regression: wall_time = f(sum_prefill_tokens, sum_decode_kv_tokens)
"""
import logging
import math
from typing import Optional
import numpy as np
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.planner.core.perf_model.base import _BaseRegressionModel, _MovingAverage
logger = logging.getLogger(__name__)
class AggRegressionModel(_BaseRegressionModel):
"""2D regression for aggregated (chunked prefill + decode) engines."""
def __init__(
self,
max_num_fpm_samples: int,
min_observations: int = 5,
bucket_count: int = 16,
):
super().__init__(
max_num_fpm_samples, min_observations, ndim=2, bucket_count=bucket_count
)
self._avg_isl = _MovingAverage(max_num_fpm_samples)
self._avg_decode_len = _MovingAverage(max_num_fpm_samples)
self._avg_prefill_tokens = _MovingAverage(max_num_fpm_samples)
self._avg_num_prefill = _MovingAverage(max_num_fpm_samples)
self._avg_num_decode = _MovingAverage(max_num_fpm_samples)
def _extract_x(self, fpm: ForwardPassMetrics) -> list[float]:
sched = fpm.scheduled_requests
return [float(sched.sum_prefill_tokens), float(sched.sum_decode_kv_tokens)]
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
sched = fpm.scheduled_requests
if sched.num_prefill_requests > 0:
self._avg_isl.add(sched.sum_prefill_tokens / sched.num_prefill_requests)
if sched.num_decode_requests > 0:
self._avg_decode_len.add(
sched.sum_decode_kv_tokens / sched.num_decode_requests
)
self._avg_prefill_tokens.add(float(sched.sum_prefill_tokens))
self._avg_num_prefill.add(float(sched.num_prefill_requests))
self._avg_num_decode.add(float(sched.num_decode_requests))
@property
def avg_isl(self) -> float:
return self._avg_isl.value
@property
def avg_decode_length(self) -> float:
return self._avg_decode_len.value
@property
def avg_prefill_tokens(self) -> float:
return self._avg_prefill_tokens.value
def _predict_2d(self, prefill_tokens: float, decode_kv_tokens: float) -> float:
return max(
1e-6,
float(
self._model.predict(np.array([[prefill_tokens, decode_kv_tokens]]))[0]
),
)
def estimate_next_ttft(
self,
queued_prefill_tokens: int,
max_num_batched_tokens: int,
current_decode_kv: int,
) -> Optional[float]:
"""Simulate prefill scheduling with piggybacked decode.
Returns estimated TTFT in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted() or max_num_batched_tokens <= 0:
return None
total_tokens = queued_prefill_tokens + self._avg_isl.value
if total_tokens <= 0:
return 0.0
num_iterations = math.ceil(total_tokens / max_num_batched_tokens)
total_time = 0.0
remaining = total_tokens
for _ in range(num_iterations):
chunk = min(remaining, max_num_batched_tokens)
total_time += self._predict_2d(chunk, float(current_decode_kv))
remaining -= chunk
return total_time
def estimate_next_itl(
self,
scheduled_decode_kv: int,
queued_decode_kv: int,
) -> Optional[float]:
"""Estimate decode iteration time with piggybacked prefill.
Returns estimated ITL in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted():
return None
total_kv = scheduled_decode_kv + queued_decode_kv + self._avg_decode_len.value
return self._predict_2d(self._avg_prefill_tokens.value, total_kv)
def find_best_engine_agg_rps(
self,
isl: float,
osl: float,
max_num_batched_tokens: int,
ttft_sla: float,
itl_sla: float,
) -> tuple[float, float, float]:
"""Find the maximum agg engine request rate under both SLA targets.
Sweeps over batch_size to find the largest decode concurrency
where both ITL and TTFT remain within their targets. Warns if
even batch_size=1 violates either SLA.
Request rate is derived via Little's law:
``engine_rps = best_batch_size / (osl * wall_time_per_iter)``.
Args:
isl: average input sequence length (tokens).
osl: average output sequence length (tokens).
max_num_batched_tokens: per-iteration token budget.
ttft_sla: TTFT target in milliseconds.
itl_sla: ITL target in milliseconds.
Returns:
(engine_rps, actual_ttft_ms, actual_itl_ms) -- 0 rps
signals an error (model not fitted or invalid input);
positive rps is the best achievable rate with the
predicted TTFT/ITL. If SLAs are violated, a warning
is logged but the rate is still returned.
"""
if (
not self._ensure_fitted()
or isl <= 0
or osl <= 0
or max_num_batched_tokens <= 0
):
return (0.0, 0.0, 0.0)
avg_ctx = isl + osl / 2.0
max_bs = max(1, int(max_num_batched_tokens / max(1, avg_ctx))) * 2
best_rps = 0.0
best_ttft_ms = 0.0
best_itl_ms = 0.0
for bs in range(1, max_bs + 1):
decode_kv = bs * avg_ctx
prefill_per_iter = min(bs * isl / max(1.0, osl), max_num_batched_tokens)
wt = self._predict_2d(prefill_per_iter, decode_kv)
itl_ms = wt * 1000.0
est_ttft = self.estimate_next_ttft(
queued_prefill_tokens=int(prefill_per_iter),
max_num_batched_tokens=max_num_batched_tokens,
current_decode_kv=int(decode_kv),
)
ttft_ms = est_ttft * 1000.0 if est_ttft is not None else 0.0
if itl_ms > itl_sla or ttft_ms > ttft_sla:
if bs == 1:
logger.warning(
f"Agg SLA unreachable at batch_size=1: "
f"TTFT={ttft_ms:.1f}ms (target {ttft_sla:.1f}ms), "
f"ITL={itl_ms:.1f}ms (target {itl_sla:.1f}ms)"
)
best_rps = 1.0 / (osl * wt)
best_ttft_ms = ttft_ms
best_itl_ms = itl_ms
break
best_rps = bs / (osl * wt)
best_ttft_ms = ttft_ms
best_itl_ms = itl_ms
return (best_rps, best_ttft_ms, best_itl_ms)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Base regression infrastructure and utilities for FPM-based performance models.
Provides ``_BaseRegressionModel`` (bucketed observation storage with
linear regression) and ``_MovingAverage``, shared by the prefill,
decode, and agg perf model subclasses.
"""
import logging
import math
from collections import defaultdict, deque
from typing import Union
import numpy as np
from sklearn.linear_model import LinearRegression
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
logger = logging.getLogger(__name__)
class _MovingAverage:
"""Fixed-window moving average that skips leading zeros.
Initial zero values (pre-traffic idle period) are ignored until the
first non-zero value arrives, matching the throughput planner's
load predictor behavior.
"""
__slots__ = ("_window", "_sum", "_seen_nonzero")
def __init__(self, window_size: int):
self._window: deque[float] = deque(maxlen=window_size)
self._sum: float = 0.0
self._seen_nonzero: bool = False
def add(self, value: float) -> None:
if value == 0.0 and not self._seen_nonzero:
return
if value != 0.0:
self._seen_nonzero = True
if len(self._window) == self._window.maxlen:
self._sum -= self._window[0]
self._window.append(value)
self._sum += value
@property
def value(self) -> float:
if not self._window:
return 0.0
return self._sum / len(self._window)
def __len__(self) -> int:
return len(self._window)
# ---------------------------------------------------------------------------
# Bucketed FPM sample retirement.
#
# FPM observations span diverse engine load conditions (from
# pre-deployment benchmarks through live traffic). A simple FIFO
# window would let sustained traffic at one operating point push
# out data for other conditions, degrading the regression's coverage
# of the full performance surface.
#
# Instead, each input axis is divided into equal-width buckets.
# Observations are assigned to a bucket based on their input features.
# When total samples exceed max_num_fpm_samples, the oldest sample in
# the bucket with the most entries is retired. This keeps the sample
# distribution roughly uniform across the operating range.
#
# fpm_sample_bucket_size controls the total number of buckets:
# - 1D models: fpm_sample_bucket_size buckets along the single axis
# - 2D models: sqrt(fpm_sample_bucket_size) buckets per axis
# (e.g., 16 -> 4x4 grid)
# The config requires fpm_sample_bucket_size to be a perfect square
# so the 2D decomposition is always clean.
# ---------------------------------------------------------------------------
class _BaseRegressionModel:
"""Shared regression infrastructure for FPM-based models."""
def __init__(
self,
max_num_fpm_samples: int,
min_observations: int = 5,
ndim: int = 1,
bucket_count: int = 16,
):
self.max_num_fpm_samples = max_num_fpm_samples
self.min_observations = min_observations
self._ndim = ndim
if ndim == 1:
self._buckets_per_axis = bucket_count
else:
self._buckets_per_axis = math.isqrt(bucket_count)
self._buckets: dict[
tuple[int, ...], deque[tuple[Union[float, list[float]], float]]
] = defaultdict(deque)
self._total_observations = 0
self._axis_min: list[float] = [float("inf")] * ndim
self._axis_max: list[float] = [float("-inf")] * ndim
self._model = LinearRegression()
self._is_fitted = False
def _extract_x(self, fpm: ForwardPassMetrics) -> Union[float, list[float]]:
raise NotImplementedError
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
raise NotImplementedError
def _to_vals(self, x: Union[float, list[float]]) -> list[float]:
if isinstance(x, list):
return x
return [x]
def _bucket_key(self, x: Union[float, list[float]]) -> tuple[int, ...]:
"""Compute the bucket index for an observation's input features."""
vals = self._to_vals(x)
key = []
for i, v in enumerate(vals):
lo, hi = self._axis_min[i], self._axis_max[i]
if hi <= lo:
key.append(0)
else:
idx = int((v - lo) / (hi - lo) * self._buckets_per_axis)
key.append(max(0, min(idx, self._buckets_per_axis - 1)))
return tuple(key)
def _update_axis_bounds(self, x: Union[float, list[float]]) -> bool:
"""Update min/max per axis. Returns True if bounds changed."""
vals = self._to_vals(x)
changed = False
for i, v in enumerate(vals):
if v < self._axis_min[i]:
self._axis_min[i] = v
changed = True
if v > self._axis_max[i]:
self._axis_max[i] = v
changed = True
return changed
def _rebuild_buckets(self) -> None:
"""Re-index all observations into buckets using current axis bounds."""
all_obs = self._gather_observations()
self._buckets.clear()
for x, wt in all_obs:
key = self._bucket_key(x)
self._buckets[key].append((x, wt))
def add_observation(self, fpm: ForwardPassMetrics) -> None:
self._update_moving_averages(fpm)
if fpm.wall_time == 0.0:
return
x = self._extract_x(fpm)
bounds_changed = self._update_axis_bounds(x)
if bounds_changed and self._total_observations > 0:
self._rebuild_buckets()
key = self._bucket_key(x)
self._buckets[key].append((x, fpm.wall_time))
self._total_observations += 1
if self._total_observations > self.max_num_fpm_samples:
fattest_key = max(self._buckets, key=lambda k: len(self._buckets[k]))
self._buckets[fattest_key].popleft()
self._total_observations -= 1
if not self._buckets[fattest_key]:
del self._buckets[fattest_key]
self._is_fitted = False
def load_benchmark_fpms(self, fpms: list[ForwardPassMetrics]) -> None:
"""Bootstrap regression from pre-deployment benchmark FPMs."""
for fpm in fpms:
self.add_observation(fpm)
def _gather_observations(self) -> list[tuple[Union[float, list[float]], float]]:
return [obs for bucket in self._buckets.values() for obs in bucket]
def _fit(self) -> bool:
observations = self._gather_observations()
if len(observations) < self.min_observations:
return False
X = np.array([o[0] for o in observations])
if self._ndim == 1:
X = X.reshape(-1, 1)
y = np.array([o[1] for o in observations])
self._model.fit(X, y)
# Negative coefficients mean "more load → less compute time", which
# is physically impossible. Reject the fit so callers see the model
# as not ready rather than making inverted scaling decisions.
if np.any(self._model.coef_ < 0):
logger.warning(
f"Regression produced negative coefficients {self._model.coef_.tolist()}, "
"model rejected — scaling will be skipped until more data arrives"
)
self._is_fitted = False
return False
self._is_fitted = True
return True
def _ensure_fitted(self) -> bool:
return self._is_fitted or self._fit()
def has_sufficient_data(self) -> bool:
return self._total_observations >= self.min_observations
@property
def num_observations(self) -> int:
return self._total_observations
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Decode engine performance model.
Regression: wall_time = f(num_decode_requests, sum_decode_kv_tokens)
"""
import logging
from typing import Optional
import numpy as np
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.planner.core.perf_model.base import _BaseRegressionModel, _MovingAverage
logger = logging.getLogger(__name__)
class DecodeRegressionModel(_BaseRegressionModel):
"""Predict per-iteration wall time from decode batch composition."""
def __init__(
self,
max_num_fpm_samples: int,
min_observations: int = 5,
bucket_count: int = 16,
):
super().__init__(
max_num_fpm_samples, min_observations, ndim=2, bucket_count=bucket_count
)
self._avg_decode_len = _MovingAverage(max_num_fpm_samples)
self._avg_num_decode = _MovingAverage(max_num_fpm_samples)
self._max_observed_kv: float = 0.0
def _extract_x(self, fpm: ForwardPassMetrics) -> list[float]:
sched = fpm.scheduled_requests
return [float(sched.num_decode_requests), float(sched.sum_decode_kv_tokens)]
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
sched = fpm.scheduled_requests
if sched.num_decode_requests > 0:
self._avg_decode_len.add(
sched.sum_decode_kv_tokens / sched.num_decode_requests
)
self._avg_num_decode.add(float(sched.num_decode_requests))
if sched.sum_decode_kv_tokens > self._max_observed_kv:
self._max_observed_kv = float(sched.sum_decode_kv_tokens)
@property
def avg_decode_length(self) -> float:
return self._avg_decode_len.value
def _predict_2d(self, num_requests: float, kv_tokens: float) -> float:
return max(
1e-6, float(self._model.predict(np.array([[num_requests, kv_tokens]]))[0])
)
def estimate_next_itl(
self,
scheduled_decode_kv: int,
queued_decode_kv: int,
) -> Optional[float]:
"""Estimate the next decode iteration time in seconds."""
if not self._ensure_fitted():
return None
total_kv = scheduled_decode_kv + queued_decode_kv + self._avg_decode_len.value
num_req = self._avg_num_decode.value + 1
return self._predict_2d(num_req, total_kv)
def find_best_engine_decode_rps(
self, itl: float, context_length: float, osl: float
) -> tuple[float, float]:
"""Find the maximum decode engine request rate within an ITL target.
Binary searches over batch_size at the given context_length for the
maximum batch_size where predicted wall_time * 1000 <= itl. If even
batch_size=1 violates the target, warns but returns the best
achievable rate at batch_size=1 so the caller can still scale.
Request rate is derived via Little's law:
``engine_rps = best_batch_size / (osl * wall_time_per_iter)``.
Returns:
(engine_rps, actual_itl_ms) -- 0 rps signals an error
(model not fitted or invalid input); positive rps is
the best achievable rate with the predicted ITL.
"""
if not self._ensure_fitted() or context_length <= 0 or osl <= 0 or itl <= 0:
return (0.0, 0.0)
max_batch = (
max(1, int(self._max_observed_kv / context_length))
if self._max_observed_kv > 0
else 256
)
lo, hi = 1, max_batch
best_bs, best_wt = 1, self._predict_2d(1, context_length)
if best_wt * 1000.0 > itl:
logger.warning(
f"ITL SLA unreachable: predicted {best_wt * 1000.0:.1f}ms "
f"> target {itl:.1f}ms at batch_size=1, ctx_len={context_length:.0f}"
)
return (best_bs / (osl * best_wt), best_wt * 1000.0)
while lo <= hi:
mid = (lo + hi) // 2
kv = mid * context_length
wt = self._predict_2d(mid, kv)
if wt * 1000.0 <= itl:
best_bs, best_wt = mid, wt
lo = mid + 1
else:
hi = mid - 1
engine_rps = best_bs / (osl * best_wt)
return (engine_rps, best_wt * 1000.0)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Prefill engine performance model.
Regression: wall_time = f(sum_prefill_tokens)
"""
import logging
import math
from typing import Optional
import numpy as np
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.planner.core.perf_model.base import _BaseRegressionModel, _MovingAverage
logger = logging.getLogger(__name__)
class PrefillRegressionModel(_BaseRegressionModel):
"""Predict per-iteration wall time from scheduled prefill tokens.
Simulation: estimate TTFT by chunking queued_prefill_tokens + avg_isl
into max_num_batched_tokens-sized iterations and summing
the predicted wall time for each.
"""
def __init__(
self,
max_num_fpm_samples: int,
min_observations: int = 5,
bucket_count: int = 16,
):
super().__init__(
max_num_fpm_samples, min_observations, ndim=1, bucket_count=bucket_count
)
self._avg_isl = _MovingAverage(max_num_fpm_samples)
self._avg_num_prefill = _MovingAverage(max_num_fpm_samples)
def _extract_x(self, fpm: ForwardPassMetrics) -> float:
return float(fpm.scheduled_requests.sum_prefill_tokens)
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
sched = fpm.scheduled_requests
if sched.num_prefill_requests > 0:
self._avg_isl.add(sched.sum_prefill_tokens / sched.num_prefill_requests)
self._avg_num_prefill.add(float(sched.num_prefill_requests))
@property
def avg_isl(self) -> float:
return self._avg_isl.value
def _predict_wall_time(self, prefill_tokens: float) -> float:
return max(1e-6, float(self._model.predict(np.array([[prefill_tokens]]))[0]))
def estimate_next_ttft(
self,
queued_prefill_tokens: int,
max_num_batched_tokens: int,
) -> Optional[float]:
"""Simulate prefill scheduling to estimate TTFT for the next request.
Returns estimated TTFT in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted() or max_num_batched_tokens <= 0:
return None
total_tokens = queued_prefill_tokens + self._avg_isl.value
if total_tokens <= 0:
return 0.0
num_iterations = math.ceil(total_tokens / max_num_batched_tokens)
total_time = 0.0
remaining = total_tokens
for _ in range(num_iterations):
chunk = min(remaining, max_num_batched_tokens)
total_time += self._predict_wall_time(chunk)
remaining -= chunk
return total_time
def find_best_engine_prefill_rps(
self, ttft_sla: float, isl: float
) -> tuple[float, float]:
"""Find prefill engine request rate under a TTFT target.
Predicts wall_time for a single prefill at the given ISL.
If the predicted TTFT exceeds the SLA, logs a warning but
still returns the best achievable rate so the caller can
scale based on load matching.
Returns:
(engine_rps, actual_ttft_ms) -- 0 rps signals an error
(model not fitted or invalid input); positive rps is
the best achievable rate with the predicted TTFT.
"""
if not self._ensure_fitted() or isl <= 0:
return (0.0, 0.0)
wt = self._predict_wall_time(isl)
actual_ttft_ms = wt * 1000.0
engine_rps = 1.0 / wt
if actual_ttft_ms > ttft_sla:
logger.warning(
f"TTFT SLA unreachable: predicted {actual_ttft_ms:.1f}ms "
f"> target {ttft_sla:.1f}ms at ISL={isl:.0f}"
)
return (engine_rps, actual_ttft_ms)
...@@ -75,42 +75,27 @@ class PrefillPlanner(BasePlanner): ...@@ -75,42 +75,27 @@ class PrefillPlanner(BasePlanner):
label="prefill TTFT", label="prefill TTFT",
) )
def _update_correction_factor(self) -> bool:
assert self.last_metrics.isl is not None and self.last_metrics.ttft is not None
expect_ttft = self.prefill_interpolator.interpolate_ttft(self.last_metrics.isl)
self.p_correction_factor = self.last_metrics.ttft / expect_ttft
logger.info(f"Correction factor (prefill TTFT): {self.p_correction_factor:.3f}")
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.p_correction_factor.set(self.p_correction_factor)
return True
def _compute_replica_requirements( def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float self, next_num_req: float, next_isl: float, next_osl: float
) -> int: ) -> Optional[int]:
pred_prefill_throughput = ( demand_rps = next_num_req / self.config.throughput_adjustment_interval
next_num_req engine_rps, actual_ttft_ms = self.ttft_regression.find_best_engine_prefill_rps(
* next_isl ttft_sla=self.config.ttft, isl=next_isl
/ self.config.throughput_adjustment_interval
* min(1, self.p_correction_factor)
) )
p_thpt_per_gpu = self.prefill_interpolator.interpolate_thpt_per_gpu(next_isl) if engine_rps <= 0:
if p_thpt_per_gpu <= 0: logger.warning("Prefill perf model not ready, skipping throughput scaling")
return None
if actual_ttft_ms > self.config.ttft:
logger.warning( logger.warning(
f"p_thpt_per_gpu is {p_thpt_per_gpu} " f"Prefill TTFT SLA not met: {actual_ttft_ms:.1f}ms > "
"(no throughput satisfies TTFT target), falling back to min_endpoint" f"{self.config.ttft:.1f}ms, scaling with best achievable rate"
)
return self.config.min_endpoint
assert self.config.prefill_engine_num_gpu is not None
next_num_p = math.ceil(
pred_prefill_throughput
/ p_thpt_per_gpu
/ self.config.prefill_engine_num_gpu
) )
next_num_p = math.ceil(demand_rps / engine_rps)
next_num_p = max(next_num_p, self.config.min_endpoint) next_num_p = max(next_num_p, self.config.min_endpoint)
logger.info( logger.info(
f"Prefill calculation: {pred_prefill_throughput:.2f}(p_thpt) / " f"Prefill: {demand_rps:.2f}(demand rps) / "
f"{p_thpt_per_gpu * self.config.prefill_engine_num_gpu:.2f}(p_engine_cap) = " f"{engine_rps:.2f}(engine rps) = {next_num_p}(num_p), "
f"{next_num_p}(num_p)" f"est_ttft={actual_ttft_ms:.1f}ms"
) )
return next_num_p return next_num_p
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
from typing import Optional
import numpy as np
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
MISSING_PROFILING_DATA_ERROR_MESSAGE = (
"SLA-Planner requires pre-deployment profiling results to run.\n"
"Please follow /docs/components/profiler/profiler-guide.md to run the profiling first,\n"
"and make sure the profiling results are present in --profile-results-dir."
)
class PrefillInterpolator:
"""
Takes input from results of pre-deployment performance profiling to interpolate
throughput/gpu and TTFT for a given ISL.
"""
def __init__(
self,
profile_results_dir: Optional[str] = None,
raw_data: Optional[dict] = None,
):
if profile_results_dir:
prefill_npz_fn = (
f"{profile_results_dir}/selected_prefill_interpolation/raw_data.npz"
)
try:
with np.load(prefill_npz_fn) as raw_data:
self.prefill_isl = raw_data["prefill_isl"]
self.prefill_ttft = raw_data["prefill_ttft"] # in milliseconds
self.prefill_thpt_per_gpu = raw_data["prefill_thpt_per_gpu"]
except FileNotFoundError:
# Fallback to JSON provided via ConfigMap mounted at profile_results_dir
json_fn = os.path.join(profile_results_dir, "prefill_raw_data.json")
try:
with open(json_fn, "r") as f:
data = json.load(f)
self.prefill_isl = np.array(data["prefill_isl"]) # type: ignore[index]
self.prefill_ttft = np.array(data["prefill_ttft"]) # type: ignore[index]
self.prefill_thpt_per_gpu = np.array(data["prefill_thpt_per_gpu"]) # type: ignore[index]
except FileNotFoundError:
raise FileNotFoundError(
f"Prefill interpolation files not found: {prefill_npz_fn} and {json_fn}\n"
f"{MISSING_PROFILING_DATA_ERROR_MESSAGE}"
)
elif raw_data:
self.prefill_isl = raw_data["prefill_isl"]
self.prefill_ttft = raw_data["prefill_ttft"] # in milliseconds
self.prefill_thpt_per_gpu = raw_data["prefill_thpt_per_gpu"]
else:
raise ValueError("Either profile_results_dir or raw_data must be provided")
self.min_isl = min(self.prefill_isl)
self.max_isl = max(self.prefill_isl)
# Lazy import scipy only when interpolation is actually needed
import scipy.interpolate
# perform 1d interpolation
self.ttft_interpolator = scipy.interpolate.interp1d(
self.prefill_isl, self.prefill_ttft, kind="cubic"
)
self.thpt_interpolator = scipy.interpolate.interp1d(
self.prefill_isl, self.prefill_thpt_per_gpu, kind="cubic"
)
def interpolate_ttft(self, isl: float) -> float:
isl = max(self.min_isl, min(isl, self.max_isl))
return self.ttft_interpolator(isl)
def interpolate_thpt_per_gpu(self, isl: float) -> float:
isl = max(self.min_isl, min(isl, self.max_isl))
return self.thpt_interpolator(isl)
class DecodeInterpolator:
"""
Takes input from results of pre-deployment performance profiling to interpolate
throughput/gpu and ITL for a given decode context length.
"""
def __init__(
self,
profile_results_dir: Optional[str] = None,
resolution: int = 100,
raw_data: Optional[dict] = None,
):
if profile_results_dir:
decode_npz_fn = (
f"{profile_results_dir}/selected_decode_interpolation/raw_data.npz"
)
try:
with np.load(decode_npz_fn) as raw_data:
self.x_kv_usage = raw_data["x_kv_usage"]
self.y_context_length = raw_data["y_context_length"]
self.z_itl = raw_data["z_itl"]
self.z_thpt_per_gpu = raw_data["z_thpt_per_gpu"]
self.max_kv_tokens = raw_data["max_kv_tokens"][0]
except FileNotFoundError:
# Fallback to JSON provided via ConfigMap mounted at profile_results_dir
json_fn = os.path.join(profile_results_dir, "decode_raw_data.json")
try:
with open(json_fn, "r") as f:
data = json.load(f)
self.x_kv_usage = np.array(data["x_kv_usage"]) # type: ignore[index]
self.y_context_length = np.array(data["y_context_length"]) # type: ignore[index]
self.z_itl = np.array(data["z_itl"]) # type: ignore[index]
self.z_thpt_per_gpu = np.array(data["z_thpt_per_gpu"]) # type: ignore[index]
self.max_kv_tokens = int(data["max_kv_tokens"]) # type: ignore[index]
except FileNotFoundError:
raise FileNotFoundError(
f"Decode interpolation files not found: {decode_npz_fn} and {json_fn}\n"
f"{MISSING_PROFILING_DATA_ERROR_MESSAGE}"
)
elif raw_data:
self.x_kv_usage = raw_data["x_kv_usage"]
self.y_context_length = raw_data["y_context_length"]
self.z_itl = raw_data["z_itl"]
self.z_thpt_per_gpu = raw_data["z_thpt_per_gpu"]
self.max_kv_tokens = raw_data["max_kv_tokens"][0]
else:
raise ValueError("Either profile_results_dir or raw_data must be provided")
# pre-compute the interpolation grid for fast lookup
self.resolution = resolution
self.xi = np.linspace(0, 1, resolution)
self.yi = np.linspace(0, max(self.y_context_length), resolution)
self.X: np.ndarray
self.Y: np.ndarray
self.X, self.Y = np.meshgrid(self.xi, self.yi)
# Lazy import scipy only when interpolation is actually needed
import scipy.interpolate
# perform 2d interpolation with fallback for NaN values
self.itl_interpolator = scipy.interpolate.griddata(
(self.x_kv_usage, self.y_context_length),
self.z_itl,
(self.X, self.Y),
method="cubic",
)
# Fill NaN values using nearest neighbor interpolation
nan_mask = np.isnan(self.itl_interpolator)
if np.any(nan_mask):
itl_nearest = scipy.interpolate.griddata(
(self.x_kv_usage, self.y_context_length),
self.z_itl,
(self.X, self.Y),
method="nearest",
)
self.itl_interpolator[nan_mask] = itl_nearest[nan_mask]
# ITL values are in milliseconds
self.thpt_interpolator = scipy.interpolate.griddata(
(self.x_kv_usage, self.y_context_length),
self.z_thpt_per_gpu,
(self.X, self.Y),
method="cubic",
)
# Fill NaN values using nearest neighbor interpolation
nan_mask = np.isnan(self.thpt_interpolator)
if np.any(nan_mask):
thpt_nearest = scipy.interpolate.griddata(
(self.x_kv_usage, self.y_context_length),
self.z_thpt_per_gpu,
(self.X, self.Y),
method="nearest",
)
self.thpt_interpolator[nan_mask] = thpt_nearest[nan_mask]
def compute_idx(self, concurrency: float, context_length: float) -> tuple[int, int]:
kv_usage = concurrency * context_length / self.max_kv_tokens
# Calculate x index (kv_usage)
ix = int(
np.clip(
np.round((kv_usage - self.xi[0]) / (self.xi[1] - self.xi[0])),
0,
self.resolution - 1,
)
)
# Calculate y index (context_length)
iy = int(
np.clip(
np.round((context_length - self.yi[0]) / (self.yi[1] - self.yi[0])),
0,
self.resolution - 1,
)
)
return ix, iy
def interpolate_itl(self, concurrency: float, context_length: float) -> float:
ix, iy = self.compute_idx(concurrency, context_length)
return self.itl_interpolator[iy, ix]
def interpolate_thpt_per_gpu(
self, concurrency: float, context_length: float
) -> float:
ix, iy = self.compute_idx(concurrency, context_length)
return self.thpt_interpolator[iy, ix]
def find_best_throughput_per_gpu(
self, itl: float, context_length: float
) -> tuple[float, float, float]:
# find the max kv_load that has itl <= target itl
# here we cannot use binary search as interpolated itl might not be monotonic
iy = int(
np.clip(
np.round((context_length - self.yi[0]) / (self.yi[1] - self.yi[0])),
0,
self.resolution - 1,
)
)
iy = max(0, min(iy, self.resolution - 1))
for ix in range(self.resolution - 1, -1, -1):
if self.itl_interpolator[iy, ix] <= itl:
return (
self.thpt_interpolator[iy, ix],
self.itl_interpolator[iy, ix],
self.xi[ix],
)
return self.thpt_interpolator[iy, 0], self.itl_interpolator[iy, 0], self.xi[0]
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--profile-results-dir", type=str, required=True)
parser.add_argument("--isl", type=int, default=3000)
parser.add_argument("--osl", type=int, default=150)
parser.add_argument("--ttft", type=float, default=100.0, help="in milliseconds")
parser.add_argument("--itl", type=float, default=10.0, help="in milliseconds")
args = parser.parse_args()
print(f"ISL={args.isl}, OSL={args.osl}")
print(f"TTFT={args.ttft}ms, ITL={args.itl}ms")
print(f"Using profile results from {args.profile_results_dir}")
print("")
# first interpolate prefill
print("Interpolating prefill performance ...")
prefill_interpolator = PrefillInterpolator(args.profile_results_dir)
est_ttft = prefill_interpolator.interpolate_ttft(args.isl)
est_thpt_per_gpu = prefill_interpolator.interpolate_thpt_per_gpu(args.isl)
if est_ttft <= args.ttft:
print(
f"\tEstimated TTFT={est_ttft:.2f}ms <= target TTFT={args.ttft:.2f}ms. Requests can queue {args.ttft - est_ttft:.2f}ms maximally while meeting TTFT SLA."
)
else:
print(
f"\tEstimated TTFT={est_ttft:.2f}ms > target TTFT={args.ttft:.2f}ms. Cannot meet TTFT SLA."
)
print(
f"\tEstimated throughput: {est_thpt_per_gpu:.2f} tokens/s/gpu. Request rate at {est_thpt_per_gpu / args.isl:.2f} requests/s will saturate one GPU."
)
print("")
# then interpolate decode
decode_interpolator = DecodeInterpolator(args.profile_results_dir)
print("Interpolating decode performance ...")
context_length = args.isl + args.osl // 2
print(f"\tAverage context length: isl + osl/2 = {context_length}.")
(
est_thpt_per_gpu,
est_itl,
est_kv_usage,
) = decode_interpolator.find_best_throughput_per_gpu(args.itl, context_length)
if est_itl <= args.itl:
print(
f"\tEstimated ITL={est_itl:.2f}ms <= target ITL={args.itl:.2f}ms at {est_kv_usage*100:.2f}% active kv usage."
)
print(
f"\tEstimated throughput: {est_thpt_per_gpu:.2f} token/s/gpu. Request rate at {est_thpt_per_gpu / args.osl:.2f} requests/s will saturate one GPU."
)
else:
print(
f"\tEstimated ITL={est_itl:.2f}ms > target ITL={args.itl:.2f}ms. Cannot meet ITL SLA."
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Pre-deployment FPM data fetching for planner regression bootstrap.
Priority chain:
1. Call ``get_perf_metrics`` Dynamo endpoint (PR 7779 self-benchmark)
2. Fallback: convert legacy profiler npz to synthetic FPMs
3. If both fail: raise
"""
import asyncio
import json
import logging
import os
from typing import Optional
import numpy as np
from dynamo.common.forward_pass_metrics import (
ForwardPassMetrics,
ScheduledRequestMetrics,
)
from dynamo.planner.config.defaults import SubComponentType
from dynamo.planner.monitoring.worker_info import WorkerInfo
logger = logging.getLogger(__name__)
async def fetch_pre_deployment_metrics(
runtime: "object", # DistributedRuntime; typed loosely to avoid hard import
namespace: str,
worker_info: WorkerInfo,
profile_results_dir: Optional[str],
component_type: SubComponentType,
) -> list[ForwardPassMetrics]:
"""Fetch pre-deployment engine perf data as an FPM list.
1. Try ``get_perf_metrics`` endpoint (PR 7779 self-benchmark)
2. Fallback: convert legacy profiler data (npz or JSON) to synthetic FPMs
3. If both fail: raise
Args:
runtime: DistributedRuntime instance.
namespace: Dynamo namespace.
worker_info: WorkerInfo for the target component.
profile_results_dir: Path to legacy profiler npz data (fallback).
component_type: PREFILL or DECODE.
Returns:
List of ForwardPassMetrics suitable for regression bootstrap.
"""
fpms = await _try_endpoint(runtime, namespace, worker_info, component_type)
if fpms:
logger.info(
f"Loaded {len(fpms)} pre-deployment FPMs from get_perf_metrics endpoint"
)
return fpms
if profile_results_dir:
try:
fpms = _convert_profiling_data_to_fpms(profile_results_dir, component_type)
if fpms:
logger.info(
f"Loaded {len(fpms)} FPMs from legacy profiler npz at {profile_results_dir}"
)
return fpms
except Exception as e:
logger.warning(
f"Failed to load profiling npz from {profile_results_dir}: {e}"
)
raise RuntimeError(
"Failed to obtain pre-deployment performance data. "
"Either enable --benchmark-mode on the vLLM worker (get_perf_metrics endpoint) "
"or provide profiling results via --profile-results-dir."
)
async def _try_endpoint(
runtime: "object",
namespace: str,
worker_info: WorkerInfo,
component_type: SubComponentType,
) -> list[ForwardPassMetrics]:
"""Try to fetch benchmark FPMs from the get_perf_metrics Dynamo endpoint."""
if not worker_info.component_name:
return []
try:
endpoint = runtime.endpoint( # type: ignore[attr-defined]
f"{namespace}.{worker_info.component_name}.get_perf_metrics"
)
client = await endpoint.client()
await asyncio.sleep(0.1)
response_stream = await client.round_robin(None)
benchmark_data = None
async for resp in response_stream:
benchmark_data = resp.data()
break
if benchmark_data is None:
return []
if isinstance(benchmark_data, str):
benchmark_data = json.loads(benchmark_data)
if isinstance(benchmark_data, dict) and benchmark_data.get("status") == "error":
logger.info(
f"get_perf_metrics returned error: {benchmark_data.get('message')}"
)
return []
fpms = _extract_fpms_from_benchmark(benchmark_data, component_type)
if not fpms:
logger.warning(
"get_perf_metrics returned data but no valid FPMs were extracted "
"(possible schema mismatch)"
)
return fpms
except (ConnectionError, TimeoutError, OSError) as e:
logger.info(f"get_perf_metrics endpoint not available: {e}")
return []
except Exception as e:
logger.warning(f"get_perf_metrics unexpected error: {e}")
return []
def _extract_fpms_from_benchmark(
benchmark_data: dict,
component_type: SubComponentType,
) -> list[ForwardPassMetrics]:
"""Extract ForwardPassMetrics from PR 7779 benchmark results dict."""
import msgspec
results = benchmark_data.get("results", [])
fpms: list[ForwardPassMetrics] = []
target_types: set[str] = set()
if component_type == SubComponentType.PREFILL:
target_types = {"prefill"}
elif component_type == SubComponentType.DECODE:
target_types = {"decode"}
else:
target_types = {"prefill", "decode"}
for result in results:
point = result.get("point", {})
point_type = point.get("point_type", "")
if point_type not in target_types:
continue
for fpm_dict in result.get("fpms", []):
try:
raw = json.dumps(fpm_dict).encode()
fpm = msgspec.json.decode(raw, type=ForwardPassMetrics)
if fpm.wall_time > 0:
fpms.append(fpm)
except Exception as e:
logger.warning(f"Failed to decode FPM entry: {e}")
continue
return fpms
def _convert_profiling_data_to_fpms(
profile_results_dir: str,
component_type: SubComponentType,
) -> list[ForwardPassMetrics]:
"""Convert legacy profiler data (npz or JSON) to synthetic ForwardPassMetrics."""
fpms: list[ForwardPassMetrics] = []
if component_type in (SubComponentType.PREFILL,):
fpms.extend(_convert_prefill_profiling(profile_results_dir))
if component_type in (SubComponentType.DECODE,):
fpms.extend(_convert_decode_profiling(profile_results_dir))
return fpms
def _convert_prefill_profiling(profile_results_dir: str) -> list[ForwardPassMetrics]:
npz_path = os.path.join(
profile_results_dir, "selected_prefill_interpolation", "raw_data.npz"
)
json_path = os.path.join(profile_results_dir, "prefill_raw_data.json")
prefill_isl: np.ndarray
prefill_ttft: np.ndarray
if os.path.exists(npz_path):
with np.load(npz_path) as data:
prefill_isl = data["prefill_isl"]
prefill_ttft = data["prefill_ttft"]
elif os.path.exists(json_path):
with open(json_path) as f:
data = json.load(f)
prefill_isl = np.array(data["prefill_isl"])
prefill_ttft = np.array(data["prefill_ttft"])
else:
raise FileNotFoundError(
f"Prefill profiling data not found at {npz_path} or {json_path}"
)
fpms = []
for isl_val, ttft_ms in zip(prefill_isl, prefill_ttft):
fpms.append(
ForwardPassMetrics(
wall_time=float(ttft_ms) / 1000.0,
scheduled_requests=ScheduledRequestMetrics(
num_prefill_requests=1,
sum_prefill_tokens=int(isl_val),
),
)
)
return fpms
def _convert_decode_profiling(profile_results_dir: str) -> list[ForwardPassMetrics]:
npz_path = os.path.join(
profile_results_dir, "selected_decode_interpolation", "raw_data.npz"
)
json_path = os.path.join(profile_results_dir, "decode_raw_data.json")
x_kv_usage: np.ndarray
y_context_length: np.ndarray
z_itl: np.ndarray
max_kv_tokens: int
if os.path.exists(npz_path):
with np.load(npz_path) as data:
x_kv_usage = data["x_kv_usage"]
y_context_length = data["y_context_length"]
z_itl = data["z_itl"]
max_kv_tokens = int(data["max_kv_tokens"][0])
elif os.path.exists(json_path):
with open(json_path) as f:
data = json.load(f)
x_kv_usage = np.array(data["x_kv_usage"])
y_context_length = np.array(data["y_context_length"])
z_itl = np.array(data["z_itl"])
max_kv_tokens = int(data["max_kv_tokens"])
else:
raise FileNotFoundError(
f"Decode profiling data not found at {npz_path} or {json_path}"
)
fpms = []
for kv_usage, ctx_len, itl_ms in zip(x_kv_usage, y_context_length, z_itl):
sum_decode_kv = int(round(float(kv_usage) * max_kv_tokens))
batch_size = (
max(1, int(round(sum_decode_kv / float(ctx_len)))) if ctx_len > 0 else 1
)
fpms.append(
ForwardPassMetrics(
wall_time=float(itl_ms) / 1000.0,
scheduled_requests=ScheduledRequestMetrics(
num_decode_requests=batch_size,
sum_decode_kv_tokens=sum_decode_kv,
),
)
)
return fpms
...@@ -36,14 +36,6 @@ class PlannerPrometheusMetrics: ...@@ -36,14 +36,6 @@ class PlannerPrometheusMetrics:
f"{prefix}:observed_osl", "Observed output sequence length" f"{prefix}:observed_osl", "Observed output sequence length"
) )
# Correction factors
self.p_correction_factor = Gauge(
f"{prefix}:p_correction_factor", "Prefill correction factor"
)
self.d_correction_factor = Gauge(
f"{prefix}:d_correction_factor", "Decode correction factor"
)
# Predicted metrics # Predicted metrics
self.predicted_request_rate = Gauge( self.predicted_request_rate = Gauge(
f"{prefix}:predicted_request_rate", "Predicted request rate (req/s)" f"{prefix}:predicted_request_rate", "Predicted request rate (req/s)"
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.budget import (
_apply_component_gpu_budget,
_apply_global_gpu_budget,
)
from dynamo.planner.core.decode import DecodePlanner
from dynamo.planner.core.prefill import PrefillPlanner
from dynamo.planner.core.state import PlannerSharedState
from dynamo.planner.offline.dryrun_plot import create_dryrun_plot
from dynamo.planner.offline.trace_data import extract_metrics_from_mooncake
def run_sla_planner_dryrun(
config: PlannerConfig,
dataset: str,
start_num_p: int = 1,
start_num_d: int = 1,
output_plot: str = "dryrun_plot.png",
) -> None:
if config.enable_load_scaling:
raise ValueError(
"Load-based scaling is not supported in dryrun mode. "
"Set enable_load_scaling to false in the config."
)
if config.prefill_engine_num_gpu is None:
config.prefill_engine_num_gpu = 1
if config.decode_engine_num_gpu is None:
config.decode_engine_num_gpu = 1
warmup_metrics = None
if config.load_predictor_warmup_trace is not None:
warmup_metrics = extract_metrics_from_mooncake(
config.load_predictor_warmup_trace,
config.throughput_adjustment_interval,
)
metrics = extract_metrics_from_mooncake(
dataset, config.throughput_adjustment_interval
)
if not metrics:
raise ValueError("Empty metrics dataset: cannot run dryrun")
mode = config.mode
prefill_planner: Optional[PrefillPlanner] = None
decode_planner: Optional[DecodePlanner] = None
if mode == "disagg":
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(
None, config, dryrun=True, shared_state=shared_state
)
decode_planner = DecodePlanner(
None, config, dryrun=True, shared_state=shared_state
)
elif mode == "prefill":
prefill_planner = PrefillPlanner(None, config, dryrun=True)
elif mode == "decode":
decode_planner = DecodePlanner(None, config, dryrun=True)
else:
raise ValueError(f"Invalid planner mode: {mode}")
def compute_safe_p_thpt(num_p: int, isl: float, ttft: float):
"""safe throughput is maximum throughput that the engine can handle given the TTFT SLA"""
assert prefill_planner is not None
actual_ttft = prefill_planner.prefill_interpolator.interpolate_ttft(isl)
if actual_ttft > ttft:
return 0
return num_p * prefill_planner.prefill_interpolator.interpolate_thpt_per_gpu(
isl
)
def compute_safe_d_thpt(num_d: int, isl: float, osl: float, itl: float):
"""safe throughput is maximum throughput that the engine can handle given the ITL SLA"""
assert decode_planner is not None
(
pred_decode_thpt_per_gpu,
actual_itl,
_,
) = decode_planner.decode_interpolator.find_best_throughput_per_gpu(
itl=itl, context_length=isl + osl / 2
)
if actual_itl > itl:
return 0
return num_d * pred_decode_thpt_per_gpu
time_series = [0]
rr = [metrics[0]["request_count"]]
est_rr = [metrics[0]["request_count"]]
isl = [metrics[0]["avg_isl"]]
est_isl = [metrics[0]["avg_isl"]]
osl = [metrics[0]["avg_osl"]]
est_osl = [metrics[0]["avg_osl"]]
interval = config.throughput_adjustment_interval
if prefill_planner is not None:
num_p = [start_num_p]
p_thpt = [rr[0] * isl[0]]
safe_p_thpt = [compute_safe_p_thpt(start_num_p, isl[0], config.ttft) * interval]
prefill_planner.dryrun_observe_traffic_stats(rr[0], isl[0], osl[0])
else:
num_p = [0]
p_thpt = [0]
safe_p_thpt = [0]
if decode_planner is not None:
num_d = [start_num_d]
d_thpt = [rr[0] * osl[0]]
safe_d_thpt = [
compute_safe_d_thpt(start_num_d, isl[0], osl[0], config.itl) * interval
]
decode_planner.dryrun_observe_traffic_stats(rr[0], isl[0], osl[0])
else:
num_d = [0]
d_thpt = [0]
safe_d_thpt = [0]
predictor_planner = prefill_planner or decode_planner
assert predictor_planner is not None
for metric in metrics[1:]:
time_series.append(time_series[-1] + interval)
_est_rr, _est_isl, _est_osl = predictor_planner.predict_load()
# predict_load() returns Optional[float] values; in dryrun mode with
# pre-loaded data the predictors always return valid floats.
assert (
_est_rr is not None and _est_isl is not None and _est_osl is not None
), "predict_load() returned None in dryrun mode"
est_rr.append(_est_rr)
est_isl.append(_est_isl)
est_osl.append(_est_osl)
_num_p = (
prefill_planner._compute_replica_requirements(_est_rr, _est_isl, _est_osl)
if prefill_planner is not None
else 0
)
_num_d = (
decode_planner._compute_replica_requirements(_est_rr, _est_isl, _est_osl)
if decode_planner is not None
else 0
)
if prefill_planner is not None and decode_planner is not None:
_num_p, _num_d = _apply_global_gpu_budget(_num_p, _num_d, config)
elif prefill_planner is not None:
assert config.prefill_engine_num_gpu is not None
_num_p = _apply_component_gpu_budget(
_num_p, config.prefill_engine_num_gpu, config
)
elif decode_planner is not None:
assert config.decode_engine_num_gpu is not None
_num_d = _apply_component_gpu_budget(
_num_d, config.decode_engine_num_gpu, config
)
num_p.append(_num_p)
num_d.append(_num_d)
for planner in [prefill_planner, decode_planner]:
if planner is not None:
planner.dryrun_observe_traffic_stats(
metric["request_count"], metric["avg_isl"], metric["avg_osl"]
)
rr.append(metric["request_count"])
isl.append(metric["avg_isl"])
osl.append(metric["avg_osl"])
p_thpt.append(rr[-1] * isl[-1] if prefill_planner is not None else 0)
d_thpt.append(rr[-1] * osl[-1] if decode_planner is not None else 0)
safe_p_thpt.append(
compute_safe_p_thpt(num_p[-1], isl[-1], config.ttft) * interval
if prefill_planner is not None
else 0
)
safe_d_thpt.append(
compute_safe_d_thpt(num_d[-1], isl[-1], osl[-1], config.itl) * interval
if decode_planner is not None
else 0
)
warmup_time = None
warmup_rr = None
warmup_isl = None
warmup_osl = None
if warmup_metrics:
n = len(warmup_metrics)
warmup_time = [-(n - i) * interval for i in range(n)]
warmup_rr = [m["request_count"] for m in warmup_metrics]
warmup_isl = [m["avg_isl"] for m in warmup_metrics]
warmup_osl = [m["avg_osl"] for m in warmup_metrics]
create_dryrun_plot(
time=time_series,
rr=rr,
est_rr=est_rr,
isl=isl,
est_isl=est_isl,
osl=osl,
est_osl=est_osl,
num_p=num_p,
p_thpt=p_thpt,
safe_p_thpt=safe_p_thpt,
num_d=num_d,
d_thpt=d_thpt,
safe_d_thpt=safe_d_thpt,
output_path=output_plot,
warmup_time=warmup_time,
warmup_rr=warmup_rr,
warmup_isl=warmup_isl,
warmup_osl=warmup_osl,
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib.pyplot as plt
def create_dryrun_plot(
time: list,
rr: list,
est_rr: list,
isl: list,
est_isl: list,
osl: list,
est_osl: list,
num_p: list,
p_thpt: list,
safe_p_thpt: list,
num_d: list,
d_thpt: list,
safe_d_thpt: list,
output_path: str,
warmup_time: list | None = None,
warmup_rr: list | None = None,
warmup_isl: list | None = None,
warmup_osl: list | None = None,
) -> None:
"""
Create a comprehensive dryrun plot with 4 subplots showing various metrics over time.
Args:
time: List of time points
rr: List of actual request rates
est_rr: List of estimated request rates
isl: List of actual input sequence lengths
est_isl: List of estimated input sequence lengths
osl: List of actual output sequence lengths
est_osl: List of estimated output sequence lengths
num_p: List of prefill worker counts
p_thpt: List of actual prefill throughputs
safe_p_thpt: List of safe prefill throughput limits
num_d: List of decode worker counts
d_thpt: List of actual decode throughputs
safe_d_thpt: List of safe decode throughput limits
output_path: Path where the plot should be saved
warmup_time: Optional list of warmup time points (negative seconds)
warmup_rr: Optional list of warmup request rates (same units as rr)
warmup_isl: Optional list of warmup input sequence lengths
warmup_osl: Optional list of warmup output sequence lengths
"""
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
# Plot 1: Request Rate
if warmup_time is not None and warmup_rr is not None and len(warmup_time) > 0:
ax1.plot(
warmup_time,
warmup_rr,
"b-",
alpha=0.35,
linewidth=2,
label="Warmup Request Rate",
)
ax1.plot(time, rr, "b-", label="Actual Request Rate", linewidth=2)
ax1.plot(time, est_rr, "r--", label="Predicted Request Rate", linewidth=2)
if warmup_time is not None and warmup_rr is not None:
ax1.axvline(0, color="k", linestyle=":", linewidth=2, label="Warmup Boundary")
ax1.set_xlabel("Time (s)")
ax1.set_ylabel("Request Rate")
ax1.set_ylim(bottom=0)
ax1.set_title("Request Rate Over Time")
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot 2: Sequence Lengths
if (
warmup_time is not None
and warmup_isl is not None
and warmup_osl is not None
and len(warmup_time) > 0
):
ax2.plot(
warmup_time,
warmup_isl,
"g-",
alpha=0.35,
linewidth=2,
label="Warmup ISL",
)
ax2.plot(
warmup_time,
warmup_osl,
"m-",
alpha=0.35,
linewidth=2,
label="Warmup OSL",
)
ax2.plot(time, isl, "g-", label="Actual ISL", linewidth=2)
ax2.plot(time, est_isl, "g--", label="Predicted ISL", linewidth=2)
ax2.plot(time, osl, "m-", label="Actual OSL", linewidth=2)
ax2.plot(time, est_osl, "m--", label="Predicted OSL", linewidth=2)
if warmup_time is not None and warmup_isl is not None and warmup_osl is not None:
ax2.axvline(0, color="k", linestyle=":", linewidth=2, label="Warmup Boundary")
ax2.set_xlabel("Time (s)")
ax2.set_ylabel("Num Tokens")
ax2.set_ylim(bottom=0)
ax2.set_title("Input/Output Sequence Lengths Over Time")
ax2.legend()
ax2.grid(True, alpha=0.3)
# Plot 3: Worker Counts
ax3.plot(time, p_thpt, "b-", label="Actual Prefill Throughput", linewidth=2)
ax3.plot(
time, safe_p_thpt, "b--", label="Safe Prefill Throughput Limit", linewidth=2
)
ax3_right = ax3.twinx()
ax3_right.plot(time, num_p, "c-", label="Prefill Workers", linewidth=2, marker="o")
ax3_right.set_ylabel("Number of Workers")
lines1, labels1 = ax3.get_legend_handles_labels()
lines2, labels2 = ax3_right.get_legend_handles_labels()
ax3.legend(lines1 + lines2, labels1 + labels2, loc="upper left")
ax3.set_xlabel("Time (s)")
ax3.set_ylabel("Throughput (tok/adjustment_interval)")
ax3.set_ylim(bottom=0)
ax3_right.set_ylabel("Number of Workers")
ax3_right.set_ylim(bottom=0)
ax3.set_title("Prefill Load and Workers")
ax3.grid(True, alpha=0.3)
# Plot 4: Throughput Comparison
ax4.plot(time, d_thpt, "r-", label="Actual Decode Throughput", linewidth=2)
ax4.plot(
time, safe_d_thpt, "r--", label="Safe Decode Throughput Limit", linewidth=2
)
ax4_right = ax4.twinx()
ax4_right.plot(
time, num_d, "orange", label="Decode Workers", linewidth=2, marker="o"
)
ax4_right.set_ylabel("Number of Workers")
lines1, labels1 = ax4.get_legend_handles_labels()
lines2, labels2 = ax4_right.get_legend_handles_labels()
ax4.legend(lines1 + lines2, labels1 + labels2, loc="upper left")
ax4.set_xlabel("Time (s)")
ax4.set_ylabel("Throughput (tok/adjustment_interval)")
ax4.set_ylim(bottom=0)
ax4_right.set_ylabel("Number of Workers")
ax4_right.set_ylim(bottom=0)
ax4.set_title("Decode Load and Workers")
ax4.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight")
plt.close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment