Unverified Commit e2f1e04f authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

refactor: separate planner into prefill/decode planner (#5622)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent 77aadb72
...@@ -79,6 +79,7 @@ class SLAPlannerDefaults(BasePlannerDefaults): ...@@ -79,6 +79,7 @@ class SLAPlannerDefaults(BasePlannerDefaults):
kalman_min_points = 5 kalman_min_points = 5
no_correction = False # disable correction factor, might be useful under some conditions like long cold start time no_correction = False # disable correction factor, might be useful under some conditions like long cold start time
mode = "disagg" # ["disagg", "prefill", "decode"]
class VllmComponentName: class VllmComponentName:
......
...@@ -113,6 +113,8 @@ class KubernetesConnector(PlannerConnector): ...@@ -113,6 +113,8 @@ class KubernetesConnector(PlannerConnector):
self, self,
prefill_component_name: Optional[str] = None, prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None, decode_component_name: Optional[str] = None,
require_prefill: bool = True,
require_decode: bool = True,
): ):
""" """
Verify that the deployment contains services with subComponentType prefill and decode and the model name exists. Verify that the deployment contains services with subComponentType prefill and decode and the model name exists.
...@@ -126,34 +128,45 @@ class KubernetesConnector(PlannerConnector): ...@@ -126,34 +128,45 @@ class KubernetesConnector(PlannerConnector):
errors = [] errors = []
try: if require_prefill:
get_service_from_sub_component_type_or_name( try:
deployment, get_service_from_sub_component_type_or_name(
SubComponentType.PREFILL, deployment,
component_name=prefill_component_name, SubComponentType.PREFILL,
) component_name=prefill_component_name,
except PlannerError as e: )
errors.append(str(e)) except PlannerError as e:
errors.append(str(e))
if require_decode:
try:
get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.DECODE,
component_name=decode_component_name,
)
except PlannerError as e:
errors.append(str(e))
try: try:
get_service_from_sub_component_type_or_name( self.get_model_name(
deployment, deployment,
SubComponentType.DECODE, require_prefill=require_prefill,
component_name=decode_component_name, require_decode=require_decode,
) )
except PlannerError as e: except PlannerError as e:
errors.append(str(e)) errors.append(str(e))
try:
self.get_model_name(deployment)
except PlannerError as e:
errors.append(str(e))
# Raise combined error if any issues found # Raise combined error if any issues found
if errors: if errors:
raise DeploymentValidationError(errors) raise DeploymentValidationError(errors)
def get_model_name(self, deployment: Optional[dict] = None) -> str: def get_model_name(
self,
deployment: Optional[dict] = None,
require_prefill: bool = True,
require_decode: bool = True,
) -> str:
"""Get the model name from the deployment""" """Get the model name from the deployment"""
try: try:
if deployment is None: if deployment is None:
...@@ -163,16 +176,20 @@ class KubernetesConnector(PlannerConnector): ...@@ -163,16 +176,20 @@ class KubernetesConnector(PlannerConnector):
# TODO: benchmarks/profiler/utils/config.py already contains DGD config parsing # TODO: benchmarks/profiler/utils/config.py already contains DGD config parsing
# and model name logic, should consolidate # and model name logic, should consolidate
prefill_service = get_service_from_sub_component_type_or_name( prefill_model_name = None
deployment, decode_model_name = None
SubComponentType.PREFILL, if require_prefill:
) prefill_service = get_service_from_sub_component_type_or_name(
decode_service = get_service_from_sub_component_type_or_name( deployment,
deployment, SubComponentType.PREFILL,
SubComponentType.DECODE, )
) prefill_model_name = prefill_service.get_model_name()
prefill_model_name = prefill_service.get_model_name() if require_decode:
decode_model_name = decode_service.get_model_name() decode_service = get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.DECODE,
)
decode_model_name = decode_service.get_model_name()
if prefill_model_name is None and decode_model_name is None: if prefill_model_name is None and decode_model_name is None:
raise ModelNameNotFoundError() raise ModelNameNotFoundError()
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
from typing import Optional
from dynamo.planner.utils.dryrun_plot_utils import create_dryrun_plot
from dynamo.planner.utils.planner_core import (
DecodePlanner,
PlannerSharedState,
PrefillPlanner,
_apply_component_gpu_budget,
_apply_global_gpu_budget,
)
from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_mooncake
def run_sla_planner_dryrun(args: argparse.Namespace) -> None:
warmup_metrics = None
if getattr(args, "load_predictor_warmup_trace", None):
warmup_metrics = extract_metrics_from_mooncake(
args.load_predictor_warmup_trace,
args.adjustment_interval,
)
metrics = extract_metrics_from_mooncake(args.dataset, args.adjustment_interval)
if not metrics:
raise ValueError("Empty metrics dataset: cannot run dryrun")
mode = getattr(args, "mode", "disagg")
prefill_planner: Optional[PrefillPlanner] = None
decode_planner: Optional[DecodePlanner] = None
if mode == "disagg":
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(
None, args, dryrun=True, shared_state=shared_state
)
decode_planner = DecodePlanner(
None, args, dryrun=True, shared_state=shared_state
)
elif mode == "prefill":
prefill_planner = PrefillPlanner(None, args, dryrun=True)
elif mode == "decode":
decode_planner = DecodePlanner(None, args, dryrun=True)
else:
raise ValueError(f"Invalid planner mode: {mode}")
def compute_safe_p_thpt(num_p: int, isl: float, ttft: float):
"""safe throughput is maximum throughput that the engine can handle given the TTFT SLA"""
assert prefill_planner is not None
actual_ttft = prefill_planner.prefill_interpolator.interpolate_ttft(isl)
if actual_ttft > ttft:
return 0
return num_p * prefill_planner.prefill_interpolator.interpolate_thpt_per_gpu(
isl
)
def compute_safe_d_thpt(num_d: int, isl: float, osl: float, itl: float):
"""safe throughput is maximum throughput that the engine can handle given the ITL SLA"""
assert decode_planner is not None
(
pred_decode_thpt_per_gpu,
actual_itl,
_,
) = decode_planner.decode_interpolator.find_best_throughput_per_gpu(
itl=itl, context_length=isl + osl / 2
)
if actual_itl > itl:
return 0
return num_d * pred_decode_thpt_per_gpu
time_series = [0]
rr = [metrics[0]["request_count"]]
est_rr = [metrics[0]["request_count"]]
isl = [metrics[0]["avg_isl"]]
est_isl = [metrics[0]["avg_isl"]]
osl = [metrics[0]["avg_osl"]]
est_osl = [metrics[0]["avg_osl"]]
if prefill_planner is not None:
num_p = [args.start_num_p]
p_thpt = [rr[0] * isl[0]]
safe_p_thpt = [
compute_safe_p_thpt(args.start_num_p, isl[0], args.ttft)
* args.adjustment_interval
]
prefill_planner.dryrun_observe_metrics(rr[0], isl[0], osl[0])
else:
num_p = [0]
p_thpt = [0]
safe_p_thpt = [0]
if decode_planner is not None:
num_d = [args.start_num_d]
d_thpt = [rr[0] * osl[0]]
safe_d_thpt = [
compute_safe_d_thpt(args.start_num_d, isl[0], osl[0], args.itl)
* args.adjustment_interval
]
decode_planner.dryrun_observe_metrics(rr[0], isl[0], osl[0])
else:
num_d = [0]
d_thpt = [0]
safe_d_thpt = [0]
predictor_planner = prefill_planner or decode_planner
assert predictor_planner is not None
for metric in metrics[1:]:
# update time
time_series.append(time_series[-1] + args.adjustment_interval)
# load prediction
_est_rr, _est_isl, _est_osl = predictor_planner.predict_load()
est_rr.append(_est_rr)
est_isl.append(_est_isl)
est_osl.append(_est_osl)
# compute num_p and num_d
_num_p = (
prefill_planner._compute_replica_requirements(_est_rr, _est_isl, _est_osl)
if prefill_planner is not None
else 0
)
_num_d = (
decode_planner._compute_replica_requirements(_est_rr, _est_isl, _est_osl)
if decode_planner is not None
else 0
)
# apply GPU budget
if prefill_planner is not None and decode_planner is not None:
_num_p, _num_d = _apply_global_gpu_budget(_num_p, _num_d, args)
elif prefill_planner is not None:
_num_p = _apply_component_gpu_budget(
_num_p, args.prefill_engine_num_gpu, args
)
elif decode_planner is not None:
_num_d = _apply_component_gpu_budget(
_num_d, args.decode_engine_num_gpu, args
)
num_p.append(_num_p)
num_d.append(_num_d)
# update load predictor
for planner in [prefill_planner, decode_planner]:
if planner is not None:
planner.dryrun_observe_metrics(
metric["request_count"], metric["avg_isl"], metric["avg_osl"]
)
# fill in ground truth
rr.append(metric["request_count"])
isl.append(metric["avg_isl"])
osl.append(metric["avg_osl"])
p_thpt.append(rr[-1] * isl[-1] if prefill_planner is not None else 0)
d_thpt.append(rr[-1] * osl[-1] if decode_planner is not None else 0)
safe_p_thpt.append(
compute_safe_p_thpt(num_p[-1], isl[-1], args.ttft)
* args.adjustment_interval
if prefill_planner is not None
else 0
)
safe_d_thpt.append(
compute_safe_d_thpt(num_d[-1], isl[-1], osl[-1], args.itl)
* args.adjustment_interval
if decode_planner is not None
else 0
)
warmup_time = None
warmup_rr = None
warmup_isl = None
warmup_osl = None
if warmup_metrics:
interval = args.adjustment_interval
n = len(warmup_metrics)
warmup_time = [-(n - i) * interval for i in range(n)]
warmup_rr = [m["request_count"] for m in warmup_metrics]
warmup_isl = [m["avg_isl"] for m in warmup_metrics]
warmup_osl = [m["avg_osl"] for m in warmup_metrics]
create_dryrun_plot(
time=time_series,
rr=rr,
est_rr=est_rr,
isl=isl,
est_isl=est_isl,
osl=osl,
est_osl=est_osl,
num_p=num_p,
p_thpt=p_thpt,
safe_p_thpt=safe_p_thpt,
num_d=num_d,
d_thpt=d_thpt,
safe_d_thpt=safe_d_thpt,
output_path=args.output_plot,
warmup_time=warmup_time,
warmup_rr=warmup_rr,
warmup_isl=warmup_isl,
warmup_osl=warmup_osl,
)
...@@ -42,6 +42,12 @@ def create_sla_planner_parser() -> argparse.ArgumentParser: ...@@ -42,6 +42,12 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
choices=["vllm", "sglang", "trtllm", "mocker"], choices=["vllm", "sglang", "trtllm", "mocker"],
help="Backend type", help="Backend type",
) )
parser.add_argument(
"--mode",
default=SLAPlannerDefaults.mode,
choices=["disagg", "prefill", "decode"],
help="Planner mode: disagg (prefill+decode), prefill-only, or decode-only",
)
parser.add_argument( parser.add_argument(
"--no-operation", "--no-operation",
action="store_true", action="store_true",
...@@ -61,7 +67,7 @@ def create_sla_planner_parser() -> argparse.ArgumentParser: ...@@ -61,7 +67,7 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
"--max-gpu-budget", "--max-gpu-budget",
type=int, type=int,
default=SLAPlannerDefaults.max_gpu_budget, default=SLAPlannerDefaults.max_gpu_budget,
help="Maximum GPU budget", help="Maximum GPU budget (-1 for no budget enforcement)",
) )
parser.add_argument( parser.add_argument(
"--min-endpoint", "--min-endpoint",
......
...@@ -6,7 +6,7 @@ import asyncio ...@@ -6,7 +6,7 @@ import asyncio
import logging import logging
import math import math
import time import time
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional from typing import Optional
from prometheus_client import Gauge, start_http_server from prometheus_client import Gauge, start_http_server
...@@ -119,15 +119,110 @@ class PlannerPrometheusMetrics: ...@@ -119,15 +119,110 @@ class PlannerPrometheusMetrics:
self.gpu_hours = Gauge(f"{prefix}:gpu_hours", "Cumulative GPU hours used") self.gpu_hours = Gauge(f"{prefix}:gpu_hours", "Cumulative GPU hours used")
class Planner: @dataclass
class PlannerSharedState:
last_metrics: Metrics = field(default_factory=Metrics)
p_endpoints: list = field(default_factory=list)
d_endpoints: list = field(default_factory=list)
cumulative_gpu_hours: float = 0.0
last_adjustment_time: float = 0.0
def _apply_global_gpu_budget(
next_num_p: int, next_num_d: int, args: argparse.Namespace
) -> tuple[int, int]:
"""Apply GPU budget constraint to both prefill and decode replicas.
When total GPUs required (num_p * prefill_gpus + num_d * decode_gpus) exceeds the
budget, scale down both proportionally using scale = budget / total_required. Prefill
replicas are clamped to [min_endpoint, max_prefill] where max_prefill reserves enough
GPUs for min_endpoint decode replicas. Remaining budget is then allocated to decode.
Returns (0, 0) if budget cannot satisfy min_endpoint for both components.
"""
if args.max_gpu_budget < 0:
return next_num_p, next_num_d
total_gpu_required = (
next_num_p * args.prefill_engine_num_gpu
+ next_num_d * args.decode_engine_num_gpu
)
if total_gpu_required <= args.max_gpu_budget:
return next_num_p, next_num_d
min_required = (
args.min_endpoint * args.prefill_engine_num_gpu
+ args.min_endpoint * args.decode_engine_num_gpu
)
if args.max_gpu_budget < min_required:
logger.warning(
f"max_gpu_budget ({args.max_gpu_budget}) is below the minimum required "
f"for min_endpoint ({min_required}); enforcing zero replicas"
)
return 0, 0
scale = args.max_gpu_budget / total_gpu_required
max_prefill = math.floor(
(args.max_gpu_budget - args.min_endpoint * args.decode_engine_num_gpu)
/ args.prefill_engine_num_gpu
)
next_num_p = max(
args.min_endpoint, min(max_prefill, math.floor(next_num_p * scale))
)
remaining = args.max_gpu_budget - next_num_p * args.prefill_engine_num_gpu
next_num_d = max(
args.min_endpoint, math.floor(remaining / args.decode_engine_num_gpu)
)
logger.warning(
f"Total number of GPUs required ({total_gpu_required}) exceeds the max GPU budget ({args.max_gpu_budget}), "
f"scaling down to {next_num_p} prefill and {next_num_d} decode replicas"
)
return next_num_p, next_num_d
def _apply_component_gpu_budget(
desired_replicas: int, engine_num_gpu: int, args: argparse.Namespace
) -> int:
"""Apply GPU budget constraint to a single component (prefill-only or decode-only).
When total GPUs required (replicas * gpus_per_replica) exceeds the budget, scale down
using scale = budget / total_required, floored and clamped to at least min_endpoint.
Returns 0 if budget cannot satisfy min_endpoint replicas.
"""
if args.max_gpu_budget < 0:
return desired_replicas
total_gpu_required = desired_replicas * engine_num_gpu
if total_gpu_required <= args.max_gpu_budget:
return desired_replicas
min_required = args.min_endpoint * engine_num_gpu
if args.max_gpu_budget < min_required:
logger.warning(
f"max_gpu_budget ({args.max_gpu_budget}) is below the minimum required "
f"for min_endpoint ({min_required}); enforcing zero replicas"
)
return 0
scale = args.max_gpu_budget / total_gpu_required
next_num = max(args.min_endpoint, math.floor(desired_replicas * scale))
logger.warning(
f"Total number of GPUs required ({total_gpu_required}) exceeds the max GPU budget ({args.max_gpu_budget}), "
f"scaling down to {next_num} replicas"
)
return next_num
class BasePlanner:
component_type: SubComponentType
def __init__( def __init__(
self, self,
runtime: Optional[DistributedRuntime], runtime: Optional[DistributedRuntime],
args: argparse.Namespace, args: argparse.Namespace,
dryrun: bool = False, dryrun: bool = False,
shared_state: Optional[PlannerSharedState] = None,
prometheus_metrics: Optional[PlannerPrometheusMetrics] = None,
prometheus_api_client: Optional[PrometheusAPIClient] = None,
connector=None,
start_prometheus_server: bool = True,
): ):
self.args = args self.args = args
self.dryrun = dryrun self.dryrun = dryrun
self.shared_state = shared_state or PlannerSharedState()
# Rely on getting model name from connector # Rely on getting model name from connector
self.model_name: Optional[str] = None self.model_name: Optional[str] = None
...@@ -137,7 +232,9 @@ class Planner: ...@@ -137,7 +232,9 @@ class Planner:
self.namespace = args.namespace self.namespace = args.namespace
if not args.no_operation: if not args.no_operation:
if args.environment == "kubernetes": if connector is not None:
self.connector = connector
elif args.environment == "kubernetes":
self.connector = KubernetesConnector( self.connector = KubernetesConnector(
self.namespace, self.model_name self.namespace, self.model_name
) )
...@@ -150,7 +247,7 @@ class Planner: ...@@ -150,7 +247,7 @@ class Planner:
else: else:
raise ValueError(f"Invalid environment: {args.environment}") raise ValueError(f"Invalid environment: {args.environment}")
self.prometheus_api_client = PrometheusAPIClient( self.prometheus_api_client = prometheus_api_client or PrometheusAPIClient(
args.metric_pulling_prometheus_endpoint, args.metric_pulling_prometheus_endpoint,
args.namespace, args.namespace,
) )
...@@ -231,22 +328,16 @@ class Planner: ...@@ -231,22 +328,16 @@ class Planner:
if not self.dryrun: if not self.dryrun:
self.prefill_client = None self.prefill_client = None
self.workers_client = None self.workers_client = None
self.p_endpoints = [] # type: ignore
self.d_endpoints = [] # type: ignore
self.last_adjustment_time = time.time()
self.last_metrics = Metrics()
self.prometheus_port = args.metric_reporting_prometheus_port self.prometheus_port = args.metric_reporting_prometheus_port
# Initialize Prometheus metrics if prometheus_metrics is None:
self.prometheus_metrics = PlannerPrometheusMetrics() self.prometheus_metrics = PlannerPrometheusMetrics()
else:
# Track cumulative GPU hours self.prometheus_metrics = prometheus_metrics
self.cumulative_gpu_hours = 0.0
# Start Prometheus HTTP server if port is specified # Start Prometheus HTTP server if port is specified
if self.prometheus_port != 0: if start_prometheus_server and self.prometheus_port != 0:
try: try:
start_http_server(self.prometheus_port) start_http_server(self.prometheus_port)
logger.info( logger.info(
...@@ -254,6 +345,9 @@ class Planner: ...@@ -254,6 +345,9 @@ class Planner:
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to start Prometheus metrics server: {e}") logger.error(f"Failed to start Prometheus metrics server: {e}")
else:
self.prometheus_port = 0
self.prometheus_metrics = prometheus_metrics
self.p_correction_factor = 1.0 self.p_correction_factor = 1.0
self.d_correction_factor = 1.0 self.d_correction_factor = 1.0
...@@ -262,6 +356,14 @@ class Planner: ...@@ -262,6 +356,14 @@ class Planner:
else: else:
self.no_correction = args.no_correction self.no_correction = args.no_correction
@property
def last_metrics(self) -> Metrics:
return self.shared_state.last_metrics
@last_metrics.setter
def last_metrics(self, value: Metrics) -> None:
self.shared_state.last_metrics = value
async def _async_init(self): async def _async_init(self):
"""Async initialization for components that need it""" """Async initialization for components that need it"""
if ( if (
...@@ -271,80 +373,94 @@ class Planner: ...@@ -271,80 +373,94 @@ class Planner:
): ):
await self.connector._async_init() await self.connector._async_init()
async def get_workers_info(self): async def _get_model_name(self, require_prefill: bool, require_decode: bool) -> str:
model_name = self.connector.get_model_name(
require_prefill=require_prefill, require_decode=require_decode
)
if asyncio.iscoroutine(model_name):
model_name = await model_name
return model_name
async def _get_or_create_client(self, component_name: str, endpoint_name: str):
"""Create a client for the given component and endpoint, with a brief sleep for state sync."""
client = (
await self.runtime.namespace(self.namespace)
.component(component_name)
.endpoint(endpoint_name)
.client()
)
# TODO: remove this sleep after rust client() is blocking until watching state
await asyncio.sleep(0.1)
return client
async def get_workers_info(
self, require_prefill: bool = True, require_decode: bool = True
):
if self.runtime is None: if self.runtime is None:
raise RuntimeError("Runtime is not initialized") raise RuntimeError("Runtime is not initialized")
try: p_endpoints = []
if self.prefill_client is None: d_endpoints = []
self.prefill_client = ( worker_names = WORKER_COMPONENT_NAMES[self.args.backend]
await self.runtime.namespace(self.namespace)
.component( if require_prefill:
WORKER_COMPONENT_NAMES[ try:
self.args.backend if self.prefill_client is None:
].prefill_worker_component_name self.prefill_client = await self._get_or_create_client(
) worker_names.prefill_worker_component_name,
.endpoint( worker_names.prefill_worker_endpoint,
WORKER_COMPONENT_NAMES[
self.args.backend
].prefill_worker_endpoint
) )
.client() p_endpoints = self.prefill_client.instance_ids() # type: ignore
except Exception:
p_endpoints = []
logger.warning(
"No prefill workers found, aggregated mode is not supported yet"
) )
# TODO: remove this sleep after rust client() is blocking until watching state
await asyncio.sleep(0.1) if require_decode:
# TODO: use etcd events instead of pulling instance_ids try:
p_endpoints = self.prefill_client.instance_ids() # type: ignore if self.workers_client is None:
except Exception: self.workers_client = await self._get_or_create_client(
p_endpoints = [] worker_names.decode_worker_component_name,
logger.warning( worker_names.decode_worker_endpoint,
"No prefill workers found, aggregated mode is not supported yet"
)
try:
if self.workers_client is None:
self.workers_client = (
await self.runtime.namespace(self.namespace)
.component(
WORKER_COMPONENT_NAMES[
self.args.backend
].decode_worker_component_name
)
.endpoint(
WORKER_COMPONENT_NAMES[self.args.backend].decode_worker_endpoint
) )
.client() d_endpoints = self.workers_client.instance_ids() # type: ignore
) except Exception as e:
# TODO: remove this sleep after rust client() is blocking until watching state raise RuntimeError(f"Failed to get decode worker endpoints: {e}")
await asyncio.sleep(0.1)
# TODO: use etcd events instead of pulling instance_ids
d_endpoints = self.workers_client.instance_ids() # type: ignore
except Exception as e:
raise RuntimeError(f"Failed to get decode worker endpoints: {e}")
return p_endpoints, d_endpoints return p_endpoints, d_endpoints
async def observe_metrics(self): async def observe_metrics(
self.p_endpoints, self.d_endpoints = await self.get_workers_info() self, require_prefill: bool = True, require_decode: bool = True
):
p_endpoints, d_endpoints = await self.get_workers_info(
require_prefill=require_prefill, require_decode=require_decode
)
self.shared_state.p_endpoints = p_endpoints
self.shared_state.d_endpoints = d_endpoints
logger.debug( logger.debug(
f"Number of prefill workers: {len(self.p_endpoints)}, number of decode workers: {len(self.d_endpoints)}" f"Number of prefill workers: {len(p_endpoints)}, number of decode workers: {len(d_endpoints)}"
) )
# Update Prometheus metrics if server is running # Update Prometheus metrics if server is running
if self.prometheus_port != 0: if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.num_p_workers.set(len(self.p_endpoints)) self.prometheus_metrics.num_p_workers.set(len(p_endpoints))
self.prometheus_metrics.num_d_workers.set(len(self.d_endpoints)) self.prometheus_metrics.num_d_workers.set(len(d_endpoints))
# Calculate and accumulate GPU hours for this interval # Calculate and accumulate GPU hours for this interval
# TODO: track startup and shutdown times to get more accurate GPU hours # TODO: track startup and shutdown times to get more accurate GPU hours
interval_gpu_hours = ( interval_gpu_hours = (
( (
len(self.p_endpoints) * self.args.prefill_engine_num_gpu len(p_endpoints) * self.args.prefill_engine_num_gpu
+ len(self.d_endpoints) * self.args.decode_engine_num_gpu + len(d_endpoints) * self.args.decode_engine_num_gpu
) )
* self.args.adjustment_interval * self.args.adjustment_interval
/ 3600 / 3600
) )
self.cumulative_gpu_hours += interval_gpu_hours self.shared_state.cumulative_gpu_hours += interval_gpu_hours
self.prometheus_metrics.gpu_hours.set(self.cumulative_gpu_hours) self.prometheus_metrics.gpu_hours.set(
self.shared_state.cumulative_gpu_hours
)
# Prometheus returns seconds, convert to milliseconds # Prometheus returns seconds, convert to milliseconds
self.last_metrics.ttft = ( self.last_metrics.ttft = (
...@@ -392,7 +508,7 @@ class Planner: ...@@ -392,7 +508,7 @@ class Planner:
) )
# Update observed metrics in Prometheus # Update observed metrics in Prometheus
if self.prometheus_port != 0: if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.observed_ttft.set(self.last_metrics.ttft) self.prometheus_metrics.observed_ttft.set(self.last_metrics.ttft)
self.prometheus_metrics.observed_itl.set(self.last_metrics.itl) self.prometheus_metrics.observed_itl.set(self.last_metrics.itl)
self.prometheus_metrics.observed_request_rate.set( self.prometheus_metrics.observed_request_rate.set(
...@@ -404,9 +520,12 @@ class Planner: ...@@ -404,9 +520,12 @@ class Planner:
self.prometheus_metrics.observed_isl.set(self.last_metrics.isl) self.prometheus_metrics.observed_isl.set(self.last_metrics.isl)
self.prometheus_metrics.observed_osl.set(self.last_metrics.osl) self.prometheus_metrics.observed_osl.set(self.last_metrics.osl)
self.num_req_predictor.add_data_point(self.last_metrics.num_req) self.update_predictors_from_metrics(self.last_metrics)
self.isl_predictor.add_data_point(self.last_metrics.isl)
self.osl_predictor.add_data_point(self.last_metrics.osl) def update_predictors_from_metrics(self, metrics: Metrics) -> None:
self.num_req_predictor.add_data_point(metrics.num_req)
self.isl_predictor.add_data_point(metrics.isl)
self.osl_predictor.add_data_point(metrics.osl)
def predict_load(self): def predict_load(self):
try: try:
...@@ -427,43 +546,193 @@ class Planner: ...@@ -427,43 +546,193 @@ class Planner:
self.isl_predictor.add_data_point(isl_avg) self.isl_predictor.add_data_point(isl_avg)
self.osl_predictor.add_data_point(osl_avg) self.osl_predictor.add_data_point(osl_avg)
def plan_adjustment(self) -> Optional[int]:
# Skip adjustment if no traffic
if not self.last_metrics.is_valid():
logger.info(
"Metrics contain None or NaN values (no active requests), skipping adjustment"
)
return None
if not self.no_correction:
try:
if not self._update_correction_factor():
return None
except Exception as e:
logger.error(f"Failed to correct prediction factors: {e}")
return None
next_num_req, next_isl, next_osl = self.predict_load()
if next_num_req is None or next_isl is None or next_osl is None:
return None
# Update predicted load metrics in Prometheus
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.predicted_request_rate.set(
next_num_req / self.args.adjustment_interval
)
self.prometheus_metrics.predicted_isl.set(next_isl)
self.prometheus_metrics.predicted_osl.set(next_osl)
try:
return self._compute_replica_requirements(next_num_req, next_isl, next_osl)
except Exception as e:
logger.error(f"Failed to compute number of replicas: {e}")
return None
def update_predicted_replicas_metric(self, desired_replicas: int) -> None:
raise NotImplementedError
def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float
) -> int:
raise NotImplementedError
def _update_correction_factor(self) -> bool:
raise NotImplementedError
def _component_name(self) -> str:
if self.component_type == SubComponentType.PREFILL:
return self.prefill_component_name
return self.decode_component_name
def _engine_num_gpu(self) -> int:
if self.component_type == SubComponentType.PREFILL:
return self.args.prefill_engine_num_gpu
return self.args.decode_engine_num_gpu
def apply_component_budget(self, desired_replicas: int) -> int:
return _apply_component_gpu_budget(
desired_replicas, self._engine_num_gpu(), self.args
)
async def _apply_scaling(self, desired_replicas: int) -> None:
if self.args.no_operation:
return
target_replicas = [
TargetReplica(
sub_component_type=self.component_type,
component_name=self._component_name(),
desired_replicas=desired_replicas,
)
]
await self.connector.set_component_replicas(target_replicas, blocking=False)
async def run(self):
"""Main loop for the planner"""
if not self.args.no_operation:
logger.info("Validating deployment...")
require_prefill = self.component_type == SubComponentType.PREFILL
require_decode = self.component_type == SubComponentType.DECODE
await self.connector.validate_deployment(
prefill_component_name=(
self.prefill_component_name if require_prefill else None
),
decode_component_name=(
self.decode_component_name if require_decode else None
),
require_prefill=require_prefill,
require_decode=require_decode,
)
logger.info("Successfully validated the deployment")
await self.connector.wait_for_deployment_ready()
model_name = await self._get_model_name(
require_prefill=require_prefill, require_decode=require_decode
)
logger.info(f"Detected model name from deployment: {model_name}")
self.model_name = (
model_name.lower()
) # normalize model name to lowercase (MDC)
self.shared_state.last_adjustment_time = time.time()
while True:
current_time = time.time()
if (
current_time - self.shared_state.last_adjustment_time
>= self.args.adjustment_interval
):
self.shared_state.last_adjustment_time = time.time()
logger.info("New adjustment interval started!")
await self.observe_metrics(
require_prefill=require_prefill, require_decode=require_decode
)
desired_replicas = self.plan_adjustment()
if desired_replicas is not None:
desired_replicas = self.apply_component_budget(desired_replicas)
self.update_predicted_replicas_metric(desired_replicas)
await self._apply_scaling(desired_replicas)
# sleep for a while to avoid busy-waiting but not too long to miss the next adjustment
await asyncio.sleep(self.args.adjustment_interval / 10)
class PrefillPlanner(BasePlanner):
component_type = SubComponentType.PREFILL
def _update_correction_factor(self) -> bool:
expect_ttft = self.prefill_interpolator.interpolate_ttft(self.last_metrics.isl)
self.p_correction_factor = self.last_metrics.ttft / expect_ttft
logger.info(f"Correction factor (prefill TTFT): {self.p_correction_factor:.3f}")
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.p_correction_factor.set(self.p_correction_factor)
return True
def _compute_replica_requirements( def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float self, next_num_req: float, next_isl: float, next_osl: float
) -> tuple[int, int]: ) -> int:
"""Compute the number of prefill and decode replicas needed based on predicted load.
Args:
next_num_req: Predicted number of requests
next_isl: Predicted input sequence length
next_osl: Predicted output sequence length
Returns:
tuple[int, int]: Number of prefill and decode replicas needed
"""
# compute how many replicas are needed for prefill
# here we assume the prefill bias is purely due to request queueing
# and we increase the number of prefill replicas linearly to account for the queueing delay
pred_prefill_throughput = ( pred_prefill_throughput = (
next_num_req next_num_req
* next_isl * next_isl
/ self.args.adjustment_interval / self.args.adjustment_interval
* min(1, self.p_correction_factor) * min(1, self.p_correction_factor)
) )
p_thpt_per_gpu = self.prefill_interpolator.interpolate_thpt_per_gpu(next_isl)
next_num_p = math.ceil( next_num_p = math.ceil(
pred_prefill_throughput pred_prefill_throughput / p_thpt_per_gpu / self.args.prefill_engine_num_gpu
/ self.prefill_interpolator.interpolate_thpt_per_gpu(next_isl)
/ self.args.prefill_engine_num_gpu
) )
next_num_p = max(next_num_p, self.args.min_endpoint)
logger.info( logger.info(
f"Prefill calculation: {pred_prefill_throughput:.2f}(p_thpt) / " f"Prefill calculation: {pred_prefill_throughput:.2f}(p_thpt) / "
f"{self.prefill_interpolator.interpolate_thpt_per_gpu(next_isl) * self.args.prefill_engine_num_gpu:.2f}(p_engine_cap) = " f"{p_thpt_per_gpu * self.args.prefill_engine_num_gpu:.2f}(p_engine_cap) = "
f"{next_num_p}(num_p)" f"{next_num_p}(num_p)"
) )
return next_num_p
def update_predicted_replicas_metric(self, desired_replicas: int) -> None:
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.predicted_num_p.set(desired_replicas)
# compute how many replicas are needed for decode class DecodePlanner(BasePlanner):
# 1. apply d_correction_factor to the ITL SLA component_type = SubComponentType.DECODE
# Prevent divide by zero when d_correction_factor is 0 (no metrics yet)
def _update_correction_factor(self) -> bool:
if not self.shared_state.d_endpoints:
logger.warning(
"No decode workers found for correction factor, skipping correction update"
)
return True
expect_itl = self.decode_interpolator.interpolate_itl(
concurrency=self.last_metrics.num_req # type: ignore
/ len(self.shared_state.d_endpoints)
* self.last_metrics.request_duration # type: ignore
/ self.args.adjustment_interval,
context_length=self.last_metrics.isl + self.last_metrics.osl / 2, # type: ignore
)
self.d_correction_factor = self.last_metrics.itl / expect_itl
logger.info(f"Correction factor (decode ITL): {self.d_correction_factor:.3f}")
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.d_correction_factor.set(self.d_correction_factor)
return True
def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float
) -> int:
if self.d_correction_factor <= 0: if self.d_correction_factor <= 0:
logger.warning( logger.warning(
f"d_correction_factor is {self.d_correction_factor}, using default value of 1.0" f"d_correction_factor is {self.d_correction_factor}, using default value of 1.0"
...@@ -471,7 +740,6 @@ class Planner: ...@@ -471,7 +740,6 @@ class Planner:
corrected_itl = self.args.itl corrected_itl = self.args.itl
else: else:
corrected_itl = self.args.itl / self.d_correction_factor corrected_itl = self.args.itl / self.d_correction_factor
# 2. reversely find out what is best throughput/gpu that can achieve corrected_itl under the predicted context length
( (
pred_decode_thpt_per_gpu, pred_decode_thpt_per_gpu,
_, _,
...@@ -479,318 +747,136 @@ class Planner: ...@@ -479,318 +747,136 @@ class Planner:
) = self.decode_interpolator.find_best_throughput_per_gpu( ) = self.decode_interpolator.find_best_throughput_per_gpu(
itl=corrected_itl, context_length=next_isl + next_osl / 2 itl=corrected_itl, context_length=next_isl + next_osl / 2
) )
# 3. compute number of decode replicas needed
pred_decode_throughput = next_num_req * next_osl / self.args.adjustment_interval pred_decode_throughput = next_num_req * next_osl / self.args.adjustment_interval
next_num_d = math.ceil( next_num_d = math.ceil(
pred_decode_throughput pred_decode_throughput
/ pred_decode_thpt_per_gpu / pred_decode_thpt_per_gpu
/ self.args.decode_engine_num_gpu / self.args.decode_engine_num_gpu
) )
next_num_d = max(next_num_d, self.args.min_endpoint)
logger.info( logger.info(
f"Decode calculation: {pred_decode_throughput:.2f}(d_thpt) / " f"Decode calculation: {pred_decode_throughput:.2f}(d_thpt) / "
f"{pred_decode_thpt_per_gpu * self.args.decode_engine_num_gpu:.2f}(d_engine_cap) = " f"{pred_decode_thpt_per_gpu * self.args.decode_engine_num_gpu:.2f}(d_engine_cap) = "
f"{next_num_d}(num_d)" f"{next_num_d}(num_d)"
) )
return next_num_d
# correct num_p and num_d based on the gpu budget def update_predicted_replicas_metric(self, desired_replicas: int) -> None:
next_num_p = max(next_num_p, self.args.min_endpoint) if self.prometheus_port != 0 and self.prometheus_metrics is not None:
next_num_d = max(next_num_d, self.args.min_endpoint) self.prometheus_metrics.predicted_num_d.set(desired_replicas)
logger.info(
f"Predicted number of engine replicas: prefill={next_num_p}, decode={next_num_d}"
)
total_gpu_required = (
next_num_p * self.args.prefill_engine_num_gpu
+ next_num_d * self.args.decode_engine_num_gpu
)
if total_gpu_required > self.args.max_gpu_budget:
scale = self.args.max_gpu_budget / total_gpu_required
next_num_p = max(self.args.min_endpoint, round(next_num_p * scale))
next_num_d = max(
self.args.min_endpoint,
round(
(
self.args.max_gpu_budget
- next_num_p * self.args.prefill_engine_num_gpu
)
/ self.args.decode_engine_num_gpu
),
)
logger.warning(
f"Total number of GPUs required ({total_gpu_required}) exceeds the max GPU budget ({self.args.max_gpu_budget}), scaling down to {next_num_p} prefill and {next_num_d} decode replicas"
)
return next_num_p, next_num_d
async def make_adjustments(self):
# Skip adjustment if no traffic
if not self.last_metrics.is_valid():
logger.info(
"Metrics contain None or NaN values (no active requests), skipping adjustment"
)
return
if not self.no_correction:
try:
self.p_endpoints, self.d_endpoints = await self.get_workers_info()
logger.info(
f"Number of prefill workers: {len(self.p_endpoints)}, number of decode workers: {len(self.d_endpoints)}"
)
# first correct the prediction correction factor
# for TTFT, we expect the correction factor to be << 1 due to queuing delay
expect_ttft = self.prefill_interpolator.interpolate_ttft(
self.last_metrics.isl
)
self.p_correction_factor = self.last_metrics.ttft / expect_ttft
# for ITL, we expect the correction factor to be close to 1
expect_itl = self.decode_interpolator.interpolate_itl(
concurrency=self.last_metrics.num_req # type: ignore
/ len(self.d_endpoints)
* self.last_metrics.request_duration # type: ignore
/ self.args.adjustment_interval,
context_length=self.last_metrics.isl + self.last_metrics.osl / 2, # type: ignore
)
self.d_correction_factor = self.last_metrics.itl / expect_itl
logger.info(
f"Correction factors: TTFT: {self.p_correction_factor:.3f}, ITL: {self.d_correction_factor:.3f}"
)
# Update correction factor metrics in Prometheus
if self.prometheus_port != 0:
self.prometheus_metrics.p_correction_factor.set(
self.p_correction_factor
)
self.prometheus_metrics.d_correction_factor.set(
self.d_correction_factor
)
except Exception as e:
logger.error(f"Failed to correct prediction factors: {e}")
return
next_num_req, next_isl, next_osl = self.predict_load()
if next_num_req is not None and next_isl is not None and next_osl is not None:
# Update predicted load metrics in Prometheus
if self.prometheus_port != 0:
self.prometheus_metrics.predicted_request_rate.set(
next_num_req / self.args.adjustment_interval
)
self.prometheus_metrics.predicted_isl.set(next_isl)
self.prometheus_metrics.predicted_osl.set(next_osl)
try:
next_num_p, next_num_d = self._compute_replica_requirements(
next_num_req, next_isl, next_osl
)
# Update predicted replica metrics in Prometheus class DisaggPlanner:
if self.prometheus_port != 0: def __init__(
self.prometheus_metrics.predicted_num_p.set(next_num_p) self, runtime: Optional[DistributedRuntime], args: argparse.Namespace
self.prometheus_metrics.predicted_num_d.set(next_num_d) ) -> None:
except Exception as e: self.args = args
logger.error(f"Failed to compute number of replicas: {e}") self.shared_state = PlannerSharedState()
return prometheus_metrics = PlannerPrometheusMetrics()
self.prefill_planner = PrefillPlanner(
runtime,
args,
shared_state=self.shared_state,
prometheus_metrics=prometheus_metrics,
start_prometheus_server=True,
)
self.decode_planner = DecodePlanner(
runtime,
args,
shared_state=self.shared_state,
prometheus_metrics=prometheus_metrics,
prometheus_api_client=getattr(
self.prefill_planner, "prometheus_api_client", None
),
connector=getattr(self.prefill_planner, "connector", None),
start_prometheus_server=False,
)
if not self.args.no_operation: async def _async_init(self):
target_replicas = [ # Prefill/Decode share the same connector instance in disagg mode.
TargetReplica( await self.prefill_planner._async_init()
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_component_name,
desired_replicas=next_num_p,
),
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.decode_component_name,
desired_replicas=next_num_d,
),
]
await self.connector.set_component_replicas(target_replicas, blocking=False)
async def run(self): async def run(self):
"""Main loop for the planner"""
if not self.args.no_operation: if not self.args.no_operation:
# Fail fast if the deployment is not valid
logger.info("Validating deployment...") logger.info("Validating deployment...")
await self.prefill_planner.connector.validate_deployment(
# TODO: still supporting framework component names for backwards compatibility prefill_component_name=self.prefill_planner.prefill_component_name,
# Should be deprecated in favor of service subComponentType decode_component_name=self.prefill_planner.decode_component_name,
await self.connector.validate_deployment( require_prefill=True,
prefill_component_name=self.prefill_component_name, require_decode=True,
decode_component_name=self.decode_component_name,
) )
logger.info("Successfully validated the deployment") logger.info("Successfully validated the deployment")
await self.connector.wait_for_deployment_ready() await self.prefill_planner.connector.wait_for_deployment_ready()
model_name = self.connector.get_model_name() model_name = await self.prefill_planner._get_model_name(
require_prefill=True, require_decode=True
)
logger.info(f"Detected model name from deployment: {model_name}") logger.info(f"Detected model name from deployment: {model_name}")
self.model_name = ( model_name = model_name.lower()
model_name.lower() self.prefill_planner.model_name = model_name
) # normalize model name to lowercase (MDC) self.decode_planner.model_name = model_name
self.last_adjustment_time = time.time() self.shared_state.last_adjustment_time = time.time()
while True: while True:
current_time = time.time() current_time = time.time()
if ( if (
current_time - self.last_adjustment_time current_time - self.shared_state.last_adjustment_time
>= self.args.adjustment_interval >= self.args.adjustment_interval
): ):
self.last_adjustment_time = time.time() self.shared_state.last_adjustment_time = time.time()
logger.info("New adjustment interval started!") logger.info("New adjustment interval started!")
await self.observe_metrics() await self.prefill_planner.observe_metrics(
await self.make_adjustments() require_prefill=True, require_decode=True
)
self.decode_planner.update_predictors_from_metrics(
self.shared_state.last_metrics
)
next_num_p = self.prefill_planner.plan_adjustment()
next_num_d = self.decode_planner.plan_adjustment()
if next_num_p is None or next_num_d is None:
continue
next_num_p, next_num_d = _apply_global_gpu_budget(
next_num_p, next_num_d, self.args
)
self.prefill_planner.update_predicted_replicas_metric(next_num_p)
self.decode_planner.update_predicted_replicas_metric(next_num_d)
if not self.args.no_operation:
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_planner.prefill_component_name,
desired_replicas=next_num_p,
),
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.prefill_planner.decode_component_name,
desired_replicas=next_num_d,
),
]
await self.prefill_planner.connector.set_component_replicas(
target_replicas, blocking=False
)
# sleep for a while to avoid busy-waiting but not too long to miss the next adjustment # sleep for a while to avoid busy-waiting but not too long to miss the next adjustment
await asyncio.sleep(self.args.adjustment_interval / 10) await asyncio.sleep(self.args.adjustment_interval / 10)
def dryrun_run(self):
"""Run planner in dry-run mode with dataset"""
warmup_metrics = None
if getattr(self.args, "load_predictor_warmup_trace", None):
warmup_metrics = extract_metrics_from_mooncake(
self.args.load_predictor_warmup_trace,
self.args.adjustment_interval,
)
metrics = extract_metrics_from_mooncake(
self.args.dataset, self.args.adjustment_interval
)
def compute_safe_p_thpt(num_p: int, isl: float, ttft: float):
"""safe throughput is maximum throughput that the engine can handle given the TTFT SLA"""
actual_ttft = self.prefill_interpolator.interpolate_ttft(isl)
if actual_ttft > ttft:
return 0
else:
return num_p * self.prefill_interpolator.interpolate_thpt_per_gpu(isl)
def compute_safe_d_thpt(num_d: int, isl: float, osl: float, itl: float):
"""safe throughput is maximum throughput that the engine can handle given the ITL SLA"""
(
pred_decode_thpt_per_gpu,
actual_itl,
_,
) = self.decode_interpolator.find_best_throughput_per_gpu(
itl=itl, context_length=isl + osl / 2
)
if actual_itl > itl:
return 0
else:
return num_d * pred_decode_thpt_per_gpu
time = [0]
rr = [metrics[0]["request_count"]]
est_rr = [metrics[0]["request_count"]]
isl = [metrics[0]["avg_isl"]]
est_isl = [metrics[0]["avg_isl"]]
osl = [metrics[0]["avg_osl"]]
est_osl = [metrics[0]["avg_osl"]]
num_p = [self.args.start_num_p]
p_thpt = [metrics[0]["request_count"] * metrics[0]["avg_isl"]]
safe_p_thpt = [
compute_safe_p_thpt(
self.args.start_num_p, metrics[0]["avg_isl"], self.args.ttft
)
* self.args.adjustment_interval
]
num_d = [self.args.start_num_d]
d_thpt = [metrics[0]["request_count"] * metrics[0]["avg_osl"]]
safe_d_thpt = [
compute_safe_d_thpt(
self.args.start_num_d,
metrics[0]["avg_isl"],
metrics[0]["avg_osl"],
self.args.itl,
)
* self.args.adjustment_interval
]
self.dryrun_observe_metrics(
metrics[0]["request_count"], metrics[0]["avg_isl"], metrics[0]["avg_osl"]
)
for metric in metrics[1:]:
# update time
time.append(time[-1] + self.args.adjustment_interval)
# load prediction
_est_rr, _est_isl, _est_osl = self.predict_load()
est_rr.append(_est_rr)
est_isl.append(_est_isl)
est_osl.append(_est_osl)
# compute num_p and num_d
_num_p, _num_d = self._compute_replica_requirements(
_est_rr, _est_isl, _est_osl
)
num_p.append(_num_p)
num_d.append(_num_d)
# update load predictor
self.dryrun_observe_metrics(
metric["request_count"], metric["avg_isl"], metric["avg_osl"]
)
# fill in ground truth
rr.append(metric["request_count"])
isl.append(metric["avg_isl"])
osl.append(metric["avg_osl"])
p_thpt.append(rr[-1] * isl[-1])
d_thpt.append(rr[-1] * osl[-1])
safe_p_thpt.append(
compute_safe_p_thpt(num_p[-1], isl[-1], self.args.ttft)
* self.args.adjustment_interval
)
safe_d_thpt.append(
compute_safe_d_thpt(num_d[-1], isl[-1], osl[-1], self.args.itl)
* self.args.adjustment_interval
)
# plot the results
from dynamo.planner.utils.dryrun_plot_utils import create_dryrun_plot
warmup_time = None
warmup_rr = None
warmup_isl = None
warmup_osl = None
if warmup_metrics:
interval = self.args.adjustment_interval
n = len(warmup_metrics)
warmup_time = [-(n - i) * interval for i in range(n)]
warmup_rr = [m["request_count"] for m in warmup_metrics]
warmup_isl = [m["avg_isl"] for m in warmup_metrics]
warmup_osl = [m["avg_osl"] for m in warmup_metrics]
create_dryrun_plot(
time=time,
rr=rr,
est_rr=est_rr,
isl=isl,
est_isl=est_isl,
osl=osl,
est_osl=est_osl,
num_p=num_p,
p_thpt=p_thpt,
safe_p_thpt=safe_p_thpt,
num_d=num_d,
d_thpt=d_thpt,
safe_d_thpt=safe_d_thpt,
output_path=self.args.output_plot,
warmup_time=warmup_time,
warmup_rr=warmup_rr,
warmup_isl=warmup_isl,
warmup_osl=warmup_osl,
)
async def start_sla_planner(runtime: DistributedRuntime, args: argparse.Namespace): async def start_sla_planner(runtime: DistributedRuntime, args: argparse.Namespace):
planner = Planner(runtime, args) mode = getattr(args, "mode", "disagg")
if mode == "disagg":
planner = DisaggPlanner(runtime, args)
elif mode == "prefill":
planner = PrefillPlanner(runtime, args)
elif mode == "decode":
planner = DecodePlanner(runtime, args)
else:
raise ValueError(f"Invalid planner mode: {mode}")
await planner._async_init() await planner._async_init()
await planner.run() await planner.run()
...@@ -130,6 +130,8 @@ class VirtualConnector(PlannerConnector): ...@@ -130,6 +130,8 @@ class VirtualConnector(PlannerConnector):
self, self,
prefill_component_name: Optional[str] = None, prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None, decode_component_name: Optional[str] = None,
require_prefill: bool = True,
require_decode: bool = True,
): ):
"""Validate the deployment""" """Validate the deployment"""
pass pass
...@@ -138,6 +140,9 @@ class VirtualConnector(PlannerConnector): ...@@ -138,6 +140,9 @@ class VirtualConnector(PlannerConnector):
"""Wait for the deployment to be ready""" """Wait for the deployment to be ready"""
await self._wait_for_scaling_completion() await self._wait_for_scaling_completion()
async def get_model_name(self) -> str: async def get_model_name(
self, require_prefill: bool = True, require_decode: bool = True
) -> str:
"""Get the model name from the deployment""" """Get the model name from the deployment"""
del require_prefill, require_decode
return self.model_name return self.model_name
...@@ -36,11 +36,132 @@ sys.modules["dynamo.runtime"] = mock_runtime ...@@ -36,11 +36,132 @@ sys.modules["dynamo.runtime"] = mock_runtime
sys.modules["dynamo.runtime.logging"] = mock_runtime.logging sys.modules["dynamo.runtime.logging"] = mock_runtime.logging
# Now import after mocking # Now import after mocking
from dynamo.planner.utils.planner_core import Metrics, Planner # noqa: E402 from dynamo.planner.utils.planner_core import ( # noqa: E402
DecodePlanner,
Metrics,
PlannerSharedState,
PrefillPlanner,
_apply_global_gpu_budget,
)
pytestmark = [pytest.mark.pre_merge, pytest.mark.gpu_0] pytestmark = [pytest.mark.pre_merge, pytest.mark.gpu_0]
class PlannerHarness:
def __init__(self, prefill_planner, decode_planner, shared_state):
self.prefill_planner = prefill_planner
self.decode_planner = decode_planner
self.shared_state = shared_state
self.last_target_replicas = []
async def make_adjustments(self):
if not self.shared_state.last_metrics.is_valid():
return
p_endpoints, d_endpoints = await self.prefill_planner.get_workers_info()
self.shared_state.p_endpoints = p_endpoints
self.shared_state.d_endpoints = d_endpoints
next_num_p = self.prefill_planner.plan_adjustment()
next_num_d = self.decode_planner.plan_adjustment()
if next_num_p is None or next_num_d is None:
return
next_num_p, next_num_d = _apply_global_gpu_budget(
next_num_p, next_num_d, self.prefill_planner.args
)
self.prefill_planner.update_predicted_replicas_metric(next_num_p)
self.decode_planner.update_predicted_replicas_metric(next_num_d)
target_replicas = [
{
"sub_component_type": "prefill",
"component_name": self.prefill_planner.prefill_component_name,
"desired_replicas": next_num_p,
},
{
"sub_component_type": "decode",
"component_name": self.prefill_planner.decode_component_name,
"desired_replicas": next_num_d,
},
]
self.last_target_replicas = target_replicas
if not self.prefill_planner.args.no_operation:
await self.prefill_planner.connector.set_component_replicas(
target_replicas, blocking=False
)
def __getattr__(self, name):
shared_attrs = {
"num_req_predictor",
"isl_predictor",
"osl_predictor",
"connector",
"prometheus_api_client",
"args",
}
prefill_attrs = {
"prefill_interpolator",
"prefill_component_name",
"p_correction_factor",
}
decode_attrs = {
"decode_interpolator",
"decode_component_name",
"d_correction_factor",
}
if name == "last_metrics":
return self.shared_state.last_metrics
if name == "get_workers_info":
return self.prefill_planner.get_workers_info
if name in shared_attrs:
return getattr(self.prefill_planner, name)
if name in prefill_attrs:
return getattr(self.prefill_planner, name)
if name in decode_attrs:
return getattr(self.decode_planner, name)
raise AttributeError(name)
def __setattr__(self, name, value):
if name in {"prefill_planner", "decode_planner", "shared_state"}:
return super().__setattr__(name, value)
shared_attrs = {
"num_req_predictor",
"isl_predictor",
"osl_predictor",
"connector",
"prometheus_api_client",
"args",
"get_workers_info",
}
prefill_attrs = {"prefill_interpolator", "p_correction_factor"}
decode_attrs = {"decode_interpolator", "d_correction_factor"}
if name == "last_metrics":
self.shared_state.last_metrics = value
return None
if name in shared_attrs:
# Store locally to support patch.object lifecycle (set/del).
object.__setattr__(self, name, value)
setattr(self.prefill_planner, name, value)
setattr(self.decode_planner, name, value)
return None
if name in prefill_attrs:
setattr(self.prefill_planner, name, value)
return None
if name in decode_attrs:
setattr(self.decode_planner, name, value)
return None
return super().__setattr__(name, value)
def _replica_count(target_replicas, component_name, default=1):
for replica in target_replicas:
if replica.get("component_name") == component_name:
return replica.get("desired_replicas", default)
return default
@pytest.fixture @pytest.fixture
def planner(): def planner():
"""Set up test environment with mocked dependencies.""" """Set up test environment with mocked dependencies."""
...@@ -75,8 +196,10 @@ def planner(): ...@@ -75,8 +196,10 @@ def planner():
with patch("dynamo.planner.utils.planner_core.Gauge") as mock_gauge: with patch("dynamo.planner.utils.planner_core.Gauge") as mock_gauge:
mock_gauge.return_value = Mock() mock_gauge.return_value = Mock()
# Create planner instance shared_state = PlannerSharedState()
planner = Planner(mock_runtime, args) prefill_planner = PrefillPlanner(mock_runtime, args, shared_state=shared_state)
decode_planner = DecodePlanner(mock_runtime, args, shared_state=shared_state)
planner = PlannerHarness(prefill_planner, decode_planner, shared_state)
# Mock the interpolators to return fixed values for testing # Mock the interpolators to return fixed values for testing
planner.prefill_interpolator = Mock() planner.prefill_interpolator = Mock()
...@@ -165,19 +288,18 @@ class TestReplicaCalculation: ...@@ -165,19 +288,18 @@ class TestReplicaCalculation:
# Extract the calculated values from the log calls or by checking the mock calls # Extract the calculated values from the log calls or by checking the mock calls
# Since we mocked the connector, we can check what replicas were requested # Since we mocked the connector, we can check what replicas were requested
if planner.connector.set_component_replicas.called: prefill_component = "VllmPrefillWorker"
call_args = planner.connector.set_component_replicas.call_args[0][0] calculated_prefill_replicas = _replica_count(
prefill_component = "VllmPrefillWorker" planner.last_target_replicas, prefill_component
calculated_prefill_replicas = call_args.get(prefill_component, 1) )
print(f"Expected prefill replicas: {expected_prefill_replicas}")
print(f"Expected prefill replicas: {expected_prefill_replicas}") print(f"Calculated prefill replicas: {calculated_prefill_replicas}")
print(f"Calculated prefill replicas: {calculated_prefill_replicas}")
# Allow for small differences due to min_endpoint constraints # Allow for small differences due to min_endpoint constraints
assert ( assert (
max(expected_prefill_replicas, planner.args.min_endpoint) max(expected_prefill_replicas, planner.args.min_endpoint)
== calculated_prefill_replicas == calculated_prefill_replicas
) )
@pytest.mark.nightly @pytest.mark.nightly
@pytest.mark.gpu_2 @pytest.mark.gpu_2
...@@ -230,19 +352,18 @@ class TestReplicaCalculation: ...@@ -230,19 +352,18 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments()) asyncio.run(planner.make_adjustments())
# Check the results # Check the results
if planner.connector.set_component_replicas.called: decode_component = "VllmDecodeWorker"
call_args = planner.connector.set_component_replicas.call_args[0][0] calculated_decode_replicas = _replica_count(
decode_component = "VllmDecodeWorker" planner.last_target_replicas, decode_component
calculated_decode_replicas = call_args.get(decode_component, 1) )
print(f"Expected decode replicas: {expected_decode_replicas}")
print(f"Expected decode replicas: {expected_decode_replicas}") print(f"Calculated decode replicas: {calculated_decode_replicas}")
print(f"Calculated decode replicas: {calculated_decode_replicas}")
# Allow for small differences due to min_endpoint constraints # Allow for small differences due to min_endpoint constraints
assert ( assert (
max(expected_decode_replicas, planner.args.min_endpoint) max(expected_decode_replicas, planner.args.min_endpoint)
== calculated_decode_replicas == calculated_decode_replicas
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_req,decode_thpt,expected_p,expected_d", "num_req,decode_thpt,expected_p,expected_d",
...@@ -304,20 +425,20 @@ class TestReplicaCalculation: ...@@ -304,20 +425,20 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments()) asyncio.run(planner.make_adjustments())
# Verify results # Verify results
if planner.connector.set_component_replicas.called: prefill_replicas = _replica_count(
call_args = planner.connector.set_component_replicas.call_args[0][0] planner.last_target_replicas, "VllmPrefillWorker"
)
prefill_replicas = call_args.get("VllmPrefillWorker", 1) decode_replicas = _replica_count(
decode_replicas = call_args.get("VllmDecodeWorker", 1) planner.last_target_replicas, "VllmDecodeWorker"
)
print(f"Load {num_req} req/s: P={prefill_replicas}, D={decode_replicas}") print(f"Load {num_req} req/s: P={prefill_replicas}, D={decode_replicas}")
assert ( assert (
prefill_replicas == expected_p prefill_replicas == expected_p
), f"Prefill replicas mismatch: expected {expected_p}, got {prefill_replicas}" ), f"Prefill replicas mismatch: expected {expected_p}, got {prefill_replicas}"
assert ( assert (
decode_replicas == expected_d decode_replicas == expected_d
), f"Decode replicas mismatch: expected {expected_d}, got {decode_replicas}" ), f"Decode replicas mismatch: expected {expected_d}, got {decode_replicas}"
@pytest.mark.nightly @pytest.mark.nightly
@pytest.mark.gpu_2 @pytest.mark.gpu_2
...@@ -359,24 +480,24 @@ class TestReplicaCalculation: ...@@ -359,24 +480,24 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments()) asyncio.run(planner.make_adjustments())
# Verify that total GPU usage doesn't exceed budget # Verify that total GPU usage doesn't exceed budget
if planner.connector.set_component_replicas.called: prefill_replicas = _replica_count(
call_args = planner.connector.set_component_replicas.call_args[0][0] planner.last_target_replicas, "VllmPrefillWorker"
)
prefill_replicas = call_args.get("VllmPrefillWorker", 1) decode_replicas = _replica_count(
decode_replicas = call_args.get("VllmDecodeWorker", 1) planner.last_target_replicas, "VllmDecodeWorker"
)
total_gpus = ( total_gpus = (
prefill_replicas * planner.args.prefill_engine_num_gpu prefill_replicas * planner.args.prefill_engine_num_gpu
+ decode_replicas * planner.args.decode_engine_num_gpu + decode_replicas * planner.args.decode_engine_num_gpu
) )
print( print(
f"GPU budget test: P={prefill_replicas}, D={decode_replicas}, Total GPUs={total_gpus}" f"GPU budget test: P={prefill_replicas}, D={decode_replicas}, Total GPUs={total_gpus}"
) )
assert ( assert (
total_gpus <= planner.args.max_gpu_budget total_gpus <= planner.args.max_gpu_budget
), "Total GPU usage exceeds budget" ), "Total GPU usage exceeds budget"
@pytest.mark.nightly @pytest.mark.nightly
@pytest.mark.gpu_2 @pytest.mark.gpu_2
...@@ -417,20 +538,20 @@ class TestReplicaCalculation: ...@@ -417,20 +538,20 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments()) asyncio.run(planner.make_adjustments())
# Verify minimum constraints are respected # Verify minimum constraints are respected
if planner.connector.set_component_replicas.called: prefill_replicas = _replica_count(
call_args = planner.connector.set_component_replicas.call_args[0][0] planner.last_target_replicas, "VllmPrefillWorker"
)
prefill_replicas = call_args.get("VllmPrefillWorker", 1) decode_replicas = _replica_count(
decode_replicas = call_args.get("VllmDecodeWorker", 1) planner.last_target_replicas, "VllmDecodeWorker"
)
print(f"Min endpoint test: P={prefill_replicas}, D={decode_replicas}") print(f"Min endpoint test: P={prefill_replicas}, D={decode_replicas}")
assert ( assert (
prefill_replicas >= planner.args.min_endpoint prefill_replicas >= planner.args.min_endpoint
), "Prefill replicas below minimum" ), "Prefill replicas below minimum"
assert ( assert (
decode_replicas >= planner.args.min_endpoint decode_replicas >= planner.args.min_endpoint
), "Decode replicas below minimum" ), "Decode replicas below minimum"
@pytest.mark.nightly @pytest.mark.nightly
@pytest.mark.gpu_2 @pytest.mark.gpu_2
...@@ -482,17 +603,17 @@ class TestReplicaCalculation: ...@@ -482,17 +603,17 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments()) asyncio.run(planner.make_adjustments())
# Verify that correction factor was effectively clamped # Verify that correction factor was effectively clamped
if planner.connector.set_component_replicas.called: prefill_replicas = _replica_count(
call_args = planner.connector.set_component_replicas.call_args[0][0] planner.last_target_replicas, "VllmPrefillWorker"
prefill_replicas = call_args.get("VllmPrefillWorker", 1) )
print( print(
f"Correction factor clamping test: Expected={expected_prefill_replicas}, Got={prefill_replicas}" f"Correction factor clamping test: Expected={expected_prefill_replicas}, Got={prefill_replicas}"
) )
assert prefill_replicas == max( assert prefill_replicas == max(
expected_prefill_replicas, planner.args.min_endpoint expected_prefill_replicas, planner.args.min_endpoint
), "Prefill correction factor should be clamped to 1" ), "Prefill correction factor should be clamped to 1"
@pytest.mark.nightly @pytest.mark.nightly
@pytest.mark.gpu_2 @pytest.mark.gpu_2
...@@ -501,62 +622,59 @@ class TestReplicaCalculation: ...@@ -501,62 +622,59 @@ class TestReplicaCalculation:
"""Test handling of d_correction_factor <= 0.""" """Test handling of d_correction_factor <= 0."""
# Test both 0 and negative values # Test both 0 and negative values
for correction_factor in [0.0, -1.0]: for correction_factor in [0.0, -1.0]:
with patch.object(planner, "connector") as mock_connector: planner.p_correction_factor = 1.0
planner.p_correction_factor = 1.0 planner.d_correction_factor = correction_factor
planner.d_correction_factor = correction_factor
# Mock predictor outputs
# Mock predictor outputs planner.num_req_predictor.predict_next.return_value = 10
planner.num_req_predictor.predict_next.return_value = 10 planner.isl_predictor.predict_next.return_value = 3000
planner.isl_predictor.predict_next.return_value = 3000 planner.osl_predictor.predict_next.return_value = 150
planner.osl_predictor.predict_next.return_value = 150
# Mock interpolator outputs
# Mock interpolator outputs planner.prefill_interpolator.interpolate_thpt_per_gpu.return_value = 40000
planner.prefill_interpolator.interpolate_thpt_per_gpu.return_value = ( planner.decode_interpolator.find_best_throughput_per_gpu.return_value = (
40000 10000,
) 0.01,
planner.decode_interpolator.find_best_throughput_per_gpu.return_value = ( 0.5,
10000, )
0.01,
0.5, # Set up metrics
) planner.last_metrics = Metrics(
num_req=10,
# Set up metrics isl=3000,
planner.last_metrics = Metrics( osl=150,
num_req=10, ttft=80.0,
isl=3000, itl=10.0,
osl=150, request_duration=100.0,
ttft=80.0, )
itl=10.0,
request_duration=100.0, # Mock workers info
) async def mock_get_workers_info():
return (["prefill1"], ["decode1"])
# Mock workers info
async def mock_get_workers_info(): planner.get_workers_info = mock_get_workers_info
return (["prefill1"], ["decode1"])
# Mock interpolation calls
planner.get_workers_info = mock_get_workers_info planner.prefill_interpolator.interpolate_ttft.return_value = 80.0
planner.decode_interpolator.interpolate_itl.return_value = 10.0
# Mock interpolation calls
planner.prefill_interpolator.interpolate_ttft.return_value = 80.0 # Run calculation
planner.decode_interpolator.interpolate_itl.return_value = 10.0 asyncio.run(planner.make_adjustments())
# Run calculation # Should handle gracefully without crashing
asyncio.run(planner.make_adjustments()) # The code should use args.itl directly instead of dividing by 0
decode_replicas = _replica_count(
# Should handle gracefully without crashing planner.last_target_replicas, "VllmDecodeWorker"
# The code should use args.itl directly instead of dividing by 0 )
if mock_connector.set_component_replicas.called:
call_args = mock_connector.set_component_replicas.call_args[0][0] print(
decode_replicas = call_args.get("VllmDecodeWorker", 1) f"Correction factor {correction_factor} test: Decode replicas={decode_replicas}"
)
print(
f"Correction factor {correction_factor} test: Decode replicas={decode_replicas}" # Should get a valid result (not crash)
) assert (
decode_replicas >= 1
# Should get a valid result (not crash) ), f"Should handle correction factor {correction_factor} gracefully"
assert (
decode_replicas >= 1
), f"Should handle correction factor {correction_factor} gracefully"
@pytest.mark.nightly @pytest.mark.nightly
@pytest.mark.gpu_2 @pytest.mark.gpu_2
...@@ -608,23 +726,23 @@ class TestReplicaCalculation: ...@@ -608,23 +726,23 @@ class TestReplicaCalculation:
# Run calculation # Run calculation
asyncio.run(planner.make_adjustments()) asyncio.run(planner.make_adjustments())
if planner.connector.set_component_replicas.called: prefill_replicas = _replica_count(
call_args = planner.connector.set_component_replicas.call_args[0][0] planner.last_target_replicas, "VllmPrefillWorker"
)
prefill_replicas = call_args.get("VllmPrefillWorker", 1) decode_replicas = _replica_count(
decode_replicas = call_args.get("VllmDecodeWorker", 1) planner.last_target_replicas, "VllmDecodeWorker"
)
print( print(
f"Multi-GPU test: P={prefill_replicas} (expected ~{expected_prefill_replicas}), D={decode_replicas} (expected ~{expected_decode_replicas})" f"Multi-GPU test: P={prefill_replicas} (expected ~{expected_prefill_replicas}), D={decode_replicas} (expected ~{expected_decode_replicas})"
) )
# Verify calculations account for multiple GPUs per engine # Verify calculations account for multiple GPUs per engine
assert prefill_replicas == max( assert prefill_replicas == max(
expected_prefill_replicas, planner.args.min_endpoint expected_prefill_replicas, planner.args.min_endpoint
) )
assert decode_replicas == max( assert decode_replicas == max(
expected_decode_replicas, planner.args.min_endpoint expected_decode_replicas, planner.args.min_endpoint
) )
@pytest.mark.weekly @pytest.mark.weekly
@pytest.mark.gpu_2 @pytest.mark.gpu_2
...@@ -668,31 +786,31 @@ class TestReplicaCalculation: ...@@ -668,31 +786,31 @@ class TestReplicaCalculation:
# Run calculation # Run calculation
asyncio.run(planner.make_adjustments()) asyncio.run(planner.make_adjustments())
if planner.connector.set_component_replicas.called: prefill_replicas = _replica_count(
call_args = planner.connector.set_component_replicas.call_args[0][0] planner.last_target_replicas, "VllmPrefillWorker"
)
prefill_replicas = call_args.get("VllmPrefillWorker", 1) decode_replicas = _replica_count(
decode_replicas = call_args.get("VllmDecodeWorker", 1) planner.last_target_replicas, "VllmDecodeWorker"
)
# Verify total GPU usage doesn't exceed budget # Verify total GPU usage doesn't exceed budget
total_gpus = ( total_gpus = (
prefill_replicas * planner.args.prefill_engine_num_gpu prefill_replicas * planner.args.prefill_engine_num_gpu
+ decode_replicas * planner.args.decode_engine_num_gpu + decode_replicas * planner.args.decode_engine_num_gpu
) )
print( print(
f"Complex GPU budget test: P={prefill_replicas}, D={decode_replicas}, Total GPUs={total_gpus}" f"Complex GPU budget test: P={prefill_replicas}, D={decode_replicas}, Total GPUs={total_gpus}"
) )
assert ( assert (
total_gpus <= planner.args.max_gpu_budget total_gpus <= planner.args.max_gpu_budget
), "Total GPU usage should not exceed budget" ), "Total GPU usage should not exceed budget"
assert ( assert (
prefill_replicas >= planner.args.min_endpoint prefill_replicas >= planner.args.min_endpoint
), "Should respect min_endpoint for prefill" ), "Should respect min_endpoint for prefill"
assert ( assert (
decode_replicas >= planner.args.min_endpoint decode_replicas >= planner.args.min_endpoint
), "Should respect min_endpoint for decode" ), "Should respect min_endpoint for decode"
# No need for unittest.main() with pytest! # No need for unittest.main() with pytest!
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
import logging import logging
from dynamo.planner.utils.dryrun import run_sla_planner_dryrun
from dynamo.planner.utils.planner_argparse import create_sla_planner_parser from dynamo.planner.utils.planner_argparse import create_sla_planner_parser
from dynamo.planner.utils.planner_core import Planner
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -45,5 +45,4 @@ if __name__ == "__main__": ...@@ -45,5 +45,4 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
planner = Planner(None, args, dryrun=True) run_sla_planner_dryrun(args)
planner.dryrun_run()
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import math
import os
from unittest.mock import Mock, patch
import pytest
from dynamo.planner.utils.planner_core import (
DecodePlanner,
PlannerSharedState,
PrefillPlanner,
)
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.planner,
]
@pytest.fixture(autouse=True)
def mock_prometheus_metrics():
with patch("dynamo.planner.utils.planner_core.Gauge") as mock_gauge:
mock_gauge.return_value = Mock()
yield
def _build_args():
args = argparse.Namespace()
args.adjustment_interval = 60
args.prefill_engine_num_gpu = 1
args.decode_engine_num_gpu = 1
args.min_endpoint = 1
args.max_gpu_budget = -1
args.ttft = 500.0
args.itl = 50.0
args.backend = "vllm"
args.no_operation = True
args.no_correction = True
args.metric_pulling_prometheus_endpoint = "http://localhost:9090"
args.metric_reporting_prometheus_port = 0
args.load_predictor = "constant"
args.load_predictor_warmup_trace = None
args.profile_results_dir = os.path.join(
os.path.dirname(__file__),
"..",
"profiling_results",
"H200_TP1P_TP1D",
)
args.environment = "kubernetes"
args.namespace = "test-namespace"
args.mode = "disagg"
return args
def _build_prometheus_client(samples):
client = Mock()
client.get_avg_time_to_first_token.side_effect = [
s["ttft_ms"] / 1000 for s in samples
]
client.get_avg_inter_token_latency.side_effect = [
s["itl_ms"] / 1000 for s in samples
]
client.get_avg_request_count.side_effect = [s["num_req"] for s in samples]
client.get_avg_request_duration.side_effect = [
s["request_duration"] for s in samples
]
client.get_avg_input_sequence_tokens.side_effect = [s["isl"] for s in samples]
client.get_avg_output_sequence_tokens.side_effect = [s["osl"] for s in samples]
return client
def _build_planners(args, prometheus_client):
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(None, args, shared_state=shared_state)
decode_planner = DecodePlanner(None, args, shared_state=shared_state)
prefill_planner.prometheus_api_client = prometheus_client
decode_planner.prometheus_api_client = prometheus_client
prefill_planner.model_name = "test-model"
decode_planner.model_name = "test-model"
async def mock_get_workers_info(require_prefill=True, require_decode=True):
return (
["prefill-0"] if require_prefill else [],
["decode-0"] if require_decode else [],
)
prefill_planner.get_workers_info = mock_get_workers_info
decode_planner.get_workers_info = mock_get_workers_info
return prefill_planner, decode_planner, shared_state
def _expected_prefill(args, prefill_planner, sample):
pred_prefill_throughput = (
sample["num_req"] * sample["isl"] / args.adjustment_interval
)
thpt_per_gpu = prefill_planner.prefill_interpolator.interpolate_thpt_per_gpu(
sample["isl"]
)
expected = math.ceil(
pred_prefill_throughput / thpt_per_gpu / args.prefill_engine_num_gpu
)
return max(expected, args.min_endpoint)
def _expected_decode(args, decode_planner, sample):
(
pred_decode_thpt_per_gpu,
_,
_,
) = decode_planner.decode_interpolator.find_best_throughput_per_gpu(
itl=args.itl, context_length=sample["isl"] + sample["osl"] / 2
)
pred_decode_throughput = (
sample["num_req"] * sample["osl"] / args.adjustment_interval
)
expected = math.ceil(
pred_decode_throughput / pred_decode_thpt_per_gpu / args.decode_engine_num_gpu
)
return max(expected, args.min_endpoint)
def _run_interval(prefill_planner, decode_planner, shared_state):
asyncio.run(
prefill_planner.observe_metrics(require_prefill=True, require_decode=True)
)
decode_planner.update_predictors_from_metrics(shared_state.last_metrics)
next_num_p = prefill_planner.plan_adjustment()
next_num_d = decode_planner.plan_adjustment()
return next_num_p, next_num_d
def test_disagg_scale_up():
args = _build_args()
samples = [
{
"num_req": 10,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
{
"num_req": 5000,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
]
client = _build_prometheus_client(samples)
prefill_planner, decode_planner, shared_state = _build_planners(args, client)
low_p, low_d = _run_interval(prefill_planner, decode_planner, shared_state)
high_p, high_d = _run_interval(prefill_planner, decode_planner, shared_state)
assert low_p == _expected_prefill(args, prefill_planner, samples[0])
assert low_d == _expected_decode(args, decode_planner, samples[0])
assert high_p == _expected_prefill(args, prefill_planner, samples[1])
assert high_d == _expected_decode(args, decode_planner, samples[1])
assert high_p > low_p
assert high_d > low_d
def test_disagg_scale_down():
args = _build_args()
samples = [
{
"num_req": 5000,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
{
"num_req": 10,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
]
client = _build_prometheus_client(samples)
prefill_planner, decode_planner, shared_state = _build_planners(args, client)
high_p, high_d = _run_interval(prefill_planner, decode_planner, shared_state)
low_p, low_d = _run_interval(prefill_planner, decode_planner, shared_state)
assert high_p == _expected_prefill(args, prefill_planner, samples[0])
assert high_d == _expected_decode(args, decode_planner, samples[0])
assert low_p == _expected_prefill(args, prefill_planner, samples[1])
assert low_d == _expected_decode(args, decode_planner, samples[1])
assert low_p < high_p
assert low_d < high_d
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment