"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "aa7012eb6db69baab57c80ac596d088eb81e090f"
Unverified Commit e2f1e04f authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

refactor: separate planner into prefill/decode planner (#5622)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent 77aadb72
...@@ -79,6 +79,7 @@ class SLAPlannerDefaults(BasePlannerDefaults): ...@@ -79,6 +79,7 @@ class SLAPlannerDefaults(BasePlannerDefaults):
kalman_min_points = 5 kalman_min_points = 5
no_correction = False # disable correction factor, might be useful under some conditions like long cold start time no_correction = False # disable correction factor, might be useful under some conditions like long cold start time
mode = "disagg" # ["disagg", "prefill", "decode"]
class VllmComponentName: class VllmComponentName:
......
...@@ -113,6 +113,8 @@ class KubernetesConnector(PlannerConnector): ...@@ -113,6 +113,8 @@ class KubernetesConnector(PlannerConnector):
self, self,
prefill_component_name: Optional[str] = None, prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None, decode_component_name: Optional[str] = None,
require_prefill: bool = True,
require_decode: bool = True,
): ):
""" """
Verify that the deployment contains services with subComponentType prefill and decode and the model name exists. Verify that the deployment contains services with subComponentType prefill and decode and the model name exists.
...@@ -126,6 +128,7 @@ class KubernetesConnector(PlannerConnector): ...@@ -126,6 +128,7 @@ class KubernetesConnector(PlannerConnector):
errors = [] errors = []
if require_prefill:
try: try:
get_service_from_sub_component_type_or_name( get_service_from_sub_component_type_or_name(
deployment, deployment,
...@@ -135,6 +138,7 @@ class KubernetesConnector(PlannerConnector): ...@@ -135,6 +138,7 @@ class KubernetesConnector(PlannerConnector):
except PlannerError as e: except PlannerError as e:
errors.append(str(e)) errors.append(str(e))
if require_decode:
try: try:
get_service_from_sub_component_type_or_name( get_service_from_sub_component_type_or_name(
deployment, deployment,
...@@ -145,7 +149,11 @@ class KubernetesConnector(PlannerConnector): ...@@ -145,7 +149,11 @@ class KubernetesConnector(PlannerConnector):
errors.append(str(e)) errors.append(str(e))
try: try:
self.get_model_name(deployment) self.get_model_name(
deployment,
require_prefill=require_prefill,
require_decode=require_decode,
)
except PlannerError as e: except PlannerError as e:
errors.append(str(e)) errors.append(str(e))
...@@ -153,7 +161,12 @@ class KubernetesConnector(PlannerConnector): ...@@ -153,7 +161,12 @@ class KubernetesConnector(PlannerConnector):
if errors: if errors:
raise DeploymentValidationError(errors) raise DeploymentValidationError(errors)
def get_model_name(self, deployment: Optional[dict] = None) -> str: def get_model_name(
self,
deployment: Optional[dict] = None,
require_prefill: bool = True,
require_decode: bool = True,
) -> str:
"""Get the model name from the deployment""" """Get the model name from the deployment"""
try: try:
if deployment is None: if deployment is None:
...@@ -163,15 +176,19 @@ class KubernetesConnector(PlannerConnector): ...@@ -163,15 +176,19 @@ class KubernetesConnector(PlannerConnector):
# TODO: benchmarks/profiler/utils/config.py already contains DGD config parsing # TODO: benchmarks/profiler/utils/config.py already contains DGD config parsing
# and model name logic, should consolidate # and model name logic, should consolidate
prefill_model_name = None
decode_model_name = None
if require_prefill:
prefill_service = get_service_from_sub_component_type_or_name( prefill_service = get_service_from_sub_component_type_or_name(
deployment, deployment,
SubComponentType.PREFILL, SubComponentType.PREFILL,
) )
prefill_model_name = prefill_service.get_model_name()
if require_decode:
decode_service = get_service_from_sub_component_type_or_name( decode_service = get_service_from_sub_component_type_or_name(
deployment, deployment,
SubComponentType.DECODE, SubComponentType.DECODE,
) )
prefill_model_name = prefill_service.get_model_name()
decode_model_name = decode_service.get_model_name() decode_model_name = decode_service.get_model_name()
if prefill_model_name is None and decode_model_name is None: if prefill_model_name is None and decode_model_name is None:
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
from typing import Optional
from dynamo.planner.utils.dryrun_plot_utils import create_dryrun_plot
from dynamo.planner.utils.planner_core import (
DecodePlanner,
PlannerSharedState,
PrefillPlanner,
_apply_component_gpu_budget,
_apply_global_gpu_budget,
)
from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_mooncake
def run_sla_planner_dryrun(args: argparse.Namespace) -> None:
warmup_metrics = None
if getattr(args, "load_predictor_warmup_trace", None):
warmup_metrics = extract_metrics_from_mooncake(
args.load_predictor_warmup_trace,
args.adjustment_interval,
)
metrics = extract_metrics_from_mooncake(args.dataset, args.adjustment_interval)
if not metrics:
raise ValueError("Empty metrics dataset: cannot run dryrun")
mode = getattr(args, "mode", "disagg")
prefill_planner: Optional[PrefillPlanner] = None
decode_planner: Optional[DecodePlanner] = None
if mode == "disagg":
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(
None, args, dryrun=True, shared_state=shared_state
)
decode_planner = DecodePlanner(
None, args, dryrun=True, shared_state=shared_state
)
elif mode == "prefill":
prefill_planner = PrefillPlanner(None, args, dryrun=True)
elif mode == "decode":
decode_planner = DecodePlanner(None, args, dryrun=True)
else:
raise ValueError(f"Invalid planner mode: {mode}")
def compute_safe_p_thpt(num_p: int, isl: float, ttft: float):
"""safe throughput is maximum throughput that the engine can handle given the TTFT SLA"""
assert prefill_planner is not None
actual_ttft = prefill_planner.prefill_interpolator.interpolate_ttft(isl)
if actual_ttft > ttft:
return 0
return num_p * prefill_planner.prefill_interpolator.interpolate_thpt_per_gpu(
isl
)
def compute_safe_d_thpt(num_d: int, isl: float, osl: float, itl: float):
"""safe throughput is maximum throughput that the engine can handle given the ITL SLA"""
assert decode_planner is not None
(
pred_decode_thpt_per_gpu,
actual_itl,
_,
) = decode_planner.decode_interpolator.find_best_throughput_per_gpu(
itl=itl, context_length=isl + osl / 2
)
if actual_itl > itl:
return 0
return num_d * pred_decode_thpt_per_gpu
time_series = [0]
rr = [metrics[0]["request_count"]]
est_rr = [metrics[0]["request_count"]]
isl = [metrics[0]["avg_isl"]]
est_isl = [metrics[0]["avg_isl"]]
osl = [metrics[0]["avg_osl"]]
est_osl = [metrics[0]["avg_osl"]]
if prefill_planner is not None:
num_p = [args.start_num_p]
p_thpt = [rr[0] * isl[0]]
safe_p_thpt = [
compute_safe_p_thpt(args.start_num_p, isl[0], args.ttft)
* args.adjustment_interval
]
prefill_planner.dryrun_observe_metrics(rr[0], isl[0], osl[0])
else:
num_p = [0]
p_thpt = [0]
safe_p_thpt = [0]
if decode_planner is not None:
num_d = [args.start_num_d]
d_thpt = [rr[0] * osl[0]]
safe_d_thpt = [
compute_safe_d_thpt(args.start_num_d, isl[0], osl[0], args.itl)
* args.adjustment_interval
]
decode_planner.dryrun_observe_metrics(rr[0], isl[0], osl[0])
else:
num_d = [0]
d_thpt = [0]
safe_d_thpt = [0]
predictor_planner = prefill_planner or decode_planner
assert predictor_planner is not None
for metric in metrics[1:]:
# update time
time_series.append(time_series[-1] + args.adjustment_interval)
# load prediction
_est_rr, _est_isl, _est_osl = predictor_planner.predict_load()
est_rr.append(_est_rr)
est_isl.append(_est_isl)
est_osl.append(_est_osl)
# compute num_p and num_d
_num_p = (
prefill_planner._compute_replica_requirements(_est_rr, _est_isl, _est_osl)
if prefill_planner is not None
else 0
)
_num_d = (
decode_planner._compute_replica_requirements(_est_rr, _est_isl, _est_osl)
if decode_planner is not None
else 0
)
# apply GPU budget
if prefill_planner is not None and decode_planner is not None:
_num_p, _num_d = _apply_global_gpu_budget(_num_p, _num_d, args)
elif prefill_planner is not None:
_num_p = _apply_component_gpu_budget(
_num_p, args.prefill_engine_num_gpu, args
)
elif decode_planner is not None:
_num_d = _apply_component_gpu_budget(
_num_d, args.decode_engine_num_gpu, args
)
num_p.append(_num_p)
num_d.append(_num_d)
# update load predictor
for planner in [prefill_planner, decode_planner]:
if planner is not None:
planner.dryrun_observe_metrics(
metric["request_count"], metric["avg_isl"], metric["avg_osl"]
)
# fill in ground truth
rr.append(metric["request_count"])
isl.append(metric["avg_isl"])
osl.append(metric["avg_osl"])
p_thpt.append(rr[-1] * isl[-1] if prefill_planner is not None else 0)
d_thpt.append(rr[-1] * osl[-1] if decode_planner is not None else 0)
safe_p_thpt.append(
compute_safe_p_thpt(num_p[-1], isl[-1], args.ttft)
* args.adjustment_interval
if prefill_planner is not None
else 0
)
safe_d_thpt.append(
compute_safe_d_thpt(num_d[-1], isl[-1], osl[-1], args.itl)
* args.adjustment_interval
if decode_planner is not None
else 0
)
warmup_time = None
warmup_rr = None
warmup_isl = None
warmup_osl = None
if warmup_metrics:
interval = args.adjustment_interval
n = len(warmup_metrics)
warmup_time = [-(n - i) * interval for i in range(n)]
warmup_rr = [m["request_count"] for m in warmup_metrics]
warmup_isl = [m["avg_isl"] for m in warmup_metrics]
warmup_osl = [m["avg_osl"] for m in warmup_metrics]
create_dryrun_plot(
time=time_series,
rr=rr,
est_rr=est_rr,
isl=isl,
est_isl=est_isl,
osl=osl,
est_osl=est_osl,
num_p=num_p,
p_thpt=p_thpt,
safe_p_thpt=safe_p_thpt,
num_d=num_d,
d_thpt=d_thpt,
safe_d_thpt=safe_d_thpt,
output_path=args.output_plot,
warmup_time=warmup_time,
warmup_rr=warmup_rr,
warmup_isl=warmup_isl,
warmup_osl=warmup_osl,
)
...@@ -42,6 +42,12 @@ def create_sla_planner_parser() -> argparse.ArgumentParser: ...@@ -42,6 +42,12 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
choices=["vllm", "sglang", "trtllm", "mocker"], choices=["vllm", "sglang", "trtllm", "mocker"],
help="Backend type", help="Backend type",
) )
parser.add_argument(
"--mode",
default=SLAPlannerDefaults.mode,
choices=["disagg", "prefill", "decode"],
help="Planner mode: disagg (prefill+decode), prefill-only, or decode-only",
)
parser.add_argument( parser.add_argument(
"--no-operation", "--no-operation",
action="store_true", action="store_true",
...@@ -61,7 +67,7 @@ def create_sla_planner_parser() -> argparse.ArgumentParser: ...@@ -61,7 +67,7 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
"--max-gpu-budget", "--max-gpu-budget",
type=int, type=int,
default=SLAPlannerDefaults.max_gpu_budget, default=SLAPlannerDefaults.max_gpu_budget,
help="Maximum GPU budget", help="Maximum GPU budget (-1 for no budget enforcement)",
) )
parser.add_argument( parser.add_argument(
"--min-endpoint", "--min-endpoint",
......
...@@ -130,6 +130,8 @@ class VirtualConnector(PlannerConnector): ...@@ -130,6 +130,8 @@ class VirtualConnector(PlannerConnector):
self, self,
prefill_component_name: Optional[str] = None, prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None, decode_component_name: Optional[str] = None,
require_prefill: bool = True,
require_decode: bool = True,
): ):
"""Validate the deployment""" """Validate the deployment"""
pass pass
...@@ -138,6 +140,9 @@ class VirtualConnector(PlannerConnector): ...@@ -138,6 +140,9 @@ class VirtualConnector(PlannerConnector):
"""Wait for the deployment to be ready""" """Wait for the deployment to be ready"""
await self._wait_for_scaling_completion() await self._wait_for_scaling_completion()
async def get_model_name(self) -> str: async def get_model_name(
self, require_prefill: bool = True, require_decode: bool = True
) -> str:
"""Get the model name from the deployment""" """Get the model name from the deployment"""
del require_prefill, require_decode
return self.model_name return self.model_name
...@@ -36,11 +36,132 @@ sys.modules["dynamo.runtime"] = mock_runtime ...@@ -36,11 +36,132 @@ sys.modules["dynamo.runtime"] = mock_runtime
sys.modules["dynamo.runtime.logging"] = mock_runtime.logging sys.modules["dynamo.runtime.logging"] = mock_runtime.logging
# Now import after mocking # Now import after mocking
from dynamo.planner.utils.planner_core import Metrics, Planner # noqa: E402 from dynamo.planner.utils.planner_core import ( # noqa: E402
DecodePlanner,
Metrics,
PlannerSharedState,
PrefillPlanner,
_apply_global_gpu_budget,
)
pytestmark = [pytest.mark.pre_merge, pytest.mark.gpu_0] pytestmark = [pytest.mark.pre_merge, pytest.mark.gpu_0]
class PlannerHarness:
def __init__(self, prefill_planner, decode_planner, shared_state):
self.prefill_planner = prefill_planner
self.decode_planner = decode_planner
self.shared_state = shared_state
self.last_target_replicas = []
async def make_adjustments(self):
if not self.shared_state.last_metrics.is_valid():
return
p_endpoints, d_endpoints = await self.prefill_planner.get_workers_info()
self.shared_state.p_endpoints = p_endpoints
self.shared_state.d_endpoints = d_endpoints
next_num_p = self.prefill_planner.plan_adjustment()
next_num_d = self.decode_planner.plan_adjustment()
if next_num_p is None or next_num_d is None:
return
next_num_p, next_num_d = _apply_global_gpu_budget(
next_num_p, next_num_d, self.prefill_planner.args
)
self.prefill_planner.update_predicted_replicas_metric(next_num_p)
self.decode_planner.update_predicted_replicas_metric(next_num_d)
target_replicas = [
{
"sub_component_type": "prefill",
"component_name": self.prefill_planner.prefill_component_name,
"desired_replicas": next_num_p,
},
{
"sub_component_type": "decode",
"component_name": self.prefill_planner.decode_component_name,
"desired_replicas": next_num_d,
},
]
self.last_target_replicas = target_replicas
if not self.prefill_planner.args.no_operation:
await self.prefill_planner.connector.set_component_replicas(
target_replicas, blocking=False
)
def __getattr__(self, name):
shared_attrs = {
"num_req_predictor",
"isl_predictor",
"osl_predictor",
"connector",
"prometheus_api_client",
"args",
}
prefill_attrs = {
"prefill_interpolator",
"prefill_component_name",
"p_correction_factor",
}
decode_attrs = {
"decode_interpolator",
"decode_component_name",
"d_correction_factor",
}
if name == "last_metrics":
return self.shared_state.last_metrics
if name == "get_workers_info":
return self.prefill_planner.get_workers_info
if name in shared_attrs:
return getattr(self.prefill_planner, name)
if name in prefill_attrs:
return getattr(self.prefill_planner, name)
if name in decode_attrs:
return getattr(self.decode_planner, name)
raise AttributeError(name)
def __setattr__(self, name, value):
if name in {"prefill_planner", "decode_planner", "shared_state"}:
return super().__setattr__(name, value)
shared_attrs = {
"num_req_predictor",
"isl_predictor",
"osl_predictor",
"connector",
"prometheus_api_client",
"args",
"get_workers_info",
}
prefill_attrs = {"prefill_interpolator", "p_correction_factor"}
decode_attrs = {"decode_interpolator", "d_correction_factor"}
if name == "last_metrics":
self.shared_state.last_metrics = value
return None
if name in shared_attrs:
# Store locally to support patch.object lifecycle (set/del).
object.__setattr__(self, name, value)
setattr(self.prefill_planner, name, value)
setattr(self.decode_planner, name, value)
return None
if name in prefill_attrs:
setattr(self.prefill_planner, name, value)
return None
if name in decode_attrs:
setattr(self.decode_planner, name, value)
return None
return super().__setattr__(name, value)
def _replica_count(target_replicas, component_name, default=1):
for replica in target_replicas:
if replica.get("component_name") == component_name:
return replica.get("desired_replicas", default)
return default
@pytest.fixture @pytest.fixture
def planner(): def planner():
"""Set up test environment with mocked dependencies.""" """Set up test environment with mocked dependencies."""
...@@ -75,8 +196,10 @@ def planner(): ...@@ -75,8 +196,10 @@ def planner():
with patch("dynamo.planner.utils.planner_core.Gauge") as mock_gauge: with patch("dynamo.planner.utils.planner_core.Gauge") as mock_gauge:
mock_gauge.return_value = Mock() mock_gauge.return_value = Mock()
# Create planner instance shared_state = PlannerSharedState()
planner = Planner(mock_runtime, args) prefill_planner = PrefillPlanner(mock_runtime, args, shared_state=shared_state)
decode_planner = DecodePlanner(mock_runtime, args, shared_state=shared_state)
planner = PlannerHarness(prefill_planner, decode_planner, shared_state)
# Mock the interpolators to return fixed values for testing # Mock the interpolators to return fixed values for testing
planner.prefill_interpolator = Mock() planner.prefill_interpolator = Mock()
...@@ -165,11 +288,10 @@ class TestReplicaCalculation: ...@@ -165,11 +288,10 @@ class TestReplicaCalculation:
# Extract the calculated values from the log calls or by checking the mock calls # Extract the calculated values from the log calls or by checking the mock calls
# Since we mocked the connector, we can check what replicas were requested # Since we mocked the connector, we can check what replicas were requested
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
prefill_component = "VllmPrefillWorker" prefill_component = "VllmPrefillWorker"
calculated_prefill_replicas = call_args.get(prefill_component, 1) calculated_prefill_replicas = _replica_count(
planner.last_target_replicas, prefill_component
)
print(f"Expected prefill replicas: {expected_prefill_replicas}") print(f"Expected prefill replicas: {expected_prefill_replicas}")
print(f"Calculated prefill replicas: {calculated_prefill_replicas}") print(f"Calculated prefill replicas: {calculated_prefill_replicas}")
...@@ -230,11 +352,10 @@ class TestReplicaCalculation: ...@@ -230,11 +352,10 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments()) asyncio.run(planner.make_adjustments())
# Check the results # Check the results
if planner.connector.set_component_replicas.called:
call_args = planner.connector.set_component_replicas.call_args[0][0]
decode_component = "VllmDecodeWorker" decode_component = "VllmDecodeWorker"
calculated_decode_replicas = call_args.get(decode_component, 1) calculated_decode_replicas = _replica_count(
planner.last_target_replicas, decode_component
)
print(f"Expected decode replicas: {expected_decode_replicas}") print(f"Expected decode replicas: {expected_decode_replicas}")
print(f"Calculated decode replicas: {calculated_decode_replicas}") print(f"Calculated decode replicas: {calculated_decode_replicas}")
...@@ -304,12 +425,12 @@ class TestReplicaCalculation: ...@@ -304,12 +425,12 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments()) asyncio.run(planner.make_adjustments())
# Verify results # Verify results
if planner.connector.set_component_replicas.called: prefill_replicas = _replica_count(
call_args = planner.connector.set_component_replicas.call_args[0][0] planner.last_target_replicas, "VllmPrefillWorker"
)
prefill_replicas = call_args.get("VllmPrefillWorker", 1) decode_replicas = _replica_count(
decode_replicas = call_args.get("VllmDecodeWorker", 1) planner.last_target_replicas, "VllmDecodeWorker"
)
print(f"Load {num_req} req/s: P={prefill_replicas}, D={decode_replicas}") print(f"Load {num_req} req/s: P={prefill_replicas}, D={decode_replicas}")
assert ( assert (
...@@ -359,12 +480,12 @@ class TestReplicaCalculation: ...@@ -359,12 +480,12 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments()) asyncio.run(planner.make_adjustments())
# Verify that total GPU usage doesn't exceed budget # Verify that total GPU usage doesn't exceed budget
if planner.connector.set_component_replicas.called: prefill_replicas = _replica_count(
call_args = planner.connector.set_component_replicas.call_args[0][0] planner.last_target_replicas, "VllmPrefillWorker"
)
prefill_replicas = call_args.get("VllmPrefillWorker", 1) decode_replicas = _replica_count(
decode_replicas = call_args.get("VllmDecodeWorker", 1) planner.last_target_replicas, "VllmDecodeWorker"
)
total_gpus = ( total_gpus = (
prefill_replicas * planner.args.prefill_engine_num_gpu prefill_replicas * planner.args.prefill_engine_num_gpu
+ decode_replicas * planner.args.decode_engine_num_gpu + decode_replicas * planner.args.decode_engine_num_gpu
...@@ -417,12 +538,12 @@ class TestReplicaCalculation: ...@@ -417,12 +538,12 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments()) asyncio.run(planner.make_adjustments())
# Verify minimum constraints are respected # Verify minimum constraints are respected
if planner.connector.set_component_replicas.called: prefill_replicas = _replica_count(
call_args = planner.connector.set_component_replicas.call_args[0][0] planner.last_target_replicas, "VllmPrefillWorker"
)
prefill_replicas = call_args.get("VllmPrefillWorker", 1) decode_replicas = _replica_count(
decode_replicas = call_args.get("VllmDecodeWorker", 1) planner.last_target_replicas, "VllmDecodeWorker"
)
print(f"Min endpoint test: P={prefill_replicas}, D={decode_replicas}") print(f"Min endpoint test: P={prefill_replicas}, D={decode_replicas}")
assert ( assert (
...@@ -482,9 +603,9 @@ class TestReplicaCalculation: ...@@ -482,9 +603,9 @@ class TestReplicaCalculation:
asyncio.run(planner.make_adjustments()) asyncio.run(planner.make_adjustments())
# Verify that correction factor was effectively clamped # Verify that correction factor was effectively clamped
if planner.connector.set_component_replicas.called: prefill_replicas = _replica_count(
call_args = planner.connector.set_component_replicas.call_args[0][0] planner.last_target_replicas, "VllmPrefillWorker"
prefill_replicas = call_args.get("VllmPrefillWorker", 1) )
print( print(
f"Correction factor clamping test: Expected={expected_prefill_replicas}, Got={prefill_replicas}" f"Correction factor clamping test: Expected={expected_prefill_replicas}, Got={prefill_replicas}"
...@@ -501,7 +622,6 @@ class TestReplicaCalculation: ...@@ -501,7 +622,6 @@ class TestReplicaCalculation:
"""Test handling of d_correction_factor <= 0.""" """Test handling of d_correction_factor <= 0."""
# Test both 0 and negative values # Test both 0 and negative values
for correction_factor in [0.0, -1.0]: for correction_factor in [0.0, -1.0]:
with patch.object(planner, "connector") as mock_connector:
planner.p_correction_factor = 1.0 planner.p_correction_factor = 1.0
planner.d_correction_factor = correction_factor planner.d_correction_factor = correction_factor
...@@ -511,9 +631,7 @@ class TestReplicaCalculation: ...@@ -511,9 +631,7 @@ class TestReplicaCalculation:
planner.osl_predictor.predict_next.return_value = 150 planner.osl_predictor.predict_next.return_value = 150
# Mock interpolator outputs # Mock interpolator outputs
planner.prefill_interpolator.interpolate_thpt_per_gpu.return_value = ( planner.prefill_interpolator.interpolate_thpt_per_gpu.return_value = 40000
40000
)
planner.decode_interpolator.find_best_throughput_per_gpu.return_value = ( planner.decode_interpolator.find_best_throughput_per_gpu.return_value = (
10000, 10000,
0.01, 0.01,
...@@ -545,9 +663,9 @@ class TestReplicaCalculation: ...@@ -545,9 +663,9 @@ class TestReplicaCalculation:
# Should handle gracefully without crashing # Should handle gracefully without crashing
# The code should use args.itl directly instead of dividing by 0 # The code should use args.itl directly instead of dividing by 0
if mock_connector.set_component_replicas.called: decode_replicas = _replica_count(
call_args = mock_connector.set_component_replicas.call_args[0][0] planner.last_target_replicas, "VllmDecodeWorker"
decode_replicas = call_args.get("VllmDecodeWorker", 1) )
print( print(
f"Correction factor {correction_factor} test: Decode replicas={decode_replicas}" f"Correction factor {correction_factor} test: Decode replicas={decode_replicas}"
...@@ -608,12 +726,12 @@ class TestReplicaCalculation: ...@@ -608,12 +726,12 @@ class TestReplicaCalculation:
# Run calculation # Run calculation
asyncio.run(planner.make_adjustments()) asyncio.run(planner.make_adjustments())
if planner.connector.set_component_replicas.called: prefill_replicas = _replica_count(
call_args = planner.connector.set_component_replicas.call_args[0][0] planner.last_target_replicas, "VllmPrefillWorker"
)
prefill_replicas = call_args.get("VllmPrefillWorker", 1) decode_replicas = _replica_count(
decode_replicas = call_args.get("VllmDecodeWorker", 1) planner.last_target_replicas, "VllmDecodeWorker"
)
print( print(
f"Multi-GPU test: P={prefill_replicas} (expected ~{expected_prefill_replicas}), D={decode_replicas} (expected ~{expected_decode_replicas})" f"Multi-GPU test: P={prefill_replicas} (expected ~{expected_prefill_replicas}), D={decode_replicas} (expected ~{expected_decode_replicas})"
) )
...@@ -668,12 +786,12 @@ class TestReplicaCalculation: ...@@ -668,12 +786,12 @@ class TestReplicaCalculation:
# Run calculation # Run calculation
asyncio.run(planner.make_adjustments()) asyncio.run(planner.make_adjustments())
if planner.connector.set_component_replicas.called: prefill_replicas = _replica_count(
call_args = planner.connector.set_component_replicas.call_args[0][0] planner.last_target_replicas, "VllmPrefillWorker"
)
prefill_replicas = call_args.get("VllmPrefillWorker", 1) decode_replicas = _replica_count(
decode_replicas = call_args.get("VllmDecodeWorker", 1) planner.last_target_replicas, "VllmDecodeWorker"
)
# Verify total GPU usage doesn't exceed budget # Verify total GPU usage doesn't exceed budget
total_gpus = ( total_gpus = (
prefill_replicas * planner.args.prefill_engine_num_gpu prefill_replicas * planner.args.prefill_engine_num_gpu
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
import logging import logging
from dynamo.planner.utils.dryrun import run_sla_planner_dryrun
from dynamo.planner.utils.planner_argparse import create_sla_planner_parser from dynamo.planner.utils.planner_argparse import create_sla_planner_parser
from dynamo.planner.utils.planner_core import Planner
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -45,5 +45,4 @@ if __name__ == "__main__": ...@@ -45,5 +45,4 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
planner = Planner(None, args, dryrun=True) run_sla_planner_dryrun(args)
planner.dryrun_run()
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import math
import os
from unittest.mock import Mock, patch
import pytest
from dynamo.planner.utils.planner_core import (
DecodePlanner,
PlannerSharedState,
PrefillPlanner,
)
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.planner,
]
@pytest.fixture(autouse=True)
def mock_prometheus_metrics():
with patch("dynamo.planner.utils.planner_core.Gauge") as mock_gauge:
mock_gauge.return_value = Mock()
yield
def _build_args():
args = argparse.Namespace()
args.adjustment_interval = 60
args.prefill_engine_num_gpu = 1
args.decode_engine_num_gpu = 1
args.min_endpoint = 1
args.max_gpu_budget = -1
args.ttft = 500.0
args.itl = 50.0
args.backend = "vllm"
args.no_operation = True
args.no_correction = True
args.metric_pulling_prometheus_endpoint = "http://localhost:9090"
args.metric_reporting_prometheus_port = 0
args.load_predictor = "constant"
args.load_predictor_warmup_trace = None
args.profile_results_dir = os.path.join(
os.path.dirname(__file__),
"..",
"profiling_results",
"H200_TP1P_TP1D",
)
args.environment = "kubernetes"
args.namespace = "test-namespace"
args.mode = "disagg"
return args
def _build_prometheus_client(samples):
client = Mock()
client.get_avg_time_to_first_token.side_effect = [
s["ttft_ms"] / 1000 for s in samples
]
client.get_avg_inter_token_latency.side_effect = [
s["itl_ms"] / 1000 for s in samples
]
client.get_avg_request_count.side_effect = [s["num_req"] for s in samples]
client.get_avg_request_duration.side_effect = [
s["request_duration"] for s in samples
]
client.get_avg_input_sequence_tokens.side_effect = [s["isl"] for s in samples]
client.get_avg_output_sequence_tokens.side_effect = [s["osl"] for s in samples]
return client
def _build_planners(args, prometheus_client):
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(None, args, shared_state=shared_state)
decode_planner = DecodePlanner(None, args, shared_state=shared_state)
prefill_planner.prometheus_api_client = prometheus_client
decode_planner.prometheus_api_client = prometheus_client
prefill_planner.model_name = "test-model"
decode_planner.model_name = "test-model"
async def mock_get_workers_info(require_prefill=True, require_decode=True):
return (
["prefill-0"] if require_prefill else [],
["decode-0"] if require_decode else [],
)
prefill_planner.get_workers_info = mock_get_workers_info
decode_planner.get_workers_info = mock_get_workers_info
return prefill_planner, decode_planner, shared_state
def _expected_prefill(args, prefill_planner, sample):
pred_prefill_throughput = (
sample["num_req"] * sample["isl"] / args.adjustment_interval
)
thpt_per_gpu = prefill_planner.prefill_interpolator.interpolate_thpt_per_gpu(
sample["isl"]
)
expected = math.ceil(
pred_prefill_throughput / thpt_per_gpu / args.prefill_engine_num_gpu
)
return max(expected, args.min_endpoint)
def _expected_decode(args, decode_planner, sample):
(
pred_decode_thpt_per_gpu,
_,
_,
) = decode_planner.decode_interpolator.find_best_throughput_per_gpu(
itl=args.itl, context_length=sample["isl"] + sample["osl"] / 2
)
pred_decode_throughput = (
sample["num_req"] * sample["osl"] / args.adjustment_interval
)
expected = math.ceil(
pred_decode_throughput / pred_decode_thpt_per_gpu / args.decode_engine_num_gpu
)
return max(expected, args.min_endpoint)
def _run_interval(prefill_planner, decode_planner, shared_state):
asyncio.run(
prefill_planner.observe_metrics(require_prefill=True, require_decode=True)
)
decode_planner.update_predictors_from_metrics(shared_state.last_metrics)
next_num_p = prefill_planner.plan_adjustment()
next_num_d = decode_planner.plan_adjustment()
return next_num_p, next_num_d
def test_disagg_scale_up():
args = _build_args()
samples = [
{
"num_req": 10,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
{
"num_req": 5000,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
]
client = _build_prometheus_client(samples)
prefill_planner, decode_planner, shared_state = _build_planners(args, client)
low_p, low_d = _run_interval(prefill_planner, decode_planner, shared_state)
high_p, high_d = _run_interval(prefill_planner, decode_planner, shared_state)
assert low_p == _expected_prefill(args, prefill_planner, samples[0])
assert low_d == _expected_decode(args, decode_planner, samples[0])
assert high_p == _expected_prefill(args, prefill_planner, samples[1])
assert high_d == _expected_decode(args, decode_planner, samples[1])
assert high_p > low_p
assert high_d > low_d
def test_disagg_scale_down():
args = _build_args()
samples = [
{
"num_req": 5000,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
{
"num_req": 10,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
]
client = _build_prometheus_client(samples)
prefill_planner, decode_planner, shared_state = _build_planners(args, client)
high_p, high_d = _run_interval(prefill_planner, decode_planner, shared_state)
low_p, low_d = _run_interval(prefill_planner, decode_planner, shared_state)
assert high_p == _expected_prefill(args, prefill_planner, samples[0])
assert high_d == _expected_decode(args, decode_planner, samples[0])
assert low_p == _expected_prefill(args, prefill_planner, samples[1])
assert low_d == _expected_decode(args, decode_planner, samples[1])
assert low_p < high_p
assert low_d < high_d
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment