Unverified Commit e2f1e04f authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

refactor: separate planner into prefill/decode planner (#5622)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent 77aadb72
......@@ -79,6 +79,7 @@ class SLAPlannerDefaults(BasePlannerDefaults):
kalman_min_points = 5
no_correction = False # disable correction factor, might be useful under some conditions like long cold start time
mode = "disagg" # ["disagg", "prefill", "decode"]
class VllmComponentName:
......
......@@ -113,6 +113,8 @@ class KubernetesConnector(PlannerConnector):
self,
prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None,
require_prefill: bool = True,
require_decode: bool = True,
):
"""
Verify that the deployment contains services with subComponentType prefill and decode and the model name exists.
......@@ -126,34 +128,45 @@ class KubernetesConnector(PlannerConnector):
errors = []
try:
get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.PREFILL,
component_name=prefill_component_name,
)
except PlannerError as e:
errors.append(str(e))
if require_prefill:
try:
get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.PREFILL,
component_name=prefill_component_name,
)
except PlannerError as e:
errors.append(str(e))
if require_decode:
try:
get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.DECODE,
component_name=decode_component_name,
)
except PlannerError as e:
errors.append(str(e))
try:
get_service_from_sub_component_type_or_name(
self.get_model_name(
deployment,
SubComponentType.DECODE,
component_name=decode_component_name,
require_prefill=require_prefill,
require_decode=require_decode,
)
except PlannerError as e:
errors.append(str(e))
try:
self.get_model_name(deployment)
except PlannerError as e:
errors.append(str(e))
# Raise combined error if any issues found
if errors:
raise DeploymentValidationError(errors)
def get_model_name(self, deployment: Optional[dict] = None) -> str:
def get_model_name(
self,
deployment: Optional[dict] = None,
require_prefill: bool = True,
require_decode: bool = True,
) -> str:
"""Get the model name from the deployment"""
try:
if deployment is None:
......@@ -163,16 +176,20 @@ class KubernetesConnector(PlannerConnector):
# TODO: benchmarks/profiler/utils/config.py already contains DGD config parsing
# and model name logic, should consolidate
prefill_service = get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.PREFILL,
)
decode_service = get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.DECODE,
)
prefill_model_name = prefill_service.get_model_name()
decode_model_name = decode_service.get_model_name()
prefill_model_name = None
decode_model_name = None
if require_prefill:
prefill_service = get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.PREFILL,
)
prefill_model_name = prefill_service.get_model_name()
if require_decode:
decode_service = get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.DECODE,
)
decode_model_name = decode_service.get_model_name()
if prefill_model_name is None and decode_model_name is None:
raise ModelNameNotFoundError()
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
from typing import Optional
from dynamo.planner.utils.dryrun_plot_utils import create_dryrun_plot
from dynamo.planner.utils.planner_core import (
DecodePlanner,
PlannerSharedState,
PrefillPlanner,
_apply_component_gpu_budget,
_apply_global_gpu_budget,
)
from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_mooncake
def run_sla_planner_dryrun(args: argparse.Namespace) -> None:
warmup_metrics = None
if getattr(args, "load_predictor_warmup_trace", None):
warmup_metrics = extract_metrics_from_mooncake(
args.load_predictor_warmup_trace,
args.adjustment_interval,
)
metrics = extract_metrics_from_mooncake(args.dataset, args.adjustment_interval)
if not metrics:
raise ValueError("Empty metrics dataset: cannot run dryrun")
mode = getattr(args, "mode", "disagg")
prefill_planner: Optional[PrefillPlanner] = None
decode_planner: Optional[DecodePlanner] = None
if mode == "disagg":
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(
None, args, dryrun=True, shared_state=shared_state
)
decode_planner = DecodePlanner(
None, args, dryrun=True, shared_state=shared_state
)
elif mode == "prefill":
prefill_planner = PrefillPlanner(None, args, dryrun=True)
elif mode == "decode":
decode_planner = DecodePlanner(None, args, dryrun=True)
else:
raise ValueError(f"Invalid planner mode: {mode}")
def compute_safe_p_thpt(num_p: int, isl: float, ttft: float):
"""safe throughput is maximum throughput that the engine can handle given the TTFT SLA"""
assert prefill_planner is not None
actual_ttft = prefill_planner.prefill_interpolator.interpolate_ttft(isl)
if actual_ttft > ttft:
return 0
return num_p * prefill_planner.prefill_interpolator.interpolate_thpt_per_gpu(
isl
)
def compute_safe_d_thpt(num_d: int, isl: float, osl: float, itl: float):
"""safe throughput is maximum throughput that the engine can handle given the ITL SLA"""
assert decode_planner is not None
(
pred_decode_thpt_per_gpu,
actual_itl,
_,
) = decode_planner.decode_interpolator.find_best_throughput_per_gpu(
itl=itl, context_length=isl + osl / 2
)
if actual_itl > itl:
return 0
return num_d * pred_decode_thpt_per_gpu
time_series = [0]
rr = [metrics[0]["request_count"]]
est_rr = [metrics[0]["request_count"]]
isl = [metrics[0]["avg_isl"]]
est_isl = [metrics[0]["avg_isl"]]
osl = [metrics[0]["avg_osl"]]
est_osl = [metrics[0]["avg_osl"]]
if prefill_planner is not None:
num_p = [args.start_num_p]
p_thpt = [rr[0] * isl[0]]
safe_p_thpt = [
compute_safe_p_thpt(args.start_num_p, isl[0], args.ttft)
* args.adjustment_interval
]
prefill_planner.dryrun_observe_metrics(rr[0], isl[0], osl[0])
else:
num_p = [0]
p_thpt = [0]
safe_p_thpt = [0]
if decode_planner is not None:
num_d = [args.start_num_d]
d_thpt = [rr[0] * osl[0]]
safe_d_thpt = [
compute_safe_d_thpt(args.start_num_d, isl[0], osl[0], args.itl)
* args.adjustment_interval
]
decode_planner.dryrun_observe_metrics(rr[0], isl[0], osl[0])
else:
num_d = [0]
d_thpt = [0]
safe_d_thpt = [0]
predictor_planner = prefill_planner or decode_planner
assert predictor_planner is not None
for metric in metrics[1:]:
# update time
time_series.append(time_series[-1] + args.adjustment_interval)
# load prediction
_est_rr, _est_isl, _est_osl = predictor_planner.predict_load()
est_rr.append(_est_rr)
est_isl.append(_est_isl)
est_osl.append(_est_osl)
# compute num_p and num_d
_num_p = (
prefill_planner._compute_replica_requirements(_est_rr, _est_isl, _est_osl)
if prefill_planner is not None
else 0
)
_num_d = (
decode_planner._compute_replica_requirements(_est_rr, _est_isl, _est_osl)
if decode_planner is not None
else 0
)
# apply GPU budget
if prefill_planner is not None and decode_planner is not None:
_num_p, _num_d = _apply_global_gpu_budget(_num_p, _num_d, args)
elif prefill_planner is not None:
_num_p = _apply_component_gpu_budget(
_num_p, args.prefill_engine_num_gpu, args
)
elif decode_planner is not None:
_num_d = _apply_component_gpu_budget(
_num_d, args.decode_engine_num_gpu, args
)
num_p.append(_num_p)
num_d.append(_num_d)
# update load predictor
for planner in [prefill_planner, decode_planner]:
if planner is not None:
planner.dryrun_observe_metrics(
metric["request_count"], metric["avg_isl"], metric["avg_osl"]
)
# fill in ground truth
rr.append(metric["request_count"])
isl.append(metric["avg_isl"])
osl.append(metric["avg_osl"])
p_thpt.append(rr[-1] * isl[-1] if prefill_planner is not None else 0)
d_thpt.append(rr[-1] * osl[-1] if decode_planner is not None else 0)
safe_p_thpt.append(
compute_safe_p_thpt(num_p[-1], isl[-1], args.ttft)
* args.adjustment_interval
if prefill_planner is not None
else 0
)
safe_d_thpt.append(
compute_safe_d_thpt(num_d[-1], isl[-1], osl[-1], args.itl)
* args.adjustment_interval
if decode_planner is not None
else 0
)
warmup_time = None
warmup_rr = None
warmup_isl = None
warmup_osl = None
if warmup_metrics:
interval = args.adjustment_interval
n = len(warmup_metrics)
warmup_time = [-(n - i) * interval for i in range(n)]
warmup_rr = [m["request_count"] for m in warmup_metrics]
warmup_isl = [m["avg_isl"] for m in warmup_metrics]
warmup_osl = [m["avg_osl"] for m in warmup_metrics]
create_dryrun_plot(
time=time_series,
rr=rr,
est_rr=est_rr,
isl=isl,
est_isl=est_isl,
osl=osl,
est_osl=est_osl,
num_p=num_p,
p_thpt=p_thpt,
safe_p_thpt=safe_p_thpt,
num_d=num_d,
d_thpt=d_thpt,
safe_d_thpt=safe_d_thpt,
output_path=args.output_plot,
warmup_time=warmup_time,
warmup_rr=warmup_rr,
warmup_isl=warmup_isl,
warmup_osl=warmup_osl,
)
......@@ -42,6 +42,12 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
choices=["vllm", "sglang", "trtllm", "mocker"],
help="Backend type",
)
parser.add_argument(
"--mode",
default=SLAPlannerDefaults.mode,
choices=["disagg", "prefill", "decode"],
help="Planner mode: disagg (prefill+decode), prefill-only, or decode-only",
)
parser.add_argument(
"--no-operation",
action="store_true",
......@@ -61,7 +67,7 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
"--max-gpu-budget",
type=int,
default=SLAPlannerDefaults.max_gpu_budget,
help="Maximum GPU budget",
help="Maximum GPU budget (-1 for no budget enforcement)",
)
parser.add_argument(
"--min-endpoint",
......
......@@ -130,6 +130,8 @@ class VirtualConnector(PlannerConnector):
self,
prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None,
require_prefill: bool = True,
require_decode: bool = True,
):
"""Validate the deployment"""
pass
......@@ -138,6 +140,9 @@ class VirtualConnector(PlannerConnector):
"""Wait for the deployment to be ready"""
await self._wait_for_scaling_completion()
async def get_model_name(self) -> str:
async def get_model_name(
self, require_prefill: bool = True, require_decode: bool = True
) -> str:
"""Get the model name from the deployment"""
del require_prefill, require_decode
return self.model_name
This diff is collapsed.
......@@ -15,8 +15,8 @@
import logging
from dynamo.planner.utils.dryrun import run_sla_planner_dryrun
from dynamo.planner.utils.planner_argparse import create_sla_planner_parser
from dynamo.planner.utils.planner_core import Planner
logger = logging.getLogger(__name__)
......@@ -45,5 +45,4 @@ if __name__ == "__main__":
)
args = parser.parse_args()
planner = Planner(None, args, dryrun=True)
planner.dryrun_run()
run_sla_planner_dryrun(args)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import math
import os
from unittest.mock import Mock, patch
import pytest
from dynamo.planner.utils.planner_core import (
DecodePlanner,
PlannerSharedState,
PrefillPlanner,
)
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.planner,
]
@pytest.fixture(autouse=True)
def mock_prometheus_metrics():
with patch("dynamo.planner.utils.planner_core.Gauge") as mock_gauge:
mock_gauge.return_value = Mock()
yield
def _build_args():
args = argparse.Namespace()
args.adjustment_interval = 60
args.prefill_engine_num_gpu = 1
args.decode_engine_num_gpu = 1
args.min_endpoint = 1
args.max_gpu_budget = -1
args.ttft = 500.0
args.itl = 50.0
args.backend = "vllm"
args.no_operation = True
args.no_correction = True
args.metric_pulling_prometheus_endpoint = "http://localhost:9090"
args.metric_reporting_prometheus_port = 0
args.load_predictor = "constant"
args.load_predictor_warmup_trace = None
args.profile_results_dir = os.path.join(
os.path.dirname(__file__),
"..",
"profiling_results",
"H200_TP1P_TP1D",
)
args.environment = "kubernetes"
args.namespace = "test-namespace"
args.mode = "disagg"
return args
def _build_prometheus_client(samples):
client = Mock()
client.get_avg_time_to_first_token.side_effect = [
s["ttft_ms"] / 1000 for s in samples
]
client.get_avg_inter_token_latency.side_effect = [
s["itl_ms"] / 1000 for s in samples
]
client.get_avg_request_count.side_effect = [s["num_req"] for s in samples]
client.get_avg_request_duration.side_effect = [
s["request_duration"] for s in samples
]
client.get_avg_input_sequence_tokens.side_effect = [s["isl"] for s in samples]
client.get_avg_output_sequence_tokens.side_effect = [s["osl"] for s in samples]
return client
def _build_planners(args, prometheus_client):
shared_state = PlannerSharedState()
prefill_planner = PrefillPlanner(None, args, shared_state=shared_state)
decode_planner = DecodePlanner(None, args, shared_state=shared_state)
prefill_planner.prometheus_api_client = prometheus_client
decode_planner.prometheus_api_client = prometheus_client
prefill_planner.model_name = "test-model"
decode_planner.model_name = "test-model"
async def mock_get_workers_info(require_prefill=True, require_decode=True):
return (
["prefill-0"] if require_prefill else [],
["decode-0"] if require_decode else [],
)
prefill_planner.get_workers_info = mock_get_workers_info
decode_planner.get_workers_info = mock_get_workers_info
return prefill_planner, decode_planner, shared_state
def _expected_prefill(args, prefill_planner, sample):
pred_prefill_throughput = (
sample["num_req"] * sample["isl"] / args.adjustment_interval
)
thpt_per_gpu = prefill_planner.prefill_interpolator.interpolate_thpt_per_gpu(
sample["isl"]
)
expected = math.ceil(
pred_prefill_throughput / thpt_per_gpu / args.prefill_engine_num_gpu
)
return max(expected, args.min_endpoint)
def _expected_decode(args, decode_planner, sample):
(
pred_decode_thpt_per_gpu,
_,
_,
) = decode_planner.decode_interpolator.find_best_throughput_per_gpu(
itl=args.itl, context_length=sample["isl"] + sample["osl"] / 2
)
pred_decode_throughput = (
sample["num_req"] * sample["osl"] / args.adjustment_interval
)
expected = math.ceil(
pred_decode_throughput / pred_decode_thpt_per_gpu / args.decode_engine_num_gpu
)
return max(expected, args.min_endpoint)
def _run_interval(prefill_planner, decode_planner, shared_state):
asyncio.run(
prefill_planner.observe_metrics(require_prefill=True, require_decode=True)
)
decode_planner.update_predictors_from_metrics(shared_state.last_metrics)
next_num_p = prefill_planner.plan_adjustment()
next_num_d = decode_planner.plan_adjustment()
return next_num_p, next_num_d
def test_disagg_scale_up():
args = _build_args()
samples = [
{
"num_req": 10,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
{
"num_req": 5000,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
]
client = _build_prometheus_client(samples)
prefill_planner, decode_planner, shared_state = _build_planners(args, client)
low_p, low_d = _run_interval(prefill_planner, decode_planner, shared_state)
high_p, high_d = _run_interval(prefill_planner, decode_planner, shared_state)
assert low_p == _expected_prefill(args, prefill_planner, samples[0])
assert low_d == _expected_decode(args, decode_planner, samples[0])
assert high_p == _expected_prefill(args, prefill_planner, samples[1])
assert high_d == _expected_decode(args, decode_planner, samples[1])
assert high_p > low_p
assert high_d > low_d
def test_disagg_scale_down():
args = _build_args()
samples = [
{
"num_req": 5000,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
{
"num_req": 10,
"isl": 3000,
"osl": 150,
"ttft_ms": 400.0,
"itl_ms": 30.0,
"request_duration": 20.0,
},
]
client = _build_prometheus_client(samples)
prefill_planner, decode_planner, shared_state = _build_planners(args, client)
high_p, high_d = _run_interval(prefill_planner, decode_planner, shared_state)
low_p, low_d = _run_interval(prefill_planner, decode_planner, shared_state)
assert high_p == _expected_prefill(args, prefill_planner, samples[0])
assert high_d == _expected_decode(args, decode_planner, samples[0])
assert low_p == _expected_prefill(args, prefill_planner, samples[1])
assert low_d == _expected_decode(args, decode_planner, samples[1])
assert low_p < high_p
assert low_d < high_d
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment