Unverified Commit e2f1e04f authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

refactor: separate planner into prefill/decode planner (#5622)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent 77aadb72
......@@ -79,6 +79,7 @@ class SLAPlannerDefaults(BasePlannerDefaults):
kalman_min_points = 5
no_correction = False # disable correction factor, might be useful under some conditions like long cold start time
mode = "disagg" # ["disagg", "prefill", "decode"]
class VllmComponentName:
......
......@@ -113,6 +113,8 @@ class KubernetesConnector(PlannerConnector):
self,
prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None,
require_prefill: bool = True,
require_decode: bool = True,
):
"""
Verify that the deployment contains services with subComponentType prefill and decode and the model name exists.
......@@ -126,6 +128,7 @@ class KubernetesConnector(PlannerConnector):
errors = []
if require_prefill:
try:
get_service_from_sub_component_type_or_name(
deployment,
......@@ -135,6 +138,7 @@ class KubernetesConnector(PlannerConnector):
except PlannerError as e:
errors.append(str(e))
if require_decode:
try:
get_service_from_sub_component_type_or_name(
deployment,
......@@ -145,7 +149,11 @@ class KubernetesConnector(PlannerConnector):
errors.append(str(e))
try:
self.get_model_name(deployment)
self.get_model_name(
deployment,
require_prefill=require_prefill,
require_decode=require_decode,
)
except PlannerError as e:
errors.append(str(e))
......@@ -153,7 +161,12 @@ class KubernetesConnector(PlannerConnector):
if errors:
raise DeploymentValidationError(errors)
def get_model_name(self, deployment: Optional[dict] = None) -> str:
def get_model_name(
self,
deployment: Optional[dict] = None,
require_prefill: bool = True,
require_decode: bool = True,
) -> str:
"""Get the model name from the deployment"""
try:
if deployment is None:
......@@ -163,15 +176,19 @@ class KubernetesConnector(PlannerConnector):
# TODO: benchmarks/profiler/utils/config.py already contains DGD config parsing
# and model name logic, should consolidate
prefill_model_name = None
decode_model_name = None
if require_prefill:
prefill_service = get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.PREFILL,
)
prefill_model_name = prefill_service.get_model_name()
if require_decode:
decode_service = get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.DECODE,
)
prefill_model_name = prefill_service.get_model_name()
decode_model_name = decode_service.get_model_name()
if prefill_model_name is None and decode_model_name is None:
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
from typing import Optional
from dynamo.planner.utils.dryrun_plot_utils import create_dryrun_plot
from dynamo.planner.utils.planner_core import (
DecodePlanner,
PlannerSharedState,
PrefillPlanner,
_apply_component_gpu_budget,
_apply_global_gpu_budget,
)
from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_mooncake
def run_sla_planner_dryrun(args: argparse.Namespace) -> None:
warmup_metrics = None
if getattr(args, "load_predictor_warmup_trace", None):
warmup_metrics = extract_metrics_from_mooncake(
args.load_predictor_warmup_trace,
args.adjustment_interval,
)
metrics = extract_metrics_from_mooncake(args.dataset, args.adjustment_interval)
if not metrics:
raise ValueError("Empty metrics dataset: cannot run dryrun")
mode = getattr(args, "mode", "disagg")
prefill_planner: Optional[PrefillPlanner] = None
decode_planner: Optional[DecodePlanner] = None
if mode == "disagg":
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(
None, args, dryrun=True, shared_state=shared_state
)
decode_planner = DecodePlanner(
None, args, dryrun=True, shared_state=shared_state
)
elif mode == "prefill":
prefill_planner = PrefillPlanner(None, args, dryrun=True)
elif mode == "decode":
decode_planner = DecodePlanner(None, args, dryrun=True)
else:
raise ValueError(f"Invalid planner mode: {mode}")
def compute_safe_p_thpt(num_p: int, isl: float, ttft: float):
"""safe throughput is maximum throughput that the engine can handle given the TTFT SLA"""
assert prefill_planner is not None
actual_ttft = prefill_planner.prefill_interpolator.interpolate_ttft(isl)
if actual_ttft > ttft:
return 0
return num_p * prefill_planner.prefill_interpolator.interpolate_thpt_per_gpu(
isl
)
def compute_safe_d_thpt(num_d: int, isl: float, osl: float, itl: float):
"""safe throughput is maximum throughput that the engine can handle given the ITL SLA"""
assert decode_planner is not None
(
pred_decode_thpt_per_gpu,
actual_itl,
_,
) = decode_planner.decode_interpolator.find_best_throughput_per_gpu(
itl=itl, context_length=isl + osl / 2
)
if actual_itl > itl:
return 0
return num_d * pred_decode_thpt_per_gpu
time_series = [0]
rr = [metrics[0]["request_count"]]
est_rr = [metrics[0]["request_count"]]
isl = [metrics[0]["avg_isl"]]
est_isl = [metrics[0]["avg_isl"]]
osl = [metrics[0]["avg_osl"]]
est_osl = [metrics[0]["avg_osl"]]
if prefill_planner is not None:
num_p = [args.start_num_p]
p_thpt = [rr[0] * isl[0]]
safe_p_thpt = [
compute_safe_p_thpt(args.start_num_p, isl[0], args.ttft)
* args.adjustment_interval
]
prefill_planner.dryrun_observe_metrics(rr[0], isl[0], osl[0])
else:
num_p = [0]
p_thpt = [0]
safe_p_thpt = [0]
if decode_planner is not None:
num_d = [args.start_num_d]
d_thpt = [rr[0] * osl[0]]
safe_d_thpt = [
compute_safe_d_thpt(args.start_num_d, isl[0], osl[0], args.itl)
* args.adjustment_interval
]
decode_planner.dryrun_observe_metrics(rr[0], isl[0], osl[0])
else:
num_d = [0]
d_thpt = [0]
safe_d_thpt = [0]
predictor_planner = prefill_planner or decode_planner
assert predictor_planner is not None
for metric in metrics[1:]:
# update time
time_series.append(time_series[-1] + args.adjustment_interval)
# load prediction
_est_rr, _est_isl, _est_osl = predictor_planner.predict_load()
est_rr.append(_est_rr)
est_isl.append(_est_isl)
est_osl.append(_est_osl)
# compute num_p and num_d
_num_p = (
prefill_planner._compute_replica_requirements(_est_rr, _est_isl, _est_osl)
if prefill_planner is not None
else 0
)
_num_d = (
decode_planner._compute_replica_requirements(_est_rr, _est_isl, _est_osl)
if decode_planner is not None
else 0
)
# apply GPU budget
if prefill_planner is not None and decode_planner is not None:
_num_p, _num_d = _apply_global_gpu_budget(_num_p, _num_d, args)
elif prefill_planner is not None:
_num_p = _apply_component_gpu_budget(
_num_p, args.prefill_engine_num_gpu, args
)
elif decode_planner is not None:
_num_d = _apply_component_gpu_budget(
_num_d, args.decode_engine_num_gpu, args
)
num_p.append(_num_p)
num_d.append(_num_d)
# update load predictor
for planner in [prefill_planner, decode_planner]:
if planner is not None:
planner.dryrun_observe_metrics(
metric["request_count"], metric["avg_isl"], metric["avg_osl"]
)
# fill in ground truth
rr.append(metric["request_count"])
isl.append(metric["avg_isl"])
osl.append(metric["avg_osl"])
p_thpt.append(rr[-1] * isl[-1] if prefill_planner is not None else 0)
d_thpt.append(rr[-1] * osl[-1] if decode_planner is not None else 0)
safe_p_thpt.append(
compute_safe_p_thpt(num_p[-1], isl[-1], args.ttft)
* args.adjustment_interval
if prefill_planner is not None
else 0
)
safe_d_thpt.append(
compute_safe_d_thpt(num_d[-1], isl[-1], osl[-1], args.itl)
* args.adjustment_interval
if decode_planner is not None
else 0
)
warmup_time = None
warmup_rr = None
warmup_isl = None
warmup_osl = None
if warmup_metrics:
interval = args.adjustment_interval
n = len(warmup_metrics)
warmup_time = [-(n - i) * interval for i in range(n)]
warmup_rr = [m["request_count"] for m in warmup_metrics]
warmup_isl = [m["avg_isl"] for m in warmup_metrics]
warmup_osl = [m["avg_osl"] for m in warmup_metrics]
create_dryrun_plot(
time=time_series,
rr=rr,
est_rr=est_rr,
isl=isl,
est_isl=est_isl,
osl=osl,
est_osl=est_osl,
num_p=num_p,
p_thpt=p_thpt,
safe_p_thpt=safe_p_thpt,
num_d=num_d,
d_thpt=d_thpt,
safe_d_thpt=safe_d_thpt,
output_path=args.output_plot,
warmup_time=warmup_time,
warmup_rr=warmup_rr,
warmup_isl=warmup_isl,
warmup_osl=warmup_osl,
)
......@@ -42,6 +42,12 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
choices=["vllm", "sglang", "trtllm", "mocker"],
help="Backend type",
)
parser.add_argument(
"--mode",
default=SLAPlannerDefaults.mode,
choices=["disagg", "prefill", "decode"],
help="Planner mode: disagg (prefill+decode), prefill-only, or decode-only",
)
parser.add_argument(
"--no-operation",
action="store_true",
......@@ -61,7 +67,7 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
"--max-gpu-budget",
type=int,
default=SLAPlannerDefaults.max_gpu_budget,
help="Maximum GPU budget",
help="Maximum GPU budget (-1 for no budget enforcement)",
)
parser.add_argument(
"--min-endpoint",
......
......@@ -6,7 +6,7 @@ import asyncio
import logging
import math
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional
from prometheus_client import Gauge, start_http_server
......@@ -119,15 +119,110 @@ class PlannerPrometheusMetrics:
self.gpu_hours = Gauge(f"{prefix}:gpu_hours", "Cumulative GPU hours used")
class Planner:
@dataclass
class PlannerSharedState:
last_metrics: Metrics = field(default_factory=Metrics)
p_endpoints: list = field(default_factory=list)
d_endpoints: list = field(default_factory=list)
cumulative_gpu_hours: float = 0.0
last_adjustment_time: float = 0.0
def _apply_global_gpu_budget(
next_num_p: int, next_num_d: int, args: argparse.Namespace
) -> tuple[int, int]:
"""Apply GPU budget constraint to both prefill and decode replicas.
When total GPUs required (num_p * prefill_gpus + num_d * decode_gpus) exceeds the
budget, scale down both proportionally using scale = budget / total_required. Prefill
replicas are clamped to [min_endpoint, max_prefill] where max_prefill reserves enough
GPUs for min_endpoint decode replicas. Remaining budget is then allocated to decode.
Returns (0, 0) if budget cannot satisfy min_endpoint for both components.
"""
if args.max_gpu_budget < 0:
return next_num_p, next_num_d
total_gpu_required = (
next_num_p * args.prefill_engine_num_gpu
+ next_num_d * args.decode_engine_num_gpu
)
if total_gpu_required <= args.max_gpu_budget:
return next_num_p, next_num_d
min_required = (
args.min_endpoint * args.prefill_engine_num_gpu
+ args.min_endpoint * args.decode_engine_num_gpu
)
if args.max_gpu_budget < min_required:
logger.warning(
f"max_gpu_budget ({args.max_gpu_budget}) is below the minimum required "
f"for min_endpoint ({min_required}); enforcing zero replicas"
)
return 0, 0
scale = args.max_gpu_budget / total_gpu_required
max_prefill = math.floor(
(args.max_gpu_budget - args.min_endpoint * args.decode_engine_num_gpu)
/ args.prefill_engine_num_gpu
)
next_num_p = max(
args.min_endpoint, min(max_prefill, math.floor(next_num_p * scale))
)
remaining = args.max_gpu_budget - next_num_p * args.prefill_engine_num_gpu
next_num_d = max(
args.min_endpoint, math.floor(remaining / args.decode_engine_num_gpu)
)
logger.warning(
f"Total number of GPUs required ({total_gpu_required}) exceeds the max GPU budget ({args.max_gpu_budget}), "
f"scaling down to {next_num_p} prefill and {next_num_d} decode replicas"
)
return next_num_p, next_num_d
def _apply_component_gpu_budget(
desired_replicas: int, engine_num_gpu: int, args: argparse.Namespace
) -> int:
"""Apply GPU budget constraint to a single component (prefill-only or decode-only).
When total GPUs required (replicas * gpus_per_replica) exceeds the budget, scale down
using scale = budget / total_required, floored and clamped to at least min_endpoint.
Returns 0 if budget cannot satisfy min_endpoint replicas.
"""
if args.max_gpu_budget < 0:
return desired_replicas
total_gpu_required = desired_replicas * engine_num_gpu
if total_gpu_required <= args.max_gpu_budget:
return desired_replicas
min_required = args.min_endpoint * engine_num_gpu
if args.max_gpu_budget < min_required:
logger.warning(
f"max_gpu_budget ({args.max_gpu_budget}) is below the minimum required "
f"for min_endpoint ({min_required}); enforcing zero replicas"
)
return 0
scale = args.max_gpu_budget / total_gpu_required
next_num = max(args.min_endpoint, math.floor(desired_replicas * scale))
logger.warning(
f"Total number of GPUs required ({total_gpu_required}) exceeds the max GPU budget ({args.max_gpu_budget}), "
f"scaling down to {next_num} replicas"
)
return next_num
class BasePlanner:
component_type: SubComponentType
def __init__(
self,
runtime: Optional[DistributedRuntime],
args: argparse.Namespace,
dryrun: bool = False,
shared_state: Optional[PlannerSharedState] = None,
prometheus_metrics: Optional[PlannerPrometheusMetrics] = None,
prometheus_api_client: Optional[PrometheusAPIClient] = None,
connector=None,
start_prometheus_server: bool = True,
):
self.args = args
self.dryrun = dryrun
self.shared_state = shared_state or PlannerSharedState()
# Rely on getting model name from connector
self.model_name: Optional[str] = None
......@@ -137,7 +232,9 @@ class Planner:
self.namespace = args.namespace
if not args.no_operation:
if args.environment == "kubernetes":
if connector is not None:
self.connector = connector
elif args.environment == "kubernetes":
self.connector = KubernetesConnector(
self.namespace, self.model_name
)
......@@ -150,7 +247,7 @@ class Planner:
else:
raise ValueError(f"Invalid environment: {args.environment}")
self.prometheus_api_client = PrometheusAPIClient(
self.prometheus_api_client = prometheus_api_client or PrometheusAPIClient(
args.metric_pulling_prometheus_endpoint,
args.namespace,
)
......@@ -231,22 +328,16 @@ class Planner:
if not self.dryrun:
self.prefill_client = None
self.workers_client = None
self.p_endpoints = [] # type: ignore
self.d_endpoints = [] # type: ignore
self.last_adjustment_time = time.time()
self.last_metrics = Metrics()
self.prometheus_port = args.metric_reporting_prometheus_port
# Initialize Prometheus metrics
if prometheus_metrics is None:
self.prometheus_metrics = PlannerPrometheusMetrics()
# Track cumulative GPU hours
self.cumulative_gpu_hours = 0.0
else:
self.prometheus_metrics = prometheus_metrics
# Start Prometheus HTTP server if port is specified
if self.prometheus_port != 0:
if start_prometheus_server and self.prometheus_port != 0:
try:
start_http_server(self.prometheus_port)
logger.info(
......@@ -254,6 +345,9 @@ class Planner:
)
except Exception as e:
logger.error(f"Failed to start Prometheus metrics server: {e}")
else:
self.prometheus_port = 0
self.prometheus_metrics = prometheus_metrics
self.p_correction_factor = 1.0
self.d_correction_factor = 1.0
......@@ -262,6 +356,14 @@ class Planner:
else:
self.no_correction = args.no_correction
@property
def last_metrics(self) -> Metrics:
return self.shared_state.last_metrics
@last_metrics.setter
def last_metrics(self, value: Metrics) -> None:
self.shared_state.last_metrics = value
async def _async_init(self):
"""Async initialization for components that need it"""
if (
......@@ -271,80 +373,94 @@ class Planner:
):
await self.connector._async_init()
async def get_workers_info(self):
if self.runtime is None:
raise RuntimeError("Runtime is not initialized")
async def _get_model_name(self, require_prefill: bool, require_decode: bool) -> str:
model_name = self.connector.get_model_name(
require_prefill=require_prefill, require_decode=require_decode
)
if asyncio.iscoroutine(model_name):
model_name = await model_name
return model_name
try:
if self.prefill_client is None:
self.prefill_client = (
async def _get_or_create_client(self, component_name: str, endpoint_name: str):
"""Create a client for the given component and endpoint, with a brief sleep for state sync."""
client = (
await self.runtime.namespace(self.namespace)
.component(
WORKER_COMPONENT_NAMES[
self.args.backend
].prefill_worker_component_name
)
.endpoint(
WORKER_COMPONENT_NAMES[
self.args.backend
].prefill_worker_endpoint
)
.component(component_name)
.endpoint(endpoint_name)
.client()
)
# TODO: remove this sleep after rust client() is blocking until watching state
await asyncio.sleep(0.1)
# TODO: use etcd events instead of pulling instance_ids
return client
async def get_workers_info(
self, require_prefill: bool = True, require_decode: bool = True
):
if self.runtime is None:
raise RuntimeError("Runtime is not initialized")
p_endpoints = []
d_endpoints = []
worker_names = WORKER_COMPONENT_NAMES[self.args.backend]
if require_prefill:
try:
if self.prefill_client is None:
self.prefill_client = await self._get_or_create_client(
worker_names.prefill_worker_component_name,
worker_names.prefill_worker_endpoint,
)
p_endpoints = self.prefill_client.instance_ids() # type: ignore
except Exception:
p_endpoints = []
logger.warning(
"No prefill workers found, aggregated mode is not supported yet"
)
if require_decode:
try:
if self.workers_client is None:
self.workers_client = (
await self.runtime.namespace(self.namespace)
.component(
WORKER_COMPONENT_NAMES[
self.args.backend
].decode_worker_component_name
self.workers_client = await self._get_or_create_client(
worker_names.decode_worker_component_name,
worker_names.decode_worker_endpoint,
)
.endpoint(
WORKER_COMPONENT_NAMES[self.args.backend].decode_worker_endpoint
)
.client()
)
# TODO: remove this sleep after rust client() is blocking until watching state
await asyncio.sleep(0.1)
# TODO: use etcd events instead of pulling instance_ids
d_endpoints = self.workers_client.instance_ids() # type: ignore
except Exception as e:
raise RuntimeError(f"Failed to get decode worker endpoints: {e}")
return p_endpoints, d_endpoints
async def observe_metrics(self):
self.p_endpoints, self.d_endpoints = await self.get_workers_info()
async def observe_metrics(
self, require_prefill: bool = True, require_decode: bool = True
):
p_endpoints, d_endpoints = await self.get_workers_info(
require_prefill=require_prefill, require_decode=require_decode
)
self.shared_state.p_endpoints = p_endpoints
self.shared_state.d_endpoints = d_endpoints
logger.debug(
f"Number of prefill workers: {len(self.p_endpoints)}, number of decode workers: {len(self.d_endpoints)}"
f"Number of prefill workers: {len(p_endpoints)}, number of decode workers: {len(d_endpoints)}"
)
# Update Prometheus metrics if server is running
if self.prometheus_port != 0:
self.prometheus_metrics.num_p_workers.set(len(self.p_endpoints))
self.prometheus_metrics.num_d_workers.set(len(self.d_endpoints))
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.num_p_workers.set(len(p_endpoints))
self.prometheus_metrics.num_d_workers.set(len(d_endpoints))
# Calculate and accumulate GPU hours for this interval
# TODO: track startup and shutdown times to get more accurate GPU hours
interval_gpu_hours = (
(
len(self.p_endpoints) * self.args.prefill_engine_num_gpu
+ len(self.d_endpoints) * self.args.decode_engine_num_gpu
len(p_endpoints) * self.args.prefill_engine_num_gpu
+ len(d_endpoints) * self.args.decode_engine_num_gpu
)
* self.args.adjustment_interval
/ 3600
)
self.cumulative_gpu_hours += interval_gpu_hours
self.prometheus_metrics.gpu_hours.set(self.cumulative_gpu_hours)
self.shared_state.cumulative_gpu_hours += interval_gpu_hours
self.prometheus_metrics.gpu_hours.set(
self.shared_state.cumulative_gpu_hours
)
# Prometheus returns seconds, convert to milliseconds
self.last_metrics.ttft = (
......@@ -392,7 +508,7 @@ class Planner:
)
# Update observed metrics in Prometheus
if self.prometheus_port != 0:
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.observed_ttft.set(self.last_metrics.ttft)
self.prometheus_metrics.observed_itl.set(self.last_metrics.itl)
self.prometheus_metrics.observed_request_rate.set(
......@@ -404,9 +520,12 @@ class Planner:
self.prometheus_metrics.observed_isl.set(self.last_metrics.isl)
self.prometheus_metrics.observed_osl.set(self.last_metrics.osl)
self.num_req_predictor.add_data_point(self.last_metrics.num_req)
self.isl_predictor.add_data_point(self.last_metrics.isl)
self.osl_predictor.add_data_point(self.last_metrics.osl)
self.update_predictors_from_metrics(self.last_metrics)
def update_predictors_from_metrics(self, metrics: Metrics) -> None:
self.num_req_predictor.add_data_point(metrics.num_req)
self.isl_predictor.add_data_point(metrics.isl)
self.osl_predictor.add_data_point(metrics.osl)
def predict_load(self):
try:
......@@ -427,153 +546,28 @@ class Planner:
self.isl_predictor.add_data_point(isl_avg)
self.osl_predictor.add_data_point(osl_avg)
def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float
) -> tuple[int, int]:
"""Compute the number of prefill and decode replicas needed based on predicted load.
Args:
next_num_req: Predicted number of requests
next_isl: Predicted input sequence length
next_osl: Predicted output sequence length
Returns:
tuple[int, int]: Number of prefill and decode replicas needed
"""
# compute how many replicas are needed for prefill
# here we assume the prefill bias is purely due to request queueing
# and we increase the number of prefill replicas linearly to account for the queueing delay
pred_prefill_throughput = (
next_num_req
* next_isl
/ self.args.adjustment_interval
* min(1, self.p_correction_factor)
)
next_num_p = math.ceil(
pred_prefill_throughput
/ self.prefill_interpolator.interpolate_thpt_per_gpu(next_isl)
/ self.args.prefill_engine_num_gpu
)
logger.info(
f"Prefill calculation: {pred_prefill_throughput:.2f}(p_thpt) / "
f"{self.prefill_interpolator.interpolate_thpt_per_gpu(next_isl) * self.args.prefill_engine_num_gpu:.2f}(p_engine_cap) = "
f"{next_num_p}(num_p)"
)
# compute how many replicas are needed for decode
# 1. apply d_correction_factor to the ITL SLA
# Prevent divide by zero when d_correction_factor is 0 (no metrics yet)
if self.d_correction_factor <= 0:
logger.warning(
f"d_correction_factor is {self.d_correction_factor}, using default value of 1.0"
)
corrected_itl = self.args.itl
else:
corrected_itl = self.args.itl / self.d_correction_factor
# 2. reversely find out what is best throughput/gpu that can achieve corrected_itl under the predicted context length
(
pred_decode_thpt_per_gpu,
_,
_,
) = self.decode_interpolator.find_best_throughput_per_gpu(
itl=corrected_itl, context_length=next_isl + next_osl / 2
)
# 3. compute number of decode replicas needed
pred_decode_throughput = next_num_req * next_osl / self.args.adjustment_interval
next_num_d = math.ceil(
pred_decode_throughput
/ pred_decode_thpt_per_gpu
/ self.args.decode_engine_num_gpu
)
logger.info(
f"Decode calculation: {pred_decode_throughput:.2f}(d_thpt) / "
f"{pred_decode_thpt_per_gpu * self.args.decode_engine_num_gpu:.2f}(d_engine_cap) = "
f"{next_num_d}(num_d)"
)
# correct num_p and num_d based on the gpu budget
next_num_p = max(next_num_p, self.args.min_endpoint)
next_num_d = max(next_num_d, self.args.min_endpoint)
logger.info(
f"Predicted number of engine replicas: prefill={next_num_p}, decode={next_num_d}"
)
total_gpu_required = (
next_num_p * self.args.prefill_engine_num_gpu
+ next_num_d * self.args.decode_engine_num_gpu
)
if total_gpu_required > self.args.max_gpu_budget:
scale = self.args.max_gpu_budget / total_gpu_required
next_num_p = max(self.args.min_endpoint, round(next_num_p * scale))
next_num_d = max(
self.args.min_endpoint,
round(
(
self.args.max_gpu_budget
- next_num_p * self.args.prefill_engine_num_gpu
)
/ self.args.decode_engine_num_gpu
),
)
logger.warning(
f"Total number of GPUs required ({total_gpu_required}) exceeds the max GPU budget ({self.args.max_gpu_budget}), scaling down to {next_num_p} prefill and {next_num_d} decode replicas"
)
return next_num_p, next_num_d
async def make_adjustments(self):
def plan_adjustment(self) -> Optional[int]:
# Skip adjustment if no traffic
if not self.last_metrics.is_valid():
logger.info(
"Metrics contain None or NaN values (no active requests), skipping adjustment"
)
return
return None
if not self.no_correction:
try:
self.p_endpoints, self.d_endpoints = await self.get_workers_info()
logger.info(
f"Number of prefill workers: {len(self.p_endpoints)}, number of decode workers: {len(self.d_endpoints)}"
)
# first correct the prediction correction factor
# for TTFT, we expect the correction factor to be << 1 due to queuing delay
expect_ttft = self.prefill_interpolator.interpolate_ttft(
self.last_metrics.isl
)
self.p_correction_factor = self.last_metrics.ttft / expect_ttft
# for ITL, we expect the correction factor to be close to 1
expect_itl = self.decode_interpolator.interpolate_itl(
concurrency=self.last_metrics.num_req # type: ignore
/ len(self.d_endpoints)
* self.last_metrics.request_duration # type: ignore
/ self.args.adjustment_interval,
context_length=self.last_metrics.isl + self.last_metrics.osl / 2, # type: ignore
)
self.d_correction_factor = self.last_metrics.itl / expect_itl
logger.info(
f"Correction factors: TTFT: {self.p_correction_factor:.3f}, ITL: {self.d_correction_factor:.3f}"
)
# Update correction factor metrics in Prometheus
if self.prometheus_port != 0:
self.prometheus_metrics.p_correction_factor.set(
self.p_correction_factor
)
self.prometheus_metrics.d_correction_factor.set(
self.d_correction_factor
)
if not self._update_correction_factor():
return None
except Exception as e:
logger.error(f"Failed to correct prediction factors: {e}")
return
return None
next_num_req, next_isl, next_osl = self.predict_load()
if next_num_req is None or next_isl is None or next_osl is None:
return None
if next_num_req is not None and next_isl is not None and next_osl is not None:
# Update predicted load metrics in Prometheus
if self.prometheus_port != 0:
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.predicted_request_rate.set(
next_num_req / self.args.adjustment_interval
)
......@@ -581,216 +575,308 @@ class Planner:
self.prometheus_metrics.predicted_osl.set(next_osl)
try:
next_num_p, next_num_d = self._compute_replica_requirements(
next_num_req, next_isl, next_osl
)
# Update predicted replica metrics in Prometheus
if self.prometheus_port != 0:
self.prometheus_metrics.predicted_num_p.set(next_num_p)
self.prometheus_metrics.predicted_num_d.set(next_num_d)
return self._compute_replica_requirements(next_num_req, next_isl, next_osl)
except Exception as e:
logger.error(f"Failed to compute number of replicas: {e}")
return
return None
if not self.args.no_operation:
def update_predicted_replicas_metric(self, desired_replicas: int) -> None:
raise NotImplementedError
def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float
) -> int:
raise NotImplementedError
def _update_correction_factor(self) -> bool:
raise NotImplementedError
def _component_name(self) -> str:
if self.component_type == SubComponentType.PREFILL:
return self.prefill_component_name
return self.decode_component_name
def _engine_num_gpu(self) -> int:
if self.component_type == SubComponentType.PREFILL:
return self.args.prefill_engine_num_gpu
return self.args.decode_engine_num_gpu
def apply_component_budget(self, desired_replicas: int) -> int:
return _apply_component_gpu_budget(
desired_replicas, self._engine_num_gpu(), self.args
)
async def _apply_scaling(self, desired_replicas: int) -> None:
if self.args.no_operation:
return
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_component_name,
desired_replicas=next_num_p,
),
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.decode_component_name,
desired_replicas=next_num_d,
),
sub_component_type=self.component_type,
component_name=self._component_name(),
desired_replicas=desired_replicas,
)
]
await self.connector.set_component_replicas(target_replicas, blocking=False)
async def run(self):
"""Main loop for the planner"""
if not self.args.no_operation:
# Fail fast if the deployment is not valid
logger.info("Validating deployment...")
# TODO: still supporting framework component names for backwards compatibility
# Should be deprecated in favor of service subComponentType
require_prefill = self.component_type == SubComponentType.PREFILL
require_decode = self.component_type == SubComponentType.DECODE
await self.connector.validate_deployment(
prefill_component_name=self.prefill_component_name,
decode_component_name=self.decode_component_name,
prefill_component_name=(
self.prefill_component_name if require_prefill else None
),
decode_component_name=(
self.decode_component_name if require_decode else None
),
require_prefill=require_prefill,
require_decode=require_decode,
)
logger.info("Successfully validated the deployment")
await self.connector.wait_for_deployment_ready()
model_name = self.connector.get_model_name()
model_name = await self._get_model_name(
require_prefill=require_prefill, require_decode=require_decode
)
logger.info(f"Detected model name from deployment: {model_name}")
self.model_name = (
model_name.lower()
) # normalize model name to lowercase (MDC)
self.last_adjustment_time = time.time()
self.shared_state.last_adjustment_time = time.time()
while True:
current_time = time.time()
if (
current_time - self.last_adjustment_time
current_time - self.shared_state.last_adjustment_time
>= self.args.adjustment_interval
):
self.last_adjustment_time = time.time()
self.shared_state.last_adjustment_time = time.time()
logger.info("New adjustment interval started!")
await self.observe_metrics()
await self.make_adjustments()
await self.observe_metrics(
require_prefill=require_prefill, require_decode=require_decode
)
desired_replicas = self.plan_adjustment()
if desired_replicas is not None:
desired_replicas = self.apply_component_budget(desired_replicas)
self.update_predicted_replicas_metric(desired_replicas)
await self._apply_scaling(desired_replicas)
# sleep for a while to avoid busy-waiting but not too long to miss the next adjustment
await asyncio.sleep(self.args.adjustment_interval / 10)
def dryrun_run(self):
"""Run planner in dry-run mode with dataset"""
warmup_metrics = None
if getattr(self.args, "load_predictor_warmup_trace", None):
warmup_metrics = extract_metrics_from_mooncake(
self.args.load_predictor_warmup_trace,
self.args.adjustment_interval,
class PrefillPlanner(BasePlanner):
component_type = SubComponentType.PREFILL
def _update_correction_factor(self) -> bool:
expect_ttft = self.prefill_interpolator.interpolate_ttft(self.last_metrics.isl)
self.p_correction_factor = self.last_metrics.ttft / expect_ttft
logger.info(f"Correction factor (prefill TTFT): {self.p_correction_factor:.3f}")
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.p_correction_factor.set(self.p_correction_factor)
return True
def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float
) -> int:
pred_prefill_throughput = (
next_num_req
* next_isl
/ self.args.adjustment_interval
* min(1, self.p_correction_factor)
)
p_thpt_per_gpu = self.prefill_interpolator.interpolate_thpt_per_gpu(next_isl)
next_num_p = math.ceil(
pred_prefill_throughput / p_thpt_per_gpu / self.args.prefill_engine_num_gpu
)
next_num_p = max(next_num_p, self.args.min_endpoint)
logger.info(
f"Prefill calculation: {pred_prefill_throughput:.2f}(p_thpt) / "
f"{p_thpt_per_gpu * self.args.prefill_engine_num_gpu:.2f}(p_engine_cap) = "
f"{next_num_p}(num_p)"
)
return next_num_p
metrics = extract_metrics_from_mooncake(
self.args.dataset, self.args.adjustment_interval
def update_predicted_replicas_metric(self, desired_replicas: int) -> None:
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.predicted_num_p.set(desired_replicas)
class DecodePlanner(BasePlanner):
component_type = SubComponentType.DECODE
def _update_correction_factor(self) -> bool:
if not self.shared_state.d_endpoints:
logger.warning(
"No decode workers found for correction factor, skipping correction update"
)
return True
expect_itl = self.decode_interpolator.interpolate_itl(
concurrency=self.last_metrics.num_req # type: ignore
/ len(self.shared_state.d_endpoints)
* self.last_metrics.request_duration # type: ignore
/ self.args.adjustment_interval,
context_length=self.last_metrics.isl + self.last_metrics.osl / 2, # type: ignore
)
self.d_correction_factor = self.last_metrics.itl / expect_itl
logger.info(f"Correction factor (decode ITL): {self.d_correction_factor:.3f}")
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.d_correction_factor.set(self.d_correction_factor)
return True
def compute_safe_p_thpt(num_p: int, isl: float, ttft: float):
"""safe throughput is maximum throughput that the engine can handle given the TTFT SLA"""
actual_ttft = self.prefill_interpolator.interpolate_ttft(isl)
if actual_ttft > ttft:
return 0
def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float
) -> int:
if self.d_correction_factor <= 0:
logger.warning(
f"d_correction_factor is {self.d_correction_factor}, using default value of 1.0"
)
corrected_itl = self.args.itl
else:
return num_p * self.prefill_interpolator.interpolate_thpt_per_gpu(isl)
def compute_safe_d_thpt(num_d: int, isl: float, osl: float, itl: float):
"""safe throughput is maximum throughput that the engine can handle given the ITL SLA"""
corrected_itl = self.args.itl / self.d_correction_factor
(
pred_decode_thpt_per_gpu,
actual_itl,
_,
_,
) = self.decode_interpolator.find_best_throughput_per_gpu(
itl=itl, context_length=isl + osl / 2
)
if actual_itl > itl:
return 0
else:
return num_d * pred_decode_thpt_per_gpu
time = [0]
rr = [metrics[0]["request_count"]]
est_rr = [metrics[0]["request_count"]]
isl = [metrics[0]["avg_isl"]]
est_isl = [metrics[0]["avg_isl"]]
osl = [metrics[0]["avg_osl"]]
est_osl = [metrics[0]["avg_osl"]]
num_p = [self.args.start_num_p]
p_thpt = [metrics[0]["request_count"] * metrics[0]["avg_isl"]]
safe_p_thpt = [
compute_safe_p_thpt(
self.args.start_num_p, metrics[0]["avg_isl"], self.args.ttft
itl=corrected_itl, context_length=next_isl + next_osl / 2
)
* self.args.adjustment_interval
]
num_d = [self.args.start_num_d]
d_thpt = [metrics[0]["request_count"] * metrics[0]["avg_osl"]]
safe_d_thpt = [
compute_safe_d_thpt(
self.args.start_num_d,
metrics[0]["avg_isl"],
metrics[0]["avg_osl"],
self.args.itl,
pred_decode_throughput = next_num_req * next_osl / self.args.adjustment_interval
next_num_d = math.ceil(
pred_decode_throughput
/ pred_decode_thpt_per_gpu
/ self.args.decode_engine_num_gpu
)
* self.args.adjustment_interval
]
self.dryrun_observe_metrics(
metrics[0]["request_count"], metrics[0]["avg_isl"], metrics[0]["avg_osl"]
next_num_d = max(next_num_d, self.args.min_endpoint)
logger.info(
f"Decode calculation: {pred_decode_throughput:.2f}(d_thpt) / "
f"{pred_decode_thpt_per_gpu * self.args.decode_engine_num_gpu:.2f}(d_engine_cap) = "
f"{next_num_d}(num_d)"
)
return next_num_d
for metric in metrics[1:]:
# update time
time.append(time[-1] + self.args.adjustment_interval)
def update_predicted_replicas_metric(self, desired_replicas: int) -> None:
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.predicted_num_d.set(desired_replicas)
# load prediction
_est_rr, _est_isl, _est_osl = self.predict_load()
est_rr.append(_est_rr)
est_isl.append(_est_isl)
est_osl.append(_est_osl)
# compute num_p and num_d
_num_p, _num_d = self._compute_replica_requirements(
_est_rr, _est_isl, _est_osl
class DisaggPlanner:
def __init__(
self, runtime: Optional[DistributedRuntime], args: argparse.Namespace
) -> None:
self.args = args
self.shared_state = PlannerSharedState()
prometheus_metrics = PlannerPrometheusMetrics()
self.prefill_planner = PrefillPlanner(
runtime,
args,
shared_state=self.shared_state,
prometheus_metrics=prometheus_metrics,
start_prometheus_server=True,
)
self.decode_planner = DecodePlanner(
runtime,
args,
shared_state=self.shared_state,
prometheus_metrics=prometheus_metrics,
prometheus_api_client=getattr(
self.prefill_planner, "prometheus_api_client", None
),
connector=getattr(self.prefill_planner, "connector", None),
start_prometheus_server=False,
)
num_p.append(_num_p)
num_d.append(_num_d)
# update load predictor
self.dryrun_observe_metrics(
metric["request_count"], metric["avg_isl"], metric["avg_osl"]
async def _async_init(self):
# Prefill/Decode share the same connector instance in disagg mode.
await self.prefill_planner._async_init()
async def run(self):
if not self.args.no_operation:
logger.info("Validating deployment...")
await self.prefill_planner.connector.validate_deployment(
prefill_component_name=self.prefill_planner.prefill_component_name,
decode_component_name=self.prefill_planner.decode_component_name,
require_prefill=True,
require_decode=True,
)
logger.info("Successfully validated the deployment")
# fill in ground truth
rr.append(metric["request_count"])
isl.append(metric["avg_isl"])
osl.append(metric["avg_osl"])
await self.prefill_planner.connector.wait_for_deployment_ready()
p_thpt.append(rr[-1] * isl[-1])
d_thpt.append(rr[-1] * osl[-1])
model_name = await self.prefill_planner._get_model_name(
require_prefill=True, require_decode=True
)
logger.info(f"Detected model name from deployment: {model_name}")
model_name = model_name.lower()
self.prefill_planner.model_name = model_name
self.decode_planner.model_name = model_name
safe_p_thpt.append(
compute_safe_p_thpt(num_p[-1], isl[-1], self.args.ttft)
* self.args.adjustment_interval
self.shared_state.last_adjustment_time = time.time()
while True:
current_time = time.time()
if (
current_time - self.shared_state.last_adjustment_time
>= self.args.adjustment_interval
):
self.shared_state.last_adjustment_time = time.time()
logger.info("New adjustment interval started!")
await self.prefill_planner.observe_metrics(
require_prefill=True, require_decode=True
)
safe_d_thpt.append(
compute_safe_d_thpt(num_d[-1], isl[-1], osl[-1], self.args.itl)
* self.args.adjustment_interval
self.decode_planner.update_predictors_from_metrics(
self.shared_state.last_metrics
)
next_num_p = self.prefill_planner.plan_adjustment()
next_num_d = self.decode_planner.plan_adjustment()
if next_num_p is None or next_num_d is None:
continue
# plot the results
from dynamo.planner.utils.dryrun_plot_utils import create_dryrun_plot
warmup_time = None
warmup_rr = None
warmup_isl = None
warmup_osl = None
if warmup_metrics:
interval = self.args.adjustment_interval
n = len(warmup_metrics)
warmup_time = [-(n - i) * interval for i in range(n)]
warmup_rr = [m["request_count"] for m in warmup_metrics]
warmup_isl = [m["avg_isl"] for m in warmup_metrics]
warmup_osl = [m["avg_osl"] for m in warmup_metrics]
create_dryrun_plot(
time=time,
rr=rr,
est_rr=est_rr,
isl=isl,
est_isl=est_isl,
osl=osl,
est_osl=est_osl,
num_p=num_p,
p_thpt=p_thpt,
safe_p_thpt=safe_p_thpt,
num_d=num_d,
d_thpt=d_thpt,
safe_d_thpt=safe_d_thpt,
output_path=self.args.output_plot,
warmup_time=warmup_time,
warmup_rr=warmup_rr,
warmup_isl=warmup_isl,
warmup_osl=warmup_osl,
next_num_p, next_num_d = _apply_global_gpu_budget(
next_num_p, next_num_d, self.args
)
self.prefill_planner.update_predicted_replicas_metric(next_num_p)
self.decode_planner.update_predicted_replicas_metric(next_num_d)
if not self.args.no_operation:
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_planner.prefill_component_name,
desired_replicas=next_num_p,
),
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.prefill_planner.decode_component_name,
desired_replicas=next_num_d,
),
]
await self.prefill_planner.connector.set_component_replicas(
target_replicas, blocking=False
)
# sleep for a while to avoid busy-waiting but not too long to miss the next adjustment
await asyncio.sleep(self.args.adjustment_interval / 10)
async def start_sla_planner(runtime: DistributedRuntime, args: argparse.Namespace):
planner = Planner(runtime, args)
mode = getattr(args, "mode", "disagg")
if mode == "disagg":
planner = DisaggPlanner(runtime, args)
elif mode == "prefill":
planner = PrefillPlanner(runtime, args)
elif mode == "decode":
planner = DecodePlanner(runtime, args)
else:
raise ValueError(f"Invalid planner mode: {mode}")
await planner._async_init()
await planner.run()
......@@ -130,6 +130,8 @@ class VirtualConnector(PlannerConnector):
self,
prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None,
require_prefill: bool = True,
require_decode: bool = True,
):
"""Validate the deployment"""
pass
......@@ -138,6 +140,9 @@ class VirtualConnector(PlannerConnector):
"""Wait for the deployment to be ready"""
await self._wait_for_scaling_completion()
async def get_model_name(self) -> str:
async def get_model_name(
self, require_prefill: bool = True, require_decode: bool = True
) -> str:
"""Get the model name from the deployment"""
del require_prefill, require_decode
return self.model_name
......@@ -36,11 +36,132 @@ sys.modules["dynamo.runtime"] = mock_runtime
sys.modules["dynamo.runtime.logging"] = mock_runtime.logging
# Now import after mocking
from dynamo.planner.utils.planner_core import Metrics, Planner # noqa: E402
from dynamo.planner.utils.planner_core import ( # noqa: E402
DecodePlanner,
Metrics,
PlannerSharedState,
PrefillPlanner,
_apply_global_gpu_budget,
)
pytestmark = [pytest.mark.pre_merge, pytest.mark.gpu_0]
class PlannerHarness:
def __init__(self, prefill_planner, decode_planner, shared_state):
self.prefill_planner = prefill_planner
self.decode_planner = decode_planner
self.shared_state = shared_state
self.last_target_replicas = []
async def make_adjustments(self):
if not self.shared_state.last_metrics.is_valid():
return
p_endpoints, d_endpoints = await self.prefill_planner.get_workers_info()
self.shared_state.p_endpoints = p_endpoints
self.shared_state.d_endpoints = d_endpoints
next_num_p = self.prefill_planner.plan_adjustment()
next_num_d = self.decode_planner.plan_adjustment()
if next_num_p is None or next_num_d is None:
return
next_num_p, next_num_d = _apply_global_gpu_budget(
next_num_p, next_num_d, self.prefill_planner.args
)
self.prefill_planner.update_predicted_replicas_metric(next_num_p)
self.decode_planner.update_predicted_replicas_metric(next_num_d)
target_replicas = [
{
"sub_component_type": "prefill",
"component_name": self.prefill_planner.prefill_component_name,
"desired_replicas": next_num_p,
},
{
"sub_component_type": "decode",
"component_name": self.prefill_planner.decode_component_name,
"desired_replicas": next_num_d,
},
]
self.last_target_replicas = target_replicas
if not self.prefill_planner.args.no_operation:
await self.prefill_planner.connector.set_component_replicas(
target_replicas, blocking=False
)
def __getattr__(self, name):
shared_attrs = {
"num_req_predictor",
"isl_predictor",
"osl_predictor",
"connector",
"prometheus_api_client",
"args",
}
prefill_attrs = {
"prefill_interpolator",
"prefill_component_name",
"p_correction_factor",
}
decode_attrs = {
"decode_interpolator",
"decode_component_name",
"d_correction_factor",
}
if name == "last_metrics":
return self.shared_state.last_metrics
if name == "get_workers_info":
return self.prefill_planner.get_workers_info
if name in shared_attrs:
return getattr(self.prefill_planner, name)
if name in prefill_attrs:
return getattr(self.prefill_planner, name)
if name in decode_attrs:
return getattr(self.decode_planner, name)
raise AttributeError(name)
def __setattr__(self, name, value):
if name in {"prefill_planner", "decode_planner", "shared_state"}:
return super().__setattr__(name, value)
shared_attrs = {
"num_req_predictor",
"isl_predictor",
"osl_predictor",
"connector",
"prometheus_api_client",
"args",
"get_workers_info",
}
prefill_attrs = {"prefill_interpolator", "p_correction_factor"}
decode_attrs = {"decode_interpolator", "d_correction_factor"}
if name == "last_metrics":
self.shared_state.last_metrics = value
return None
if name in shared_attrs:
# Store locally to support patch.object lifecycle (set/del).
object.__setattr__(self, name, value)
setattr(self.prefill_planner, name, value)
setattr(self.decode_planner, name, value)
return None
if name in prefill_attrs:
setattr(self.prefill_planner, name, value)
return None
if name in decode_attrs:
setattr(self.decode_planner, name, value)
return None
return super().__setattr__(name, value)
def _replica_count(target_replicas, component_name, default=1):
for replica in target_replicas:
if replica.get("component_name") == component_name:
return replica.get("desired_replicas", default)
return default
@pytest.fixture
def planner():
"""Set up test environment with mocked dependencies."""
......@@ -75,8 +196,10 @@ def planner():
with patch("dynamo.planner.utils.planner_core.Gauge") as mock_gauge:
mock_gauge.return_value = Mock()
# Create planner instance
planner = Planner(mock_runtime, args)
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(mock_runtime, args, shared_state=shared_state)
decode_planner = DecodePlanner(mock_runtime, args, shared_state=shared_state)
planner = PlannerHarness(prefill_planner, decode_planner, shared_state)
# Mock the interpolators to return fixed values for testing
planner.prefill_interpolator = Mock()
......@@ -165,11 +288,10 @@ class TestReplicaCalculation:
# Extract the calculated values from the log calls or by checking the mock calls
# Since we mocked the connector, we can check what replicas were requested
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
prefill_component = "VllmPrefillWorker"
calculated_prefill_replicas = call_args.get(prefill_component, 1)
calculated_prefill_replicas = _replica_count(
planner.last_target_replicas, prefill_component
)
print(f"Expected prefill replicas: {expected_prefill_replicas}")
print(f"Calculated prefill replicas: {calculated_prefill_replicas}")
......@@ -230,11 +352,10 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments())
# Check the results
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
decode_component = "VllmDecodeWorker"
calculated_decode_replicas = call_args.get(decode_component, 1)
calculated_decode_replicas = _replica_count(
planner.last_target_replicas, decode_component
)
print(f"Expected decode replicas: {expected_decode_replicas}")
print(f"Calculated decode replicas: {calculated_decode_replicas}")
......@@ -304,12 +425,12 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments())
# Verify results
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
prefill_replicas = call_args.get("VllmPrefillWorker", 1)
decode_replicas = call_args.get("VllmDecodeWorker", 1)
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
print(f"Load {num_req} req/s: P={prefill_replicas}, D={decode_replicas}")
assert (
......@@ -359,12 +480,12 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments())
# Verify that total GPU usage doesn't exceed budget
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
prefill_replicas = call_args.get("VllmPrefillWorker", 1)
decode_replicas = call_args.get("VllmDecodeWorker", 1)
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
total_gpus = (
prefill_replicas * planner.args.prefill_engine_num_gpu
+ decode_replicas * planner.args.decode_engine_num_gpu
......@@ -417,12 +538,12 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments())
# Verify minimum constraints are respected
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
prefill_replicas = call_args.get("VllmPrefillWorker", 1)
decode_replicas = call_args.get("VllmDecodeWorker", 1)
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
print(f"Min endpoint test: P={prefill_replicas}, D={decode_replicas}")
assert (
......@@ -482,9 +603,9 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments())
# Verify that correction factor was effectively clamped
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
prefill_replicas = call_args.get("VllmPrefillWorker", 1)
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
print(
f"Correction factor clamping test: Expected={expected_prefill_replicas}, Got={prefill_replicas}"
......@@ -501,7 +622,6 @@ class TestReplicaCalculation:
"""Test handling of d_correction_factor <= 0."""
# Test both 0 and negative values
for correction_factor in [0.0, -1.0]:
with patch.object(planner, "connector") as mock_connector:
planner.p_correction_factor = 1.0
planner.d_correction_factor = correction_factor
......@@ -511,9 +631,7 @@ class TestReplicaCalculation:
planner.osl_predictor.predict_next.return_value = 150
# Mock interpolator outputs
planner.prefill_interpolator.interpolate_thpt_per_gpu.return_value = (
40000
)
planner.prefill_interpolator.interpolate_thpt_per_gpu.return_value = 40000
planner.decode_interpolator.find_best_throughput_per_gpu.return_value = (
10000,
0.01,
......@@ -545,9 +663,9 @@ class TestReplicaCalculation:
# Should handle gracefully without crashing
# The code should use args.itl directly instead of dividing by 0
if mock_connector.set_component_replicas.called:
call_args = mock_connector.set_component_replicas.call_args[0][0]
decode_replicas = call_args.get("VllmDecodeWorker", 1)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
print(
f"Correction factor {correction_factor} test: Decode replicas={decode_replicas}"
......@@ -608,12 +726,12 @@ class TestReplicaCalculation:
# Run calculation
asyncio.run(planner.make_adjustments())
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
prefill_replicas = call_args.get("VllmPrefillWorker", 1)
decode_replicas = call_args.get("VllmDecodeWorker", 1)
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
print(
f"Multi-GPU test: P={prefill_replicas} (expected ~{expected_prefill_replicas}), D={decode_replicas} (expected ~{expected_decode_replicas})"
)
......@@ -668,12 +786,12 @@ class TestReplicaCalculation:
# Run calculation
asyncio.run(planner.make_adjustments())
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
prefill_replicas = call_args.get("VllmPrefillWorker", 1)
decode_replicas = call_args.get("VllmDecodeWorker", 1)
prefill_replicas = _replica_count(
planner.last_target_replicas, "VllmPrefillWorker"
)
decode_replicas = _replica_count(
planner.last_target_replicas, "VllmDecodeWorker"
)
# Verify total GPU usage doesn't exceed budget
total_gpus = (
prefill_replicas * planner.args.prefill_engine_num_gpu
......
......@@ -15,8 +15,8 @@
import logging
from dynamo.planner.utils.dryrun import run_sla_planner_dryrun
from dynamo.planner.utils.planner_argparse import create_sla_planner_parser
from dynamo.planner.utils.planner_core import Planner
logger = logging.getLogger(__name__)
......@@ -45,5 +45,4 @@ if __name__ == "__main__":
)
args = parser.parse_args()
planner = Planner(None, args, dryrun=True)
planner.dryrun_run()
run_sla_planner_dryrun(args)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import math
import os
from unittest.mock import Mock, patch
import pytest
from dynamo.planner.utils.planner_core import (
DecodePlanner,
PlannerSharedState,
PrefillPlanner,
)
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.planner,
]
@pytest.fixture(autouse=True)
def mock_prometheus_metrics():
with patch("dynamo.planner.utils.planner_core.Gauge") as mock_gauge:
mock_gauge.return_value = Mock()
yield
def _build_args():
args = argparse.Namespace()
args.adjustment_interval = 60
args.prefill_engine_num_gpu = 1
args.decode_engine_num_gpu = 1
args.min_endpoint = 1
args.max_gpu_budget = -1
args.ttft = 500.0
args.itl = 50.0
args.backend = "vllm"
args.no_operation = True
args.no_correction = True
args.metric_pulling_prometheus_endpoint = "http://localhost:9090"
args.metric_reporting_prometheus_port = 0
args.load_predictor = "constant"
args.load_predictor_warmup_trace = None
args.profile_results_dir = os.path.join(
os.path.dirname(__file__),
"..",
"profiling_results",
"H200_TP1P_TP1D",
)
args.environment = "kubernetes"
args.namespace = "test-namespace"
args.mode = "disagg"
return args
def _build_prometheus_client(samples):
client = Mock()
client.get_avg_time_to_first_token.side_effect = [
s["ttft_ms"] / 1000 for s in samples
]
client.get_avg_inter_token_latency.side_effect = [
s["itl_ms"] / 1000 for s in samples
]
client.get_avg_request_count.side_effect = [s["num_req"] for s in samples]
client.get_avg_request_duration.side_effect = [
s["request_duration"] for s in samples
]
client.get_avg_input_sequence_tokens.side_effect = [s["isl"] for s in samples]
client.get_avg_output_sequence_tokens.side_effect = [s["osl"] for s in samples]
return client
def _build_planners(args, prometheus_client):
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(None, args, shared_state=shared_state)
decode_planner = DecodePlanner(None, args, shared_state=shared_state)
prefill_planner.prometheus_api_client = prometheus_client
decode_planner.prometheus_api_client = prometheus_client
prefill_planner.model_name = "test-model"
decode_planner.model_name = "test-model"
async def mock_get_workers_info(require_prefill=True, require_decode=True):
return (
["prefill-0"] if require_prefill else [],
["decode-0"] if require_decode else [],
)
prefill_planner.get_workers_info = mock_get_workers_info
decode_planner.get_workers_info = mock_get_workers_info
return prefill_planner, decode_planner, shared_state
def _expected_prefill(args, prefill_planner, sample):
pred_prefill_throughput = (
sample["num_req"] * sample["isl"] / args.adjustment_interval
)
thpt_per_gpu = prefill_planner.prefill_interpolator.interpolate_thpt_per_gpu(
sample["isl"]
)
expected = math.ceil(
pred_prefill_throughput / thpt_per_gpu / args.prefill_engine_num_gpu
)
return max(expected, args.min_endpoint)
def _expected_decode(args, decode_planner, sample):
(
pred_decode_thpt_per_gpu,
_,
_,
) = decode_planner.decode_interpolator.find_best_throughput_per_gpu(
itl=args.itl, context_length=sample["isl"] + sample["osl"] / 2
)
pred_decode_throughput = (
sample["num_req"] * sample["osl"] / args.adjustment_interval
)
expected = math.ceil(
pred_decode_throughput / pred_decode_thpt_per_gpu / args.decode_engine_num_gpu
)
return max(expected, args.min_endpoint)
def _run_interval(prefill_planner, decode_planner, shared_state):
asyncio.run(
prefill_planner.observe_metrics(require_prefill=True, require_decode=True)
)
decode_planner.update_predictors_from_metrics(shared_state.last_metrics)
next_num_p = prefill_planner.plan_adjustment()
next_num_d = decode_planner.plan_adjustment()
return next_num_p, next_num_d
def test_disagg_scale_up():
args = _build_args()
samples = [
{
"num_req": 10,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
{
"num_req": 5000,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
]
client = _build_prometheus_client(samples)
prefill_planner, decode_planner, shared_state = _build_planners(args, client)
low_p, low_d = _run_interval(prefill_planner, decode_planner, shared_state)
high_p, high_d = _run_interval(prefill_planner, decode_planner, shared_state)
assert low_p == _expected_prefill(args, prefill_planner, samples[0])
assert low_d == _expected_decode(args, decode_planner, samples[0])
assert high_p == _expected_prefill(args, prefill_planner, samples[1])
assert high_d == _expected_decode(args, decode_planner, samples[1])
assert high_p > low_p
assert high_d > low_d
def test_disagg_scale_down():
args = _build_args()
samples = [
{
"num_req": 5000,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
{
"num_req": 10,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
]
client = _build_prometheus_client(samples)
prefill_planner, decode_planner, shared_state = _build_planners(args, client)
high_p, high_d = _run_interval(prefill_planner, decode_planner, shared_state)
low_p, low_d = _run_interval(prefill_planner, decode_planner, shared_state)
assert high_p == _expected_prefill(args, prefill_planner, samples[0])
assert high_d == _expected_decode(args, decode_planner, samples[0])
assert low_p == _expected_prefill(args, prefill_planner, samples[1])
assert low_d == _expected_decode(args, decode_planner, samples[1])
assert low_p < high_p
assert low_d < high_d
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment