Unverified Commit e2f1e04f authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

refactor: separate planner into prefill/decode planner (#5622)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent 77aadb72
......@@ -79,6 +79,7 @@ class SLAPlannerDefaults(BasePlannerDefaults):
kalman_min_points = 5
no_correction = False # disable correction factor, might be useful under some conditions like long cold start time
mode = "disagg" # ["disagg", "prefill", "decode"]
class VllmComponentName:
......
......@@ -113,6 +113,8 @@ class KubernetesConnector(PlannerConnector):
self,
prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None,
require_prefill: bool = True,
require_decode: bool = True,
):
"""
Verify that the deployment contains services with subComponentType prefill and decode and the model name exists.
......@@ -126,34 +128,45 @@ class KubernetesConnector(PlannerConnector):
errors = []
try:
get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.PREFILL,
component_name=prefill_component_name,
)
except PlannerError as e:
errors.append(str(e))
if require_prefill:
try:
get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.PREFILL,
component_name=prefill_component_name,
)
except PlannerError as e:
errors.append(str(e))
if require_decode:
try:
get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.DECODE,
component_name=decode_component_name,
)
except PlannerError as e:
errors.append(str(e))
try:
get_service_from_sub_component_type_or_name(
self.get_model_name(
deployment,
SubComponentType.DECODE,
component_name=decode_component_name,
require_prefill=require_prefill,
require_decode=require_decode,
)
except PlannerError as e:
errors.append(str(e))
try:
self.get_model_name(deployment)
except PlannerError as e:
errors.append(str(e))
# Raise combined error if any issues found
if errors:
raise DeploymentValidationError(errors)
def get_model_name(self, deployment: Optional[dict] = None) -> str:
def get_model_name(
self,
deployment: Optional[dict] = None,
require_prefill: bool = True,
require_decode: bool = True,
) -> str:
"""Get the model name from the deployment"""
try:
if deployment is None:
......@@ -163,16 +176,20 @@ class KubernetesConnector(PlannerConnector):
# TODO: benchmarks/profiler/utils/config.py already contains DGD config parsing
# and model name logic, should consolidate
prefill_service = get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.PREFILL,
)
decode_service = get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.DECODE,
)
prefill_model_name = prefill_service.get_model_name()
decode_model_name = decode_service.get_model_name()
prefill_model_name = None
decode_model_name = None
if require_prefill:
prefill_service = get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.PREFILL,
)
prefill_model_name = prefill_service.get_model_name()
if require_decode:
decode_service = get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.DECODE,
)
decode_model_name = decode_service.get_model_name()
if prefill_model_name is None and decode_model_name is None:
raise ModelNameNotFoundError()
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
from typing import Optional
from dynamo.planner.utils.dryrun_plot_utils import create_dryrun_plot
from dynamo.planner.utils.planner_core import (
DecodePlanner,
PlannerSharedState,
PrefillPlanner,
_apply_component_gpu_budget,
_apply_global_gpu_budget,
)
from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_mooncake
def run_sla_planner_dryrun(args: argparse.Namespace) -> None:
warmup_metrics = None
if getattr(args, "load_predictor_warmup_trace", None):
warmup_metrics = extract_metrics_from_mooncake(
args.load_predictor_warmup_trace,
args.adjustment_interval,
)
metrics = extract_metrics_from_mooncake(args.dataset, args.adjustment_interval)
if not metrics:
raise ValueError("Empty metrics dataset: cannot run dryrun")
mode = getattr(args, "mode", "disagg")
prefill_planner: Optional[PrefillPlanner] = None
decode_planner: Optional[DecodePlanner] = None
if mode == "disagg":
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(
None, args, dryrun=True, shared_state=shared_state
)
decode_planner = DecodePlanner(
None, args, dryrun=True, shared_state=shared_state
)
elif mode == "prefill":
prefill_planner = PrefillPlanner(None, args, dryrun=True)
elif mode == "decode":
decode_planner = DecodePlanner(None, args, dryrun=True)
else:
raise ValueError(f"Invalid planner mode: {mode}")
def compute_safe_p_thpt(num_p: int, isl: float, ttft: float):
"""safe throughput is maximum throughput that the engine can handle given the TTFT SLA"""
assert prefill_planner is not None
actual_ttft = prefill_planner.prefill_interpolator.interpolate_ttft(isl)
if actual_ttft > ttft:
return 0
return num_p * prefill_planner.prefill_interpolator.interpolate_thpt_per_gpu(
isl
)
def compute_safe_d_thpt(num_d: int, isl: float, osl: float, itl: float):
"""safe throughput is maximum throughput that the engine can handle given the ITL SLA"""
assert decode_planner is not None
(
pred_decode_thpt_per_gpu,
actual_itl,
_,
) = decode_planner.decode_interpolator.find_best_throughput_per_gpu(
itl=itl, context_length=isl + osl / 2
)
if actual_itl > itl:
return 0
return num_d * pred_decode_thpt_per_gpu
time_series = [0]
rr = [metrics[0]["request_count"]]
est_rr = [metrics[0]["request_count"]]
isl = [metrics[0]["avg_isl"]]
est_isl = [metrics[0]["avg_isl"]]
osl = [metrics[0]["avg_osl"]]
est_osl = [metrics[0]["avg_osl"]]
if prefill_planner is not None:
num_p = [args.start_num_p]
p_thpt = [rr[0] * isl[0]]
safe_p_thpt = [
compute_safe_p_thpt(args.start_num_p, isl[0], args.ttft)
* args.adjustment_interval
]
prefill_planner.dryrun_observe_metrics(rr[0], isl[0], osl[0])
else:
num_p = [0]
p_thpt = [0]
safe_p_thpt = [0]
if decode_planner is not None:
num_d = [args.start_num_d]
d_thpt = [rr[0] * osl[0]]
safe_d_thpt = [
compute_safe_d_thpt(args.start_num_d, isl[0], osl[0], args.itl)
* args.adjustment_interval
]
decode_planner.dryrun_observe_metrics(rr[0], isl[0], osl[0])
else:
num_d = [0]
d_thpt = [0]
safe_d_thpt = [0]
predictor_planner = prefill_planner or decode_planner
assert predictor_planner is not None
for metric in metrics[1:]:
# update time
time_series.append(time_series[-1] + args.adjustment_interval)
# load prediction
_est_rr, _est_isl, _est_osl = predictor_planner.predict_load()
est_rr.append(_est_rr)
est_isl.append(_est_isl)
est_osl.append(_est_osl)
# compute num_p and num_d
_num_p = (
prefill_planner._compute_replica_requirements(_est_rr, _est_isl, _est_osl)
if prefill_planner is not None
else 0
)
_num_d = (
decode_planner._compute_replica_requirements(_est_rr, _est_isl, _est_osl)
if decode_planner is not None
else 0
)
# apply GPU budget
if prefill_planner is not None and decode_planner is not None:
_num_p, _num_d = _apply_global_gpu_budget(_num_p, _num_d, args)
elif prefill_planner is not None:
_num_p = _apply_component_gpu_budget(
_num_p, args.prefill_engine_num_gpu, args
)
elif decode_planner is not None:
_num_d = _apply_component_gpu_budget(
_num_d, args.decode_engine_num_gpu, args
)
num_p.append(_num_p)
num_d.append(_num_d)
# update load predictor
for planner in [prefill_planner, decode_planner]:
if planner is not None:
planner.dryrun_observe_metrics(
metric["request_count"], metric["avg_isl"], metric["avg_osl"]
)
# fill in ground truth
rr.append(metric["request_count"])
isl.append(metric["avg_isl"])
osl.append(metric["avg_osl"])
p_thpt.append(rr[-1] * isl[-1] if prefill_planner is not None else 0)
d_thpt.append(rr[-1] * osl[-1] if decode_planner is not None else 0)
safe_p_thpt.append(
compute_safe_p_thpt(num_p[-1], isl[-1], args.ttft)
* args.adjustment_interval
if prefill_planner is not None
else 0
)
safe_d_thpt.append(
compute_safe_d_thpt(num_d[-1], isl[-1], osl[-1], args.itl)
* args.adjustment_interval
if decode_planner is not None
else 0
)
warmup_time = None
warmup_rr = None
warmup_isl = None
warmup_osl = None
if warmup_metrics:
interval = args.adjustment_interval
n = len(warmup_metrics)
warmup_time = [-(n - i) * interval for i in range(n)]
warmup_rr = [m["request_count"] for m in warmup_metrics]
warmup_isl = [m["avg_isl"] for m in warmup_metrics]
warmup_osl = [m["avg_osl"] for m in warmup_metrics]
create_dryrun_plot(
time=time_series,
rr=rr,
est_rr=est_rr,
isl=isl,
est_isl=est_isl,
osl=osl,
est_osl=est_osl,
num_p=num_p,
p_thpt=p_thpt,
safe_p_thpt=safe_p_thpt,
num_d=num_d,
d_thpt=d_thpt,
safe_d_thpt=safe_d_thpt,
output_path=args.output_plot,
warmup_time=warmup_time,
warmup_rr=warmup_rr,
warmup_isl=warmup_isl,
warmup_osl=warmup_osl,
)
......@@ -42,6 +42,12 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
choices=["vllm", "sglang", "trtllm", "mocker"],
help="Backend type",
)
parser.add_argument(
"--mode",
default=SLAPlannerDefaults.mode,
choices=["disagg", "prefill", "decode"],
help="Planner mode: disagg (prefill+decode), prefill-only, or decode-only",
)
parser.add_argument(
"--no-operation",
action="store_true",
......@@ -61,7 +67,7 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
"--max-gpu-budget",
type=int,
default=SLAPlannerDefaults.max_gpu_budget,
help="Maximum GPU budget",
help="Maximum GPU budget (-1 for no budget enforcement)",
)
parser.add_argument(
"--min-endpoint",
......
......@@ -6,7 +6,7 @@ import asyncio
import logging
import math
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional
from prometheus_client import Gauge, start_http_server
......@@ -119,15 +119,110 @@ class PlannerPrometheusMetrics:
self.gpu_hours = Gauge(f"{prefix}:gpu_hours", "Cumulative GPU hours used")
class Planner:
@dataclass
class PlannerSharedState:
last_metrics: Metrics = field(default_factory=Metrics)
p_endpoints: list = field(default_factory=list)
d_endpoints: list = field(default_factory=list)
cumulative_gpu_hours: float = 0.0
last_adjustment_time: float = 0.0
def _apply_global_gpu_budget(
next_num_p: int, next_num_d: int, args: argparse.Namespace
) -> tuple[int, int]:
"""Apply GPU budget constraint to both prefill and decode replicas.
When total GPUs required (num_p * prefill_gpus + num_d * decode_gpus) exceeds the
budget, scale down both proportionally using scale = budget / total_required. Prefill
replicas are clamped to [min_endpoint, max_prefill] where max_prefill reserves enough
GPUs for min_endpoint decode replicas. Remaining budget is then allocated to decode.
Returns (0, 0) if budget cannot satisfy min_endpoint for both components.
"""
if args.max_gpu_budget < 0:
return next_num_p, next_num_d
total_gpu_required = (
next_num_p * args.prefill_engine_num_gpu
+ next_num_d * args.decode_engine_num_gpu
)
if total_gpu_required <= args.max_gpu_budget:
return next_num_p, next_num_d
min_required = (
args.min_endpoint * args.prefill_engine_num_gpu
+ args.min_endpoint * args.decode_engine_num_gpu
)
if args.max_gpu_budget < min_required:
logger.warning(
f"max_gpu_budget ({args.max_gpu_budget}) is below the minimum required "
f"for min_endpoint ({min_required}); enforcing zero replicas"
)
return 0, 0
scale = args.max_gpu_budget / total_gpu_required
max_prefill = math.floor(
(args.max_gpu_budget - args.min_endpoint * args.decode_engine_num_gpu)
/ args.prefill_engine_num_gpu
)
next_num_p = max(
args.min_endpoint, min(max_prefill, math.floor(next_num_p * scale))
)
remaining = args.max_gpu_budget - next_num_p * args.prefill_engine_num_gpu
next_num_d = max(
args.min_endpoint, math.floor(remaining / args.decode_engine_num_gpu)
)
logger.warning(
f"Total number of GPUs required ({total_gpu_required}) exceeds the max GPU budget ({args.max_gpu_budget}), "
f"scaling down to {next_num_p} prefill and {next_num_d} decode replicas"
)
return next_num_p, next_num_d
def _apply_component_gpu_budget(
desired_replicas: int, engine_num_gpu: int, args: argparse.Namespace
) -> int:
"""Apply GPU budget constraint to a single component (prefill-only or decode-only).
When total GPUs required (replicas * gpus_per_replica) exceeds the budget, scale down
using scale = budget / total_required, floored and clamped to at least min_endpoint.
Returns 0 if budget cannot satisfy min_endpoint replicas.
"""
if args.max_gpu_budget < 0:
return desired_replicas
total_gpu_required = desired_replicas * engine_num_gpu
if total_gpu_required <= args.max_gpu_budget:
return desired_replicas
min_required = args.min_endpoint * engine_num_gpu
if args.max_gpu_budget < min_required:
logger.warning(
f"max_gpu_budget ({args.max_gpu_budget}) is below the minimum required "
f"for min_endpoint ({min_required}); enforcing zero replicas"
)
return 0
scale = args.max_gpu_budget / total_gpu_required
next_num = max(args.min_endpoint, math.floor(desired_replicas * scale))
logger.warning(
f"Total number of GPUs required ({total_gpu_required}) exceeds the max GPU budget ({args.max_gpu_budget}), "
f"scaling down to {next_num} replicas"
)
return next_num
class BasePlanner:
component_type: SubComponentType
def __init__(
self,
runtime: Optional[DistributedRuntime],
args: argparse.Namespace,
dryrun: bool = False,
shared_state: Optional[PlannerSharedState] = None,
prometheus_metrics: Optional[PlannerPrometheusMetrics] = None,
prometheus_api_client: Optional[PrometheusAPIClient] = None,
connector=None,
start_prometheus_server: bool = True,
):
self.args = args
self.dryrun = dryrun
self.shared_state = shared_state or PlannerSharedState()
# Rely on getting model name from connector
self.model_name: Optional[str] = None
......@@ -137,7 +232,9 @@ class Planner:
self.namespace = args.namespace
if not args.no_operation:
if args.environment == "kubernetes":
if connector is not None:
self.connector = connector
elif args.environment == "kubernetes":
self.connector = KubernetesConnector(
self.namespace, self.model_name
)
......@@ -150,7 +247,7 @@ class Planner:
else:
raise ValueError(f"Invalid environment: {args.environment}")
self.prometheus_api_client = PrometheusAPIClient(
self.prometheus_api_client = prometheus_api_client or PrometheusAPIClient(
args.metric_pulling_prometheus_endpoint,
args.namespace,
)
......@@ -231,22 +328,16 @@ class Planner:
if not self.dryrun:
self.prefill_client = None
self.workers_client = None
self.p_endpoints = [] # type: ignore
self.d_endpoints = [] # type: ignore
self.last_adjustment_time = time.time()
self.last_metrics = Metrics()
self.prometheus_port = args.metric_reporting_prometheus_port
# Initialize Prometheus metrics
self.prometheus_metrics = PlannerPrometheusMetrics()
# Track cumulative GPU hours
self.cumulative_gpu_hours = 0.0
if prometheus_metrics is None:
self.prometheus_metrics = PlannerPrometheusMetrics()
else:
self.prometheus_metrics = prometheus_metrics
# Start Prometheus HTTP server if port is specified
if self.prometheus_port != 0:
if start_prometheus_server and self.prometheus_port != 0:
try:
start_http_server(self.prometheus_port)
logger.info(
......@@ -254,6 +345,9 @@ class Planner:
)
except Exception as e:
logger.error(f"Failed to start Prometheus metrics server: {e}")
else:
self.prometheus_port = 0
self.prometheus_metrics = prometheus_metrics
self.p_correction_factor = 1.0
self.d_correction_factor = 1.0
......@@ -262,6 +356,14 @@ class Planner:
else:
self.no_correction = args.no_correction
@property
def last_metrics(self) -> Metrics:
return self.shared_state.last_metrics
@last_metrics.setter
def last_metrics(self, value: Metrics) -> None:
self.shared_state.last_metrics = value
async def _async_init(self):
"""Async initialization for components that need it"""
if (
......@@ -271,80 +373,94 @@ class Planner:
):
await self.connector._async_init()
async def get_workers_info(self):
async def _get_model_name(self, require_prefill: bool, require_decode: bool) -> str:
model_name = self.connector.get_model_name(
require_prefill=require_prefill, require_decode=require_decode
)
if asyncio.iscoroutine(model_name):
model_name = await model_name
return model_name
async def _get_or_create_client(self, component_name: str, endpoint_name: str):
"""Create a client for the given component and endpoint, with a brief sleep for state sync."""
client = (
await self.runtime.namespace(self.namespace)
.component(component_name)
.endpoint(endpoint_name)
.client()
)
# TODO: remove this sleep after rust client() is blocking until watching state
await asyncio.sleep(0.1)
return client
async def get_workers_info(
self, require_prefill: bool = True, require_decode: bool = True
):
if self.runtime is None:
raise RuntimeError("Runtime is not initialized")
try:
if self.prefill_client is None:
self.prefill_client = (
await self.runtime.namespace(self.namespace)
.component(
WORKER_COMPONENT_NAMES[
self.args.backend
].prefill_worker_component_name
)
.endpoint(
WORKER_COMPONENT_NAMES[
self.args.backend
].prefill_worker_endpoint
p_endpoints = []
d_endpoints = []
worker_names = WORKER_COMPONENT_NAMES[self.args.backend]
if require_prefill:
try:
if self.prefill_client is None:
self.prefill_client = await self._get_or_create_client(
worker_names.prefill_worker_component_name,
worker_names.prefill_worker_endpoint,
)
.client()
p_endpoints = self.prefill_client.instance_ids() # type: ignore
except Exception:
p_endpoints = []
logger.warning(
"No prefill workers found, aggregated mode is not supported yet"
)
# TODO: remove this sleep after rust client() is blocking until watching state
await asyncio.sleep(0.1)
# TODO: use etcd events instead of pulling instance_ids
p_endpoints = self.prefill_client.instance_ids() # type: ignore
except Exception:
p_endpoints = []
logger.warning(
"No prefill workers found, aggregated mode is not supported yet"
)
try:
if self.workers_client is None:
self.workers_client = (
await self.runtime.namespace(self.namespace)
.component(
WORKER_COMPONENT_NAMES[
self.args.backend
].decode_worker_component_name
)
.endpoint(
WORKER_COMPONENT_NAMES[self.args.backend].decode_worker_endpoint
if require_decode:
try:
if self.workers_client is None:
self.workers_client = await self._get_or_create_client(
worker_names.decode_worker_component_name,
worker_names.decode_worker_endpoint,
)
.client()
)
# TODO: remove this sleep after rust client() is blocking until watching state
await asyncio.sleep(0.1)
# TODO: use etcd events instead of pulling instance_ids
d_endpoints = self.workers_client.instance_ids() # type: ignore
except Exception as e:
raise RuntimeError(f"Failed to get decode worker endpoints: {e}")
d_endpoints = self.workers_client.instance_ids() # type: ignore
except Exception as e:
raise RuntimeError(f"Failed to get decode worker endpoints: {e}")
return p_endpoints, d_endpoints
async def observe_metrics(self):
self.p_endpoints, self.d_endpoints = await self.get_workers_info()
async def observe_metrics(
self, require_prefill: bool = True, require_decode: bool = True
):
p_endpoints, d_endpoints = await self.get_workers_info(
require_prefill=require_prefill, require_decode=require_decode
)
self.shared_state.p_endpoints = p_endpoints
self.shared_state.d_endpoints = d_endpoints
logger.debug(
f"Number of prefill workers: {len(self.p_endpoints)}, number of decode workers: {len(self.d_endpoints)}"
f"Number of prefill workers: {len(p_endpoints)}, number of decode workers: {len(d_endpoints)}"
)
# Update Prometheus metrics if server is running
if self.prometheus_port != 0:
self.prometheus_metrics.num_p_workers.set(len(self.p_endpoints))
self.prometheus_metrics.num_d_workers.set(len(self.d_endpoints))
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.num_p_workers.set(len(p_endpoints))
self.prometheus_metrics.num_d_workers.set(len(d_endpoints))
# Calculate and accumulate GPU hours for this interval
# TODO: track startup and shutdown times to get more accurate GPU hours
interval_gpu_hours = (
(
len(self.p_endpoints) * self.args.prefill_engine_num_gpu
+ len(self.d_endpoints) * self.args.decode_engine_num_gpu
len(p_endpoints) * self.args.prefill_engine_num_gpu
+ len(d_endpoints) * self.args.decode_engine_num_gpu
)
* self.args.adjustment_interval
/ 3600
)
self.cumulative_gpu_hours += interval_gpu_hours
self.prometheus_metrics.gpu_hours.set(self.cumulative_gpu_hours)
self.shared_state.cumulative_gpu_hours += interval_gpu_hours
self.prometheus_metrics.gpu_hours.set(
self.shared_state.cumulative_gpu_hours
)
# Prometheus returns seconds, convert to milliseconds
self.last_metrics.ttft = (
......@@ -392,7 +508,7 @@ class Planner:
)
# Update observed metrics in Prometheus
if self.prometheus_port != 0:
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.observed_ttft.set(self.last_metrics.ttft)
self.prometheus_metrics.observed_itl.set(self.last_metrics.itl)
self.prometheus_metrics.observed_request_rate.set(
......@@ -404,9 +520,12 @@ class Planner:
self.prometheus_metrics.observed_isl.set(self.last_metrics.isl)
self.prometheus_metrics.observed_osl.set(self.last_metrics.osl)
self.num_req_predictor.add_data_point(self.last_metrics.num_req)
self.isl_predictor.add_data_point(self.last_metrics.isl)
self.osl_predictor.add_data_point(self.last_metrics.osl)
self.update_predictors_from_metrics(self.last_metrics)
def update_predictors_from_metrics(self, metrics: Metrics) -> None:
self.num_req_predictor.add_data_point(metrics.num_req)
self.isl_predictor.add_data_point(metrics.isl)
self.osl_predictor.add_data_point(metrics.osl)
def predict_load(self):
try:
......@@ -427,43 +546,193 @@ class Planner:
self.isl_predictor.add_data_point(isl_avg)
self.osl_predictor.add_data_point(osl_avg)
def plan_adjustment(self) -> Optional[int]:
# Skip adjustment if no traffic
if not self.last_metrics.is_valid():
logger.info(
"Metrics contain None or NaN values (no active requests), skipping adjustment"
)
return None
if not self.no_correction:
try:
if not self._update_correction_factor():
return None
except Exception as e:
logger.error(f"Failed to correct prediction factors: {e}")
return None
next_num_req, next_isl, next_osl = self.predict_load()
if next_num_req is None or next_isl is None or next_osl is None:
return None
# Update predicted load metrics in Prometheus
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.predicted_request_rate.set(
next_num_req / self.args.adjustment_interval
)
self.prometheus_metrics.predicted_isl.set(next_isl)
self.prometheus_metrics.predicted_osl.set(next_osl)
try:
return self._compute_replica_requirements(next_num_req, next_isl, next_osl)
except Exception as e:
logger.error(f"Failed to compute number of replicas: {e}")
return None
def update_predicted_replicas_metric(self, desired_replicas: int) -> None:
raise NotImplementedError
def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float
) -> int:
raise NotImplementedError
def _update_correction_factor(self) -> bool:
raise NotImplementedError
def _component_name(self) -> str:
if self.component_type == SubComponentType.PREFILL:
return self.prefill_component_name
return self.decode_component_name
def _engine_num_gpu(self) -> int:
if self.component_type == SubComponentType.PREFILL:
return self.args.prefill_engine_num_gpu
return self.args.decode_engine_num_gpu
def apply_component_budget(self, desired_replicas: int) -> int:
return _apply_component_gpu_budget(
desired_replicas, self._engine_num_gpu(), self.args
)
async def _apply_scaling(self, desired_replicas: int) -> None:
if self.args.no_operation:
return
target_replicas = [
TargetReplica(
sub_component_type=self.component_type,
component_name=self._component_name(),
desired_replicas=desired_replicas,
)
]
await self.connector.set_component_replicas(target_replicas, blocking=False)
async def run(self):
"""Main loop for the planner"""
if not self.args.no_operation:
logger.info("Validating deployment...")
require_prefill = self.component_type == SubComponentType.PREFILL
require_decode = self.component_type == SubComponentType.DECODE
await self.connector.validate_deployment(
prefill_component_name=(
self.prefill_component_name if require_prefill else None
),
decode_component_name=(
self.decode_component_name if require_decode else None
),
require_prefill=require_prefill,
require_decode=require_decode,
)
logger.info("Successfully validated the deployment")
await self.connector.wait_for_deployment_ready()
model_name = await self._get_model_name(
require_prefill=require_prefill, require_decode=require_decode
)
logger.info(f"Detected model name from deployment: {model_name}")
self.model_name = (
model_name.lower()
) # normalize model name to lowercase (MDC)
self.shared_state.last_adjustment_time = time.time()
while True:
current_time = time.time()
if (
current_time - self.shared_state.last_adjustment_time
>= self.args.adjustment_interval
):
self.shared_state.last_adjustment_time = time.time()
logger.info("New adjustment interval started!")
await self.observe_metrics(
require_prefill=require_prefill, require_decode=require_decode
)
desired_replicas = self.plan_adjustment()
if desired_replicas is not None:
desired_replicas = self.apply_component_budget(desired_replicas)
self.update_predicted_replicas_metric(desired_replicas)
await self._apply_scaling(desired_replicas)
# sleep for a while to avoid busy-waiting but not too long to miss the next adjustment
await asyncio.sleep(self.args.adjustment_interval / 10)
class PrefillPlanner(BasePlanner):
component_type = SubComponentType.PREFILL
def _update_correction_factor(self) -> bool:
expect_ttft = self.prefill_interpolator.interpolate_ttft(self.last_metrics.isl)
self.p_correction_factor = self.last_metrics.ttft / expect_ttft
logger.info(f"Correction factor (prefill TTFT): {self.p_correction_factor:.3f}")
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.p_correction_factor.set(self.p_correction_factor)
return True
def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float
) -> tuple[int, int]:
"""Compute the number of prefill and decode replicas needed based on predicted load.
Args:
next_num_req: Predicted number of requests
next_isl: Predicted input sequence length
next_osl: Predicted output sequence length
Returns:
tuple[int, int]: Number of prefill and decode replicas needed
"""
# compute how many replicas are needed for prefill
# here we assume the prefill bias is purely due to request queueing
# and we increase the number of prefill replicas linearly to account for the queueing delay
) -> int:
pred_prefill_throughput = (
next_num_req
* next_isl
/ self.args.adjustment_interval
* min(1, self.p_correction_factor)
)
p_thpt_per_gpu = self.prefill_interpolator.interpolate_thpt_per_gpu(next_isl)
next_num_p = math.ceil(
pred_prefill_throughput
/ self.prefill_interpolator.interpolate_thpt_per_gpu(next_isl)
/ self.args.prefill_engine_num_gpu
pred_prefill_throughput / p_thpt_per_gpu / self.args.prefill_engine_num_gpu
)
next_num_p = max(next_num_p, self.args.min_endpoint)
logger.info(
f"Prefill calculation: {pred_prefill_throughput:.2f}(p_thpt) / "
f"{self.prefill_interpolator.interpolate_thpt_per_gpu(next_isl) * self.args.prefill_engine_num_gpu:.2f}(p_engine_cap) = "
f"{p_thpt_per_gpu * self.args.prefill_engine_num_gpu:.2f}(p_engine_cap) = "
f"{next_num_p}(num_p)"
)
return next_num_p
def update_predicted_replicas_metric(self, desired_replicas: int) -> None:
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.predicted_num_p.set(desired_replicas)
# compute how many replicas are needed for decode
# 1. apply d_correction_factor to the ITL SLA
# Prevent divide by zero when d_correction_factor is 0 (no metrics yet)
class DecodePlanner(BasePlanner):
component_type = SubComponentType.DECODE
def _update_correction_factor(self) -> bool:
if not self.shared_state.d_endpoints:
logger.warning(
"No decode workers found for correction factor, skipping correction update"
)
return True
expect_itl = self.decode_interpolator.interpolate_itl(
concurrency=self.last_metrics.num_req # type: ignore
/ len(self.shared_state.d_endpoints)
* self.last_metrics.request_duration # type: ignore
/ self.args.adjustment_interval,
context_length=self.last_metrics.isl + self.last_metrics.osl / 2, # type: ignore
)
self.d_correction_factor = self.last_metrics.itl / expect_itl
logger.info(f"Correction factor (decode ITL): {self.d_correction_factor:.3f}")
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.d_correction_factor.set(self.d_correction_factor)
return True
def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float
) -> int:
if self.d_correction_factor <= 0:
logger.warning(
f"d_correction_factor is {self.d_correction_factor}, using default value of 1.0"
......@@ -471,7 +740,6 @@ class Planner:
corrected_itl = self.args.itl
else:
corrected_itl = self.args.itl / self.d_correction_factor
# 2. reversely find out what is best throughput/gpu that can achieve corrected_itl under the predicted context length
(
pred_decode_thpt_per_gpu,
_,
......@@ -479,318 +747,136 @@ class Planner:
) = self.decode_interpolator.find_best_throughput_per_gpu(
itl=corrected_itl, context_length=next_isl + next_osl / 2
)
# 3. compute number of decode replicas needed
pred_decode_throughput = next_num_req * next_osl / self.args.adjustment_interval
next_num_d = math.ceil(
pred_decode_throughput
/ pred_decode_thpt_per_gpu
/ self.args.decode_engine_num_gpu
)
next_num_d = max(next_num_d, self.args.min_endpoint)
logger.info(
f"Decode calculation: {pred_decode_throughput:.2f}(d_thpt) / "
f"{pred_decode_thpt_per_gpu * self.args.decode_engine_num_gpu:.2f}(d_engine_cap) = "
f"{next_num_d}(num_d)"
)
return next_num_d
# correct num_p and num_d based on the gpu budget
next_num_p = max(next_num_p, self.args.min_endpoint)
next_num_d = max(next_num_d, self.args.min_endpoint)
logger.info(
f"Predicted number of engine replicas: prefill={next_num_p}, decode={next_num_d}"
)
total_gpu_required = (
next_num_p * self.args.prefill_engine_num_gpu
+ next_num_d * self.args.decode_engine_num_gpu
)
if total_gpu_required > self.args.max_gpu_budget:
scale = self.args.max_gpu_budget / total_gpu_required
next_num_p = max(self.args.min_endpoint, round(next_num_p * scale))
next_num_d = max(
self.args.min_endpoint,
round(
(
self.args.max_gpu_budget
- next_num_p * self.args.prefill_engine_num_gpu
)
/ self.args.decode_engine_num_gpu
),
)
logger.warning(
f"Total number of GPUs required ({total_gpu_required}) exceeds the max GPU budget ({self.args.max_gpu_budget}), scaling down to {next_num_p} prefill and {next_num_d} decode replicas"
)
return next_num_p, next_num_d
async def make_adjustments(self):
# Skip adjustment if no traffic
if not self.last_metrics.is_valid():
logger.info(
"Metrics contain None or NaN values (no active requests), skipping adjustment"
)
return
if not self.no_correction:
try:
self.p_endpoints, self.d_endpoints = await self.get_workers_info()
logger.info(
f"Number of prefill workers: {len(self.p_endpoints)}, number of decode workers: {len(self.d_endpoints)}"
)
def update_predicted_replicas_metric(self, desired_replicas: int) -> None:
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.predicted_num_d.set(desired_replicas)
# first correct the prediction correction factor
# for TTFT, we expect the correction factor to be << 1 due to queuing delay
expect_ttft = self.prefill_interpolator.interpolate_ttft(
self.last_metrics.isl
)
self.p_correction_factor = self.last_metrics.ttft / expect_ttft
# for ITL, we expect the correction factor to be close to 1
expect_itl = self.decode_interpolator.interpolate_itl(
concurrency=self.last_metrics.num_req # type: ignore
/ len(self.d_endpoints)
* self.last_metrics.request_duration # type: ignore
/ self.args.adjustment_interval,
context_length=self.last_metrics.isl + self.last_metrics.osl / 2, # type: ignore
)
self.d_correction_factor = self.last_metrics.itl / expect_itl
logger.info(
f"Correction factors: TTFT: {self.p_correction_factor:.3f}, ITL: {self.d_correction_factor:.3f}"
)
# Update correction factor metrics in Prometheus
if self.prometheus_port != 0:
self.prometheus_metrics.p_correction_factor.set(
self.p_correction_factor
)
self.prometheus_metrics.d_correction_factor.set(
self.d_correction_factor
)
except Exception as e:
logger.error(f"Failed to correct prediction factors: {e}")
return
next_num_req, next_isl, next_osl = self.predict_load()
if next_num_req is not None and next_isl is not None and next_osl is not None:
# Update predicted load metrics in Prometheus
if self.prometheus_port != 0:
self.prometheus_metrics.predicted_request_rate.set(
next_num_req / self.args.adjustment_interval
)
self.prometheus_metrics.predicted_isl.set(next_isl)
self.prometheus_metrics.predicted_osl.set(next_osl)
try:
next_num_p, next_num_d = self._compute_replica_requirements(
next_num_req, next_isl, next_osl
)
# Update predicted replica metrics in Prometheus
if self.prometheus_port != 0:
self.prometheus_metrics.predicted_num_p.set(next_num_p)
self.prometheus_metrics.predicted_num_d.set(next_num_d)
except Exception as e:
logger.error(f"Failed to compute number of replicas: {e}")
return
class DisaggPlanner:
def __init__(
self, runtime: Optional[DistributedRuntime], args: argparse.Namespace
) -> None:
self.args = args
self.shared_state = PlannerSharedState()
prometheus_metrics = PlannerPrometheusMetrics()
self.prefill_planner = PrefillPlanner(
runtime,
args,
shared_state=self.shared_state,
prometheus_metrics=prometheus_metrics,
start_prometheus_server=True,
)
self.decode_planner = DecodePlanner(
runtime,
args,
shared_state=self.shared_state,
prometheus_metrics=prometheus_metrics,
prometheus_api_client=getattr(
self.prefill_planner, "prometheus_api_client", None
),
connector=getattr(self.prefill_planner, "connector", None),
start_prometheus_server=False,
)
if not self.args.no_operation:
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_component_name,
desired_replicas=next_num_p,
),
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.decode_component_name,
desired_replicas=next_num_d,
),
]
await self.connector.set_component_replicas(target_replicas, blocking=False)
async def _async_init(self):
# Prefill/Decode share the same connector instance in disagg mode.
await self.prefill_planner._async_init()
async def run(self):
"""Main loop for the planner"""
if not self.args.no_operation:
# Fail fast if the deployment is not valid
logger.info("Validating deployment...")
# TODO: still supporting framework component names for backwards compatibility
# Should be deprecated in favor of service subComponentType
await self.connector.validate_deployment(
prefill_component_name=self.prefill_component_name,
decode_component_name=self.decode_component_name,
await self.prefill_planner.connector.validate_deployment(
prefill_component_name=self.prefill_planner.prefill_component_name,
decode_component_name=self.prefill_planner.decode_component_name,
require_prefill=True,
require_decode=True,
)
logger.info("Successfully validated the deployment")
await self.connector.wait_for_deployment_ready()
await self.prefill_planner.connector.wait_for_deployment_ready()
model_name = self.connector.get_model_name()
model_name = await self.prefill_planner._get_model_name(
require_prefill=True, require_decode=True
)
logger.info(f"Detected model name from deployment: {model_name}")
self.model_name = (
model_name.lower()
) # normalize model name to lowercase (MDC)
model_name = model_name.lower()
self.prefill_planner.model_name = model_name
self.decode_planner.model_name = model_name
self.last_adjustment_time = time.time()
self.shared_state.last_adjustment_time = time.time()
while True:
current_time = time.time()
if (
current_time - self.last_adjustment_time
current_time - self.shared_state.last_adjustment_time
>= self.args.adjustment_interval
):
self.last_adjustment_time = time.time()
self.shared_state.last_adjustment_time = time.time()
logger.info("New adjustment interval started!")
await self.observe_metrics()
await self.make_adjustments()
await self.prefill_planner.observe_metrics(
require_prefill=True, require_decode=True
)
self.decode_planner.update_predictors_from_metrics(
self.shared_state.last_metrics
)
next_num_p = self.prefill_planner.plan_adjustment()
next_num_d = self.decode_planner.plan_adjustment()
if next_num_p is None or next_num_d is None:
continue
next_num_p, next_num_d = _apply_global_gpu_budget(
next_num_p, next_num_d, self.args
)
self.prefill_planner.update_predicted_replicas_metric(next_num_p)
self.decode_planner.update_predicted_replicas_metric(next_num_d)
if not self.args.no_operation:
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_planner.prefill_component_name,
desired_replicas=next_num_p,
),
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.prefill_planner.decode_component_name,
desired_replicas=next_num_d,
),
]
await self.prefill_planner.connector.set_component_replicas(
target_replicas, blocking=False
)
# sleep for a while to avoid busy-waiting but not too long to miss the next adjustment
await asyncio.sleep(self.args.adjustment_interval / 10)
def dryrun_run(self):
"""Run planner in dry-run mode with dataset"""
warmup_metrics = None
if getattr(self.args, "load_predictor_warmup_trace", None):
warmup_metrics = extract_metrics_from_mooncake(
self.args.load_predictor_warmup_trace,
self.args.adjustment_interval,
)
metrics = extract_metrics_from_mooncake(
self.args.dataset, self.args.adjustment_interval
)
def compute_safe_p_thpt(num_p: int, isl: float, ttft: float):
"""safe throughput is maximum throughput that the engine can handle given the TTFT SLA"""
actual_ttft = self.prefill_interpolator.interpolate_ttft(isl)
if actual_ttft > ttft:
return 0
else:
return num_p * self.prefill_interpolator.interpolate_thpt_per_gpu(isl)
def compute_safe_d_thpt(num_d: int, isl: float, osl: float, itl: float):
"""safe throughput is maximum throughput that the engine can handle given the ITL SLA"""
(
pred_decode_thpt_per_gpu,
actual_itl,
_,
) = self.decode_interpolator.find_best_throughput_per_gpu(
itl=itl, context_length=isl + osl / 2
)
if actual_itl > itl:
return 0
else:
return num_d * pred_decode_thpt_per_gpu
time = [0]
rr = [metrics[0]["request_count"]]
est_rr = [metrics[0]["request_count"]]
isl = [metrics[0]["avg_isl"]]
est_isl = [metrics[0]["avg_isl"]]
osl = [metrics[0]["avg_osl"]]
est_osl = [metrics[0]["avg_osl"]]
num_p = [self.args.start_num_p]
p_thpt = [metrics[0]["request_count"] * metrics[0]["avg_isl"]]
safe_p_thpt = [
compute_safe_p_thpt(
self.args.start_num_p, metrics[0]["avg_isl"], self.args.ttft
)
* self.args.adjustment_interval
]
num_d = [self.args.start_num_d]
d_thpt = [metrics[0]["request_count"] * metrics[0]["avg_osl"]]
safe_d_thpt = [
compute_safe_d_thpt(
self.args.start_num_d,
metrics[0]["avg_isl"],
metrics[0]["avg_osl"],
self.args.itl,
)
* self.args.adjustment_interval
]
self.dryrun_observe_metrics(
metrics[0]["request_count"], metrics[0]["avg_isl"], metrics[0]["avg_osl"]
)
for metric in metrics[1:]:
# update time
time.append(time[-1] + self.args.adjustment_interval)
# load prediction
_est_rr, _est_isl, _est_osl = self.predict_load()
est_rr.append(_est_rr)
est_isl.append(_est_isl)
est_osl.append(_est_osl)
# compute num_p and num_d
_num_p, _num_d = self._compute_replica_requirements(
_est_rr, _est_isl, _est_osl
)
num_p.append(_num_p)
num_d.append(_num_d)
# update load predictor
self.dryrun_observe_metrics(
metric["request_count"], metric["avg_isl"], metric["avg_osl"]
)
# fill in ground truth
rr.append(metric["request_count"])
isl.append(metric["avg_isl"])
osl.append(metric["avg_osl"])
p_thpt.append(rr[-1] * isl[-1])
d_thpt.append(rr[-1] * osl[-1])
safe_p_thpt.append(
compute_safe_p_thpt(num_p[-1], isl[-1], self.args.ttft)
* self.args.adjustment_interval
)
safe_d_thpt.append(
compute_safe_d_thpt(num_d[-1], isl[-1], osl[-1], self.args.itl)
* self.args.adjustment_interval
)
# plot the results
from dynamo.planner.utils.dryrun_plot_utils import create_dryrun_plot
warmup_time = None
warmup_rr = None
warmup_isl = None
warmup_osl = None
if warmup_metrics:
interval = self.args.adjustment_interval
n = len(warmup_metrics)
warmup_time = [-(n - i) * interval for i in range(n)]
warmup_rr = [m["request_count"] for m in warmup_metrics]
warmup_isl = [m["avg_isl"] for m in warmup_metrics]
warmup_osl = [m["avg_osl"] for m in warmup_metrics]
create_dryrun_plot(
time=time,
rr=rr,
est_rr=est_rr,
isl=isl,
est_isl=est_isl,
osl=osl,
est_osl=est_osl,
num_p=num_p,
p_thpt=p_thpt,
safe_p_thpt=safe_p_thpt,
num_d=num_d,
d_thpt=d_thpt,
safe_d_thpt=safe_d_thpt,
output_path=self.args.output_plot,
warmup_time=warmup_time,
warmup_rr=warmup_rr,
warmup_isl=warmup_isl,
warmup_osl=warmup_osl,
)
async def start_sla_planner(runtime: DistributedRuntime, args: argparse.Namespace):
planner = Planner(runtime, args)
mode = getattr(args, "mode", "disagg")
if mode == "disagg":
planner = DisaggPlanner(runtime, args)
elif mode == "prefill":
planner = PrefillPlanner(runtime, args)
elif mode == "decode":
planner = DecodePlanner(runtime, args)
else:
raise ValueError(f"Invalid planner mode: {mode}")
await planner._async_init()
await planner.run()
......@@ -130,6 +130,8 @@ class VirtualConnector(PlannerConnector):
self,
prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None,
require_prefill: bool = True,
require_decode: bool = True,
):
"""Validate the deployment"""
pass
......@@ -138,6 +140,9 @@ class VirtualConnector(PlannerConnector):
"""Wait for the deployment to be ready"""
await self._wait_for_scaling_completion()
async def get_model_name(self) -> str:
async def get_model_name(
self, require_prefill: bool = True, require_decode: bool = True
) -> str:
"""Get the model name from the deployment"""
del require_prefill, require_decode
return self.model_name
......@@ -36,11 +36,132 @@ sys.modules["dynamo.runtime"] = mock_runtime
sys.modules["dynamo.runtime.logging"] = mock_runtime.logging
# Now import after mocking
from dynamo.planner.utils.planner_core import Metrics, Planner # noqa: E402
from dynamo.planner.utils.planner_core import ( # noqa: E402
DecodePlanner,
Metrics,
PlannerSharedState,
PrefillPlanner,
_apply_global_gpu_budget,
)
pytestmark = [pytest.mark.pre_merge, pytest.mark.gpu_0]
class PlannerHarness:
def __init__(self, prefill_planner, decode_planner, shared_state):
self.prefill_planner = prefill_planner
self.decode_planner = decode_planner
self.shared_state = shared_state
self.last_target_replicas = []
async def make_adjustments(self):
if not self.shared_state.last_metrics.is_valid():
return
p_endpoints, d_endpoints = await self.prefill_planner.get_workers_info()
self.shared_state.p_endpoints = p_endpoints
self.shared_state.d_endpoints = d_endpoints
next_num_p = self.prefill_planner.plan_adjustment()
next_num_d = self.decode_planner.plan_adjustment()
if next_num_p is None or next_num_d is None:
return
next_num_p, next_num_d = _apply_global_gpu_budget(
next_num_p, next_num_d, self.prefill_planner.args
)
self.prefill_planner.update_predicted_replicas_metric(next_num_p)
self.decode_planner.update_predicted_replicas_metric(next_num_d)
target_replicas = [
{
"sub_component_type": "prefill",
"component_name": self.prefill_planner.prefill_component_name,
"desired_replicas": next_num_p,
},
{
"sub_component_type": "decode",
"component_name": self.prefill_planner.decode_component_name,
"desired_replicas": next_num_d,
},
]
self.last_target_replicas = target_replicas
if not self.prefill_planner.args.no_operation:
await self.prefill_planner.connector.set_component_replicas(
target_replicas, blocking=False
)
def __getattr__(self, name):
shared_attrs = {
"num_req_predictor",
"isl_predictor",
"osl_predictor",
"connector",
"prometheus_api_client",
"args",
}
prefill_attrs = {
"prefill_interpolator",
"prefill_component_name",
"p_correction_factor",
}
decode_attrs = {
"decode_interpolator",
"decode_component_name",
"d_correction_factor",
}
if name == "last_metrics":
return self.shared_state.last_metrics
if name == "get_workers_info":
return self.prefill_planner.get_workers_info
if name in shared_attrs:
return getattr(self.prefill_planner, name)
if name in prefill_attrs:
return getattr(self.prefill_planner, name)
if name in decode_attrs:
return getattr(self.decode_planner, name)
raise AttributeError(name)
def __setattr__(self, name, value):
if name in {"prefill_planner", "decode_planner", "shared_state"}:
return super().__setattr__(name, value)
shared_attrs = {
"num_req_predictor",
"isl_predictor",
"osl_predictor",
"connector",
"prometheus_api_client",
"args",
"get_workers_info",
}
prefill_attrs = {"prefill_interpolator", "p_correction_factor"}
decode_attrs = {"decode_interpolator", "d_correction_factor"}
if name == "last_metrics":
self.shared_state.last_metrics = value
return None
if name in shared_attrs:
# Store locally to support patch.object lifecycle (set/del).
object.__setattr__(self, name, value)
setattr(self.prefill_planner, name, value)
setattr(self.decode_planner, name, value)
return None
if name in prefill_attrs:
setattr(self.prefill_planner, name, value)
return None
if name in decode_attrs:
setattr(self.decode_planner, name, value)
return None
return super().__setattr__(name, value)
def _replica_count(target_replicas, component_name, default=1):
for replica in target_replicas:
if replica.get("component_name") == component_name:
return replica.get("desired_replicas", default)
return default
@pytest.fixture
def planner():
"""Set up test environment with mocked dependencies."""
......@@ -75,8 +196,10 @@ def planner():
with patch("dynamo.planner.utils.planner_core.Gauge") as mock_gauge:
mock_gauge.return_value = Mock()
# Create planner instance
planner = Planner(mock_runtime, args)
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(mock_runtime, args, shared_state=shared_state)
decode_planner = DecodePlanner(mock_runtime, args, shared_state=shared_state)
planner = PlannerHarness(prefill_planner, decode_planner, shared_state)
# Mock the interpolators to return fixed values for testing
planner.prefill_interpolator = Mock()
......@@ -165,19 +288,18 @@ class TestReplicaCalculation:
# Extract the calculated values from the log calls or by checking the mock calls
# Since we mocked the connector, we can check what replicas were requested
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
prefill_component = "VllmPrefillWorker"
calculated_prefill_replicas = call_args.get(prefill_component, 1)
print(f"Expected prefill replicas: {expected_prefill_replicas}")
print(f"Calculated prefill replicas: {calculated_prefill_replicas}")
prefill_component = "VllmPrefillWorker"
calculated_prefill_replicas = _replica_count(
planner.last_target_replicas, prefill_component
)
print(f"Expected prefill replicas: {expected_prefill_replicas}")
print(f"Calculated prefill replicas: {calculated_prefill_replicas}")
# Allow for small differences due to min_endpoint constraints
assert (
max(expected_prefill_replicas, planner.args.min_endpoint)
== calculated_prefill_replicas
)
# Allow for small differences due to min_endpoint constraints
assert (
max(expected_prefill_replicas, planner.args.min_endpoint)
== calculated_prefill_replicas
)
@pytest.mark.nightly
@pytest.mark.gpu_2
......@@ -230,19 +352,18 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments())
# Check the results
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
decode_component = "VllmDecodeWorker"
calculated_decode_replicas = call_args.get(decode_component, 1)
print(f"Expected decode replicas: {expected_decode_replicas}")
print(f"Calculated decode replicas: {calculated_decode_replicas}")
decode_component = "VllmDecodeWorker"
calculated_decode_replicas = _replica_count(
planner.last_target_replicas, decode_component
)
print(f"Expected decode replicas: {expected_decode_replicas}")
print(f"Calculated decode replicas: {calculated_decode_replicas}")
# Allow for small differences due to min_endpoint constraints
assert (
max(expected_decode_replicas, planner.args.min_endpoint)
== calculated_decode_replicas
)
# Allow for small differences due to min_endpoint constraints
assert (
max(expected_decode_replicas, planner.args.min_endpoint)
== calculated_decode_replicas
)
@pytest.mark.parametrize(
"num_req,decode_thpt,expected_p,expected_d",
......@@ -304,20 +425,20 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments())
# Verify results
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
prefill_replicas = call_args.get("VllmPrefillWorker", 1)
decode_replicas = call_args.get("VllmDecodeWorker", 1)
print(f"Load {num_req} req/s: P={prefill_replicas}, D={decode_replicas}")
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
print(f"Load {num_req} req/s: P={prefill_replicas}, D={decode_replicas}")
assert (
prefill_replicas == expected_p
), f"Prefill replicas mismatch: expected {expected_p}, got {prefill_replicas}"
assert (
decode_replicas == expected_d
), f"Decode replicas mismatch: expected {expected_d}, got {decode_replicas}"
assert (
prefill_replicas == expected_p
), f"Prefill replicas mismatch: expected {expected_p}, got {prefill_replicas}"
assert (
decode_replicas == expected_d
), f"Decode replicas mismatch: expected {expected_d}, got {decode_replicas}"
@pytest.mark.nightly
@pytest.mark.gpu_2
......@@ -359,24 +480,24 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments())
# Verify that total GPU usage doesn't exceed budget
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
prefill_replicas = call_args.get("VllmPrefillWorker", 1)
decode_replicas = call_args.get("VllmDecodeWorker", 1)
total_gpus = (
prefill_replicas * planner.args.prefill_engine_num_gpu
+ decode_replicas * planner.args.decode_engine_num_gpu
)
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
total_gpus = (
prefill_replicas * planner.args.prefill_engine_num_gpu
+ decode_replicas * planner.args.decode_engine_num_gpu
)
print(
f"GPU budget test: P={prefill_replicas}, D={decode_replicas}, Total GPUs={total_gpus}"
)
print(
f"GPU budget test: P={prefill_replicas}, D={decode_replicas}, Total GPUs={total_gpus}"
)
assert (
total_gpus <= planner.args.max_gpu_budget
), "Total GPU usage exceeds budget"
assert (
total_gpus <= planner.args.max_gpu_budget
), "Total GPU usage exceeds budget"
@pytest.mark.nightly
@pytest.mark.gpu_2
......@@ -417,20 +538,20 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments())
# Verify minimum constraints are respected
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
prefill_replicas = call_args.get("VllmPrefillWorker", 1)
decode_replicas = call_args.get("VllmDecodeWorker", 1)
print(f"Min endpoint test: P={prefill_replicas}, D={decode_replicas}")
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
print(f"Min endpoint test: P={prefill_replicas}, D={decode_replicas}")
assert (
prefill_replicas >= planner.args.min_endpoint
), "Prefill replicas below minimum"
assert (
decode_replicas >= planner.args.min_endpoint
), "Decode replicas below minimum"
assert (
prefill_replicas >= planner.args.min_endpoint
), "Prefill replicas below minimum"
assert (
decode_replicas >= planner.args.min_endpoint
), "Decode replicas below minimum"
@pytest.mark.nightly
@pytest.mark.gpu_2
......@@ -482,17 +603,17 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments())
# Verify that correction factor was effectively clamped
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
prefill_replicas = call_args.get("VllmPrefillWorker", 1)
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
print(
f"Correction factor clamping test: Expected={expected_prefill_replicas}, Got={prefill_replicas}"
)
print(
f"Correction factor clamping test: Expected={expected_prefill_replicas}, Got={prefill_replicas}"
)
assert prefill_replicas == max(
expected_prefill_replicas, planner.args.min_endpoint
), "Prefill correction factor should be clamped to 1"
assert prefill_replicas == max(
expected_prefill_replicas, planner.args.min_endpoint
), "Prefill correction factor should be clamped to 1"
@pytest.mark.nightly
@pytest.mark.gpu_2
......@@ -501,62 +622,59 @@ class TestReplicaCalculation:
"""Test handling of d_correction_factor <= 0."""
# Test both 0 and negative values
for correction_factor in [0.0, -1.0]:
with patch.object(planner, "connector") as mock_connector:
planner.p_correction_factor = 1.0
planner.d_correction_factor = correction_factor
# Mock predictor outputs
planner.num_req_predictor.predict_next.return_value = 10
planner.isl_predictor.predict_next.return_value = 3000
planner.osl_predictor.predict_next.return_value = 150
# Mock interpolator outputs
planner.prefill_interpolator.interpolate_thpt_per_gpu.return_value = (
40000
)
planner.decode_interpolator.find_best_throughput_per_gpu.return_value = (
10000,
0.01,
0.5,
)
# Set up metrics
planner.last_metrics = Metrics(
num_req=10,
isl=3000,
osl=150,
ttft=80.0,
itl=10.0,
request_duration=100.0,
)
# Mock workers info
async def mock_get_workers_info():
return (["prefill1"], ["decode1"])
planner.get_workers_info = mock_get_workers_info
# Mock interpolation calls
planner.prefill_interpolator.interpolate_ttft.return_value = 80.0
planner.decode_interpolator.interpolate_itl.return_value = 10.0
# Run calculation
asyncio.run(planner.make_adjustments())
# Should handle gracefully without crashing
# The code should use args.itl directly instead of dividing by 0
if mock_connector.set_component_replicas.called:
call_args = mock_connector.set_component_replicas.call_args[0][0]
decode_replicas = call_args.get("VllmDecodeWorker", 1)
print(
f"Correction factor {correction_factor} test: Decode replicas={decode_replicas}"
)
# Should get a valid result (not crash)
assert (
decode_replicas >= 1
), f"Should handle correction factor {correction_factor} gracefully"
planner.p_correction_factor = 1.0
planner.d_correction_factor = correction_factor
# Mock predictor outputs
planner.num_req_predictor.predict_next.return_value = 10
planner.isl_predictor.predict_next.return_value = 3000
planner.osl_predictor.predict_next.return_value = 150
# Mock interpolator outputs
planner.prefill_interpolator.interpolate_thpt_per_gpu.return_value = 40000
planner.decode_interpolator.find_best_throughput_per_gpu.return_value = (
10000,
0.01,
0.5,
)
# Set up metrics
planner.last_metrics = Metrics(
num_req=10,
isl=3000,
osl=150,
ttft=80.0,
itl=10.0,
request_duration=100.0,
)
# Mock workers info
async def mock_get_workers_info():
return (["prefill1"], ["decode1"])
planner.get_workers_info = mock_get_workers_info
# Mock interpolation calls
planner.prefill_interpolator.interpolate_ttft.return_value = 80.0
planner.decode_interpolator.interpolate_itl.return_value = 10.0
# Run calculation
asyncio.run(planner.make_adjustments())
# Should handle gracefully without crashing
# The code should use args.itl directly instead of dividing by 0
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
print(
f"Correction factor {correction_factor} test: Decode replicas={decode_replicas}"
)
# Should get a valid result (not crash)
assert (
decode_replicas >= 1
), f"Should handle correction factor {correction_factor} gracefully"
@pytest.mark.nightly
@pytest.mark.gpu_2
......@@ -608,23 +726,23 @@ class TestReplicaCalculation:
# Run calculation
asyncio.run(planner.make_adjustments())
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
prefill_replicas = call_args.get("VllmPrefillWorker", 1)
decode_replicas = call_args.get("VllmDecodeWorker", 1)
print(
f"Multi-GPU test: P={prefill_replicas} (expected ~{expected_prefill_replicas}), D={decode_replicas} (expected ~{expected_decode_replicas})"
)
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
print(
f"Multi-GPU test: P={prefill_replicas} (expected ~{expected_prefill_replicas}), D={decode_replicas} (expected ~{expected_decode_replicas})"
)
# Verify calculations account for multiple GPUs per engine
assert prefill_replicas == max(
expected_prefill_replicas, planner.args.min_endpoint
)
assert decode_replicas == max(
expected_decode_replicas, planner.args.min_endpoint
)
# Verify calculations account for multiple GPUs per engine
assert prefill_replicas == max(
expected_prefill_replicas, planner.args.min_endpoint
)
assert decode_replicas == max(
expected_decode_replicas, planner.args.min_endpoint
)
@pytest.mark.weekly
@pytest.mark.gpu_2
......@@ -668,31 +786,31 @@ class TestReplicaCalculation:
# Run calculation
asyncio.run(planner.make_adjustments())
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
prefill_replicas = call_args.get("VllmPrefillWorker", 1)
decode_replicas = call_args.get("VllmDecodeWorker", 1)
# Verify total GPU usage doesn't exceed budget
total_gpus = (
prefill_replicas * planner.args.prefill_engine_num_gpu
+ decode_replicas * planner.args.decode_engine_num_gpu
)
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
# Verify total GPU usage doesn't exceed budget
total_gpus = (
prefill_replicas * planner.args.prefill_engine_num_gpu
+ decode_replicas * planner.args.decode_engine_num_gpu
)
print(
f"Complex GPU budget test: P={prefill_replicas}, D={decode_replicas}, Total GPUs={total_gpus}"
)
print(
f"Complex GPU budget test: P={prefill_replicas}, D={decode_replicas}, Total GPUs={total_gpus}"
)
assert (
total_gpus <= planner.args.max_gpu_budget
), "Total GPU usage should not exceed budget"
assert (
prefill_replicas >= planner.args.min_endpoint
), "Should respect min_endpoint for prefill"
assert (
decode_replicas >= planner.args.min_endpoint
), "Should respect min_endpoint for decode"
assert (
total_gpus <= planner.args.max_gpu_budget
), "Total GPU usage should not exceed budget"
assert (
prefill_replicas >= planner.args.min_endpoint
), "Should respect min_endpoint for prefill"
assert (
decode_replicas >= planner.args.min_endpoint
), "Should respect min_endpoint for decode"
# No need for unittest.main() with pytest!
......@@ -15,8 +15,8 @@
import logging
from dynamo.planner.utils.dryrun import run_sla_planner_dryrun
from dynamo.planner.utils.planner_argparse import create_sla_planner_parser
from dynamo.planner.utils.planner_core import Planner
logger = logging.getLogger(__name__)
......@@ -45,5 +45,4 @@ if __name__ == "__main__":
)
args = parser.parse_args()
planner = Planner(None, args, dryrun=True)
planner.dryrun_run()
run_sla_planner_dryrun(args)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import math
import os
from unittest.mock import Mock, patch
import pytest
from dynamo.planner.utils.planner_core import (
DecodePlanner,
PlannerSharedState,
PrefillPlanner,
)
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.planner,
]
@pytest.fixture(autouse=True)
def mock_prometheus_metrics():
with patch("dynamo.planner.utils.planner_core.Gauge") as mock_gauge:
mock_gauge.return_value = Mock()
yield
def _build_args():
args = argparse.Namespace()
args.adjustment_interval = 60
args.prefill_engine_num_gpu = 1
args.decode_engine_num_gpu = 1
args.min_endpoint = 1
args.max_gpu_budget = -1
args.ttft = 500.0
args.itl = 50.0
args.backend = "vllm"
args.no_operation = True
args.no_correction = True
args.metric_pulling_prometheus_endpoint = "http://localhost:9090"
args.metric_reporting_prometheus_port = 0
args.load_predictor = "constant"
args.load_predictor_warmup_trace = None
args.profile_results_dir = os.path.join(
os.path.dirname(__file__),
"..",
"profiling_results",
"H200_TP1P_TP1D",
)
args.environment = "kubernetes"
args.namespace = "test-namespace"
args.mode = "disagg"
return args
def _build_prometheus_client(samples):
client = Mock()
client.get_avg_time_to_first_token.side_effect = [
s["ttft_ms"] / 1000 for s in samples
]
client.get_avg_inter_token_latency.side_effect = [
s["itl_ms"] / 1000 for s in samples
]
client.get_avg_request_count.side_effect = [s["num_req"] for s in samples]
client.get_avg_request_duration.side_effect = [
s["request_duration"] for s in samples
]
client.get_avg_input_sequence_tokens.side_effect = [s["isl"] for s in samples]
client.get_avg_output_sequence_tokens.side_effect = [s["osl"] for s in samples]
return client
def _build_planners(args, prometheus_client):
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(None, args, shared_state=shared_state)
decode_planner = DecodePlanner(None, args, shared_state=shared_state)
prefill_planner.prometheus_api_client = prometheus_client
decode_planner.prometheus_api_client = prometheus_client
prefill_planner.model_name = "test-model"
decode_planner.model_name = "test-model"
async def mock_get_workers_info(require_prefill=True, require_decode=True):
return (
["prefill-0"] if require_prefill else [],
["decode-0"] if require_decode else [],
)
prefill_planner.get_workers_info = mock_get_workers_info
decode_planner.get_workers_info = mock_get_workers_info
return prefill_planner, decode_planner, shared_state
def _expected_prefill(args, prefill_planner, sample):
pred_prefill_throughput = (
sample["num_req"] * sample["isl"] / args.adjustment_interval
)
thpt_per_gpu = prefill_planner.prefill_interpolator.interpolate_thpt_per_gpu(
sample["isl"]
)
expected = math.ceil(
pred_prefill_throughput / thpt_per_gpu / args.prefill_engine_num_gpu
)
return max(expected, args.min_endpoint)
def _expected_decode(args, decode_planner, sample):
(
pred_decode_thpt_per_gpu,
_,
_,
) = decode_planner.decode_interpolator.find_best_throughput_per_gpu(
itl=args.itl, context_length=sample["isl"] + sample["osl"] / 2
)
pred_decode_throughput = (
sample["num_req"] * sample["osl"] / args.adjustment_interval
)
expected = math.ceil(
pred_decode_throughput / pred_decode_thpt_per_gpu / args.decode_engine_num_gpu
)
return max(expected, args.min_endpoint)
def _run_interval(prefill_planner, decode_planner, shared_state):
asyncio.run(
prefill_planner.observe_metrics(require_prefill=True, require_decode=True)
)
decode_planner.update_predictors_from_metrics(shared_state.last_metrics)
next_num_p = prefill_planner.plan_adjustment()
next_num_d = decode_planner.plan_adjustment()
return next_num_p, next_num_d
def test_disagg_scale_up():
args = _build_args()
samples = [
{
"num_req": 10,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
{
"num_req": 5000,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
]
client = _build_prometheus_client(samples)
prefill_planner, decode_planner, shared_state = _build_planners(args, client)
low_p, low_d = _run_interval(prefill_planner, decode_planner, shared_state)
high_p, high_d = _run_interval(prefill_planner, decode_planner, shared_state)
assert low_p == _expected_prefill(args, prefill_planner, samples[0])
assert low_d == _expected_decode(args, decode_planner, samples[0])
assert high_p == _expected_prefill(args, prefill_planner, samples[1])
assert high_d == _expected_decode(args, decode_planner, samples[1])
assert high_p > low_p
assert high_d > low_d
def test_disagg_scale_down():
args = _build_args()
samples = [
{
"num_req": 5000,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
{
"num_req": 10,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
]
client = _build_prometheus_client(samples)
prefill_planner, decode_planner, shared_state = _build_planners(args, client)
high_p, high_d = _run_interval(prefill_planner, decode_planner, shared_state)
low_p, low_d = _run_interval(prefill_planner, decode_planner, shared_state)
assert high_p == _expected_prefill(args, prefill_planner, samples[0])
assert high_d == _expected_decode(args, decode_planner, samples[0])
assert low_p == _expected_prefill(args, prefill_planner, samples[1])
assert low_d == _expected_decode(args, decode_planner, samples[1])
assert low_p < high_p
assert low_d < high_d
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment