"docs/vscode:/vscode.git/clone" did not exist on "02c822d6b919ececa248755e8fc64c96123a1257"
Unverified Commit 1b3a1073 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

test: add dryrun mode for sla planner (#2557)


Signed-off-by: default avatarHongkuan Zhou <tedzhouhk@gmail.com>
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent 626d7e18
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib.pyplot as plt
def create_dryrun_plot(
time: list,
rr: list,
est_rr: list,
isl: list,
est_isl: list,
osl: list,
est_osl: list,
num_p: list,
p_thpt: list,
safe_p_thpt: list,
num_d: list,
d_thpt: list,
safe_d_thpt: list,
output_path: str,
) -> None:
"""
Create a comprehensive dryrun plot with 4 subplots showing various metrics over time.
Args:
time: List of time points
rr: List of actual request rates
est_rr: List of estimated request rates
isl: List of actual input sequence lengths
est_isl: List of estimated input sequence lengths
osl: List of actual output sequence lengths
est_osl: List of estimated output sequence lengths
num_p: List of prefill worker counts
p_thpt: List of actual prefill throughputs
safe_p_thpt: List of safe prefill throughput limits
num_d: List of decode worker counts
d_thpt: List of actual decode throughputs
safe_d_thpt: List of safe decode throughput limits
output_path: Path where the plot should be saved
"""
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
# Plot 1: Request Rate
ax1.plot(time, rr, "b-", label="Actual Request Rate", linewidth=2)
ax1.plot(time, est_rr, "r--", label="Predicted Request Rate", linewidth=2)
ax1.set_xlabel("Time (s)")
ax1.set_ylabel("Request Rate")
ax1.set_ylim(bottom=0)
ax1.set_title("Request Rate Over Time")
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot 2: Sequence Lengths
ax2.plot(time, isl, "g-", label="Actual ISL", linewidth=2)
ax2.plot(time, est_isl, "g--", label="Predicted ISL", linewidth=2)
ax2.plot(time, osl, "m-", label="Actual OSL", linewidth=2)
ax2.plot(time, est_osl, "m--", label="Predicted OSL", linewidth=2)
ax2.set_xlabel("Time (s)")
ax2.set_ylabel("Num Tokens")
ax2.set_ylim(bottom=0)
ax2.set_title("Input/Output Sequence Lengths Over Time")
ax2.legend()
ax2.grid(True, alpha=0.3)
# Plot 3: Worker Counts
ax3.plot(time, p_thpt, "b-", label="Actual Prefill Throughput", linewidth=2)
ax3.plot(
time, safe_p_thpt, "b--", label="Safe Prefill Throughput Limit", linewidth=2
)
ax3_right = ax3.twinx()
ax3_right.plot(time, num_p, "c-", label="Prefill Workers", linewidth=2, marker="o")
ax3_right.set_ylabel("Number of Workers")
lines1, labels1 = ax3.get_legend_handles_labels()
lines2, labels2 = ax3_right.get_legend_handles_labels()
ax3.legend(lines1 + lines2, labels1 + labels2, loc="upper left")
ax3.set_xlabel("Time (s)")
ax3.set_ylabel("Throughput (tok/adjustment_interval)")
ax3.set_ylim(bottom=0)
ax3_right.set_ylabel("Number of Workers")
ax3_right.set_ylim(bottom=0)
ax3.set_title("Prefill Load and Workers")
ax3.grid(True, alpha=0.3)
# Plot 4: Throughput Comparison
ax4.plot(time, d_thpt, "r-", label="Actual Decode Throughput", linewidth=2)
ax4.plot(
time, safe_d_thpt, "r--", label="Safe Decode Throughput Limit", linewidth=2
)
ax4_right = ax4.twinx()
ax4_right.plot(
time, num_d, "orange", label="Decode Workers", linewidth=2, marker="o"
)
ax4_right.set_ylabel("Number of Workers")
lines1, labels1 = ax4.get_legend_handles_labels()
lines2, labels2 = ax4_right.get_legend_handles_labels()
ax4.legend(lines1 + lines2, labels1 + labels2, loc="upper left")
ax4.set_xlabel("Time (s)")
ax4.set_ylabel("Throughput (tok/adjustment_interval)")
ax4.set_ylim(bottom=0)
ax4_right.set_ylabel("Number of Workers")
ax4_right.set_ylim(bottom=0)
ax4.set_title("Decode Load and Workers")
ax4.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight")
plt.close()
...@@ -89,6 +89,12 @@ class ARIMAPredictor(BasePredictor): ...@@ -89,6 +89,12 @@ class ARIMAPredictor(BasePredictor):
if len(self.data_buffer) < self.minimum_data_points: if len(self.data_buffer) < self.minimum_data_points:
return self.get_last_value() return self.get_last_value()
# Check if all values are the same (constant data)
# pmdarima will predict 0 for constant data, we need to correct its prediction
if len(set(self.data_buffer)) == 1:
return self.data_buffer[0] # Return the constant value
try:
# Fit auto ARIMA model # Fit auto ARIMA model
self.model = pmdarima.auto_arima( self.model = pmdarima.auto_arima(
self.data_buffer, self.data_buffer,
...@@ -99,6 +105,10 @@ class ARIMAPredictor(BasePredictor): ...@@ -99,6 +105,10 @@ class ARIMAPredictor(BasePredictor):
# Make prediction # Make prediction
forecast = self.model.predict(n_periods=1) forecast = self.model.predict(n_periods=1)
return forecast[0] return forecast[0]
except Exception as e:
# Log the specific error for debugging
logger.warning(f"ARIMA prediction failed: {e}, using last value")
return self.get_last_value()
# Time-series forecasting model from Meta # Time-series forecasting model from Meta
......
...@@ -19,6 +19,7 @@ from dynamo.planner.utils.perf_interpolation import ( ...@@ -19,6 +19,7 @@ from dynamo.planner.utils.perf_interpolation import (
PrefillInterpolator, PrefillInterpolator,
) )
from dynamo.planner.utils.prometheus import PrometheusAPIClient from dynamo.planner.utils.prometheus import PrometheusAPIClient
from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_mooncake
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -52,9 +53,17 @@ class Metrics: ...@@ -52,9 +53,17 @@ class Metrics:
class Planner: class Planner:
def __init__(self, runtime: DistributedRuntime, args: argparse.Namespace): def __init__(
self.runtime = runtime self,
runtime: Optional[DistributedRuntime],
args: argparse.Namespace,
dryrun: bool = False,
):
self.args = args self.args = args
self.dryrun = dryrun
if not self.dryrun:
self.runtime = runtime
self.namespace = SLAPlannerDefaults.namespace self.namespace = SLAPlannerDefaults.namespace
if not args.no_operation: if not args.no_operation:
...@@ -80,6 +89,7 @@ class Planner: ...@@ -80,6 +89,7 @@ class Planner:
self.prefill_interpolator = PrefillInterpolator(args.profile_results_dir) self.prefill_interpolator = PrefillInterpolator(args.profile_results_dir)
self.decode_interpolator = DecodeInterpolator(args.profile_results_dir) self.decode_interpolator = DecodeInterpolator(args.profile_results_dir)
if not self.dryrun:
self.prefill_client = None self.prefill_client = None
self.workers_client = None self.workers_client = None
self.p_endpoints = [] # type: ignore self.p_endpoints = [] # type: ignore
...@@ -88,16 +98,16 @@ class Planner: ...@@ -88,16 +98,16 @@ class Planner:
self.last_adjustment_time = time.time() self.last_adjustment_time = time.time()
self.last_metrics = Metrics() self.last_metrics = Metrics()
self.p_correction_factor = 1.0
self.d_correction_factor = 1.0
self.no_correction = args.no_correction
self.prometheus_port = args.prometheus_port self.prometheus_port = args.prometheus_port
# Initialize Prometheus metrics # Initialize Prometheus metrics
# TODO: use proper naming # TODO: use proper naming
self.num_p_workers_gauge = Gauge("num_p_workers", "Number of prefill workers") self.num_p_workers_gauge = Gauge(
self.num_d_workers_gauge = Gauge("num_d_workers", "Number of decode workers") "num_p_workers", "Number of prefill workers"
)
self.num_d_workers_gauge = Gauge(
"num_d_workers", "Number of decode workers"
)
# Start Prometheus HTTP server if port is specified # Start Prometheus HTTP server if port is specified
if self.prometheus_port != 0: if self.prometheus_port != 0:
...@@ -109,7 +119,17 @@ class Planner: ...@@ -109,7 +119,17 @@ class Planner:
except Exception as e: except Exception as e:
logger.error(f"Failed to start Prometheus metrics server: {e}") logger.error(f"Failed to start Prometheus metrics server: {e}")
self.p_correction_factor = 1.0
self.d_correction_factor = 1.0
if self.dryrun:
self.no_correction = True
else:
self.no_correction = args.no_correction
async def get_workers_info(self): async def get_workers_info(self):
if self.runtime is None:
raise RuntimeError("Runtime is not initialized")
try: try:
if self.prefill_client is None: if self.prefill_client is None:
self.prefill_client = ( self.prefill_client = (
...@@ -204,43 +224,7 @@ class Planner: ...@@ -204,43 +224,7 @@ class Planner:
self.isl_predictor.add_data_point(self.last_metrics.isl) self.isl_predictor.add_data_point(self.last_metrics.isl)
self.osl_predictor.add_data_point(self.last_metrics.osl) self.osl_predictor.add_data_point(self.last_metrics.osl)
async def make_adjustments(self): def predict_load(self):
if not self.no_correction:
try:
# Skip adjustment if no traffic
if not self.last_metrics.is_valid():
logger.info(
"Metrics contain None or NaN values (no active requests), skipping adjustment"
)
return
self.p_endpoints, self.d_endpoints = await self.get_workers_info()
logger.info(
f"Number of prefill workers: {len(self.p_endpoints)}, number of decode workers: {len(self.d_endpoints)}"
)
# first correct the prediction correction factor
# for TTFT, we expect the correction factor to be << 1 due to queuing delay
expect_ttft = self.prefill_interpolator.interpolate_ttft(
self.last_metrics.isl
)
self.p_correction_factor = self.last_metrics.ttft / expect_ttft
# for ITL, we expect the correction factor to be close to 1
expect_itl = self.decode_interpolator.interpolate_itl(
concurrency=self.last_metrics.num_req # type: ignore
/ len(self.d_endpoints)
* self.last_metrics.request_duration # type: ignore
/ self.args.adjustment_interval,
context_length=self.last_metrics.isl + self.last_metrics.osl / 2, # type: ignore
)
self.d_correction_factor = self.last_metrics.itl / expect_itl
logger.info(
f"Correction factors: TTFT: {self.p_correction_factor:.3f}, ITL: {self.d_correction_factor:.3f}"
)
except Exception as e:
logger.error(f"Failed to correct prediction factors: {e}")
return
try: try:
# predict the next load # predict the next load
next_num_req = self.num_req_predictor.predict_next() next_num_req = self.num_req_predictor.predict_next()
...@@ -249,11 +233,29 @@ class Planner: ...@@ -249,11 +233,29 @@ class Planner:
logger.info( logger.info(
f"Predicted load: num_req={next_num_req:.2f}, isl={next_isl:.2f}, osl={next_osl:.2f}" f"Predicted load: num_req={next_num_req:.2f}, isl={next_isl:.2f}, osl={next_osl:.2f}"
) )
return next_num_req, next_isl, next_osl
except Exception as e: except Exception as e:
logger.error(f"Failed to predict load: {e}") logger.error(f"Failed to predict load: {e}")
return return None, None, None
try: def dryrun_observe_metrics(self, num_req: int, isl_avg: float, osl_avg: float):
self.num_req_predictor.add_data_point(num_req)
self.isl_predictor.add_data_point(isl_avg)
self.osl_predictor.add_data_point(osl_avg)
def _compute_replica_requirements(
self, next_num_req: float, next_isl: float, next_osl: float
) -> tuple[int, int]:
"""Compute the number of prefill and decode replicas needed based on predicted load.
Args:
next_num_req: Predicted number of requests
next_isl: Predicted input sequence length
next_osl: Predicted output sequence length
Returns:
tuple[int, int]: Number of prefill and decode replicas needed
"""
# compute how many replicas are needed for prefill # compute how many replicas are needed for prefill
# here we assume the prefill bias is purely due to request queueing # here we assume the prefill bias is purely due to request queueing
# and we increase the number of prefill replicas linearly to account for the queueing delay # and we increase the number of prefill replicas linearly to account for the queueing delay
...@@ -323,6 +325,53 @@ class Planner: ...@@ -323,6 +325,53 @@ class Planner:
logger.warning( logger.warning(
f"Total number of GPUs required ({total_gpu_required}) exceeds the max GPU budget ({self.args.max_gpu_budget}), scaling down to {next_num_p} prefill and {next_num_d} decode replicas" f"Total number of GPUs required ({total_gpu_required}) exceeds the max GPU budget ({self.args.max_gpu_budget}), scaling down to {next_num_p} prefill and {next_num_d} decode replicas"
) )
return next_num_p, next_num_d
async def make_adjustments(self):
# Skip adjustment if no traffic
if not self.last_metrics.is_valid():
logger.info(
"Metrics contain None or NaN values (no active requests), skipping adjustment"
)
return
if not self.no_correction:
try:
self.p_endpoints, self.d_endpoints = await self.get_workers_info()
logger.info(
f"Number of prefill workers: {len(self.p_endpoints)}, number of decode workers: {len(self.d_endpoints)}"
)
# first correct the prediction correction factor
# for TTFT, we expect the correction factor to be << 1 due to queuing delay
expect_ttft = self.prefill_interpolator.interpolate_ttft(
self.last_metrics.isl
)
self.p_correction_factor = self.last_metrics.ttft / expect_ttft
# for ITL, we expect the correction factor to be close to 1
expect_itl = self.decode_interpolator.interpolate_itl(
concurrency=self.last_metrics.num_req # type: ignore
/ len(self.d_endpoints)
* self.last_metrics.request_duration # type: ignore
/ self.args.adjustment_interval,
context_length=self.last_metrics.isl + self.last_metrics.osl / 2, # type: ignore
)
self.d_correction_factor = self.last_metrics.itl / expect_itl
logger.info(
f"Correction factors: TTFT: {self.p_correction_factor:.3f}, ITL: {self.d_correction_factor:.3f}"
)
except Exception as e:
logger.error(f"Failed to correct prediction factors: {e}")
return
next_num_req, next_isl, next_osl = self.predict_load()
if next_num_req is not None and next_isl is not None and next_osl is not None:
try:
next_num_p, next_num_d = self._compute_replica_requirements(
next_num_req, next_isl, next_osl
)
except Exception as e: except Exception as e:
logger.error(f"Failed to compute number of replicas: {e}") logger.error(f"Failed to compute number of replicas: {e}")
return return
...@@ -358,6 +407,123 @@ class Planner: ...@@ -358,6 +407,123 @@ class Planner:
# sleep for a while to avoid busy-waiting but not too long to miss the next adjustment # sleep for a while to avoid busy-waiting but not too long to miss the next adjustment
await asyncio.sleep(self.args.adjustment_interval / 10) await asyncio.sleep(self.args.adjustment_interval / 10)
def dryrun_run(self):
"""Run planner in dry-run mode with dataset"""
metrics = extract_metrics_from_mooncake(
self.args.dataset, self.args.adjustment_interval
)
def compute_safe_p_thpt(num_p: int, isl: float, ttft: float):
"""safe throughput is maximum throughput that the engine can handle given the TTFT SLA"""
actual_ttft = self.prefill_interpolator.interpolate_ttft(isl)
if actual_ttft > ttft:
return 0
else:
return num_p * self.prefill_interpolator.interpolate_thpt_per_gpu(isl)
def compute_safe_d_thpt(num_d: int, isl: float, osl: float, itl: float):
"""safe throughput is maximum throughput that the engine can handle given the ITL SLA"""
(
pred_decode_thpt_per_gpu,
actual_itl,
_,
) = self.decode_interpolator.find_best_throughput_per_gpu(
itl=itl, context_length=isl + osl / 2
)
if actual_itl > itl:
return 0
else:
return num_d * pred_decode_thpt_per_gpu
time = [0]
rr = [metrics[0]["request_count"]]
est_rr = [metrics[0]["request_count"]]
isl = [metrics[0]["avg_isl"]]
est_isl = [metrics[0]["avg_isl"]]
osl = [metrics[0]["avg_osl"]]
est_osl = [metrics[0]["avg_osl"]]
num_p = [self.args.start_num_p]
p_thpt = [metrics[0]["request_count"] * metrics[0]["avg_isl"]]
safe_p_thpt = [
compute_safe_p_thpt(
self.args.start_num_p, metrics[0]["avg_isl"], self.args.ttft
)
* self.args.adjustment_interval
]
num_d = [self.args.start_num_d]
d_thpt = [metrics[0]["request_count"] * metrics[0]["avg_osl"]]
safe_d_thpt = [
compute_safe_d_thpt(
self.args.start_num_d,
metrics[0]["avg_isl"],
metrics[0]["avg_osl"],
self.args.itl,
)
* self.args.adjustment_interval
]
self.dryrun_observe_metrics(
metrics[0]["request_count"], metrics[0]["avg_isl"], metrics[0]["avg_osl"]
)
for metric in metrics[1:]:
# update time
time.append(time[-1] + self.args.adjustment_interval)
# load prediction
_est_rr, _est_isl, _est_osl = self.predict_load()
est_rr.append(_est_rr)
est_isl.append(_est_isl)
est_osl.append(_est_osl)
# compute num_p and num_d
_num_p, _num_d = self._compute_replica_requirements(
_est_rr, _est_isl, _est_osl
)
num_p.append(_num_p)
num_d.append(_num_d)
# update load predictor
self.dryrun_observe_metrics(
metric["request_count"], metric["avg_isl"], metric["avg_osl"]
)
# fill in ground truth
rr.append(metric["request_count"])
isl.append(metric["avg_isl"])
osl.append(metric["avg_osl"])
p_thpt.append(rr[-1] * isl[-1])
d_thpt.append(rr[-1] * osl[-1])
safe_p_thpt.append(
compute_safe_p_thpt(num_p[-1], isl[-1], self.args.ttft)
* self.args.adjustment_interval
)
safe_d_thpt.append(
compute_safe_d_thpt(num_d[-1], isl[-1], osl[-1], self.args.itl)
* self.args.adjustment_interval
)
# plot the results
from dynamo.planner.utils.dryrun_plot_utils import create_dryrun_plot
create_dryrun_plot(
time=time,
rr=rr,
est_rr=est_rr,
isl=isl,
est_isl=est_isl,
osl=osl,
est_osl=est_osl,
num_p=num_p,
p_thpt=p_thpt,
safe_p_thpt=safe_p_thpt,
num_d=num_d,
d_thpt=d_thpt,
safe_d_thpt=safe_d_thpt,
output_path=self.args.output_plot,
)
async def start_sla_planner(runtime: DistributedRuntime, args: argparse.Namespace): async def start_sla_planner(runtime: DistributedRuntime, args: argparse.Namespace):
planner = Planner(runtime, args) planner = Planner(runtime, args)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
from collections import defaultdict
from typing import Any, Dict, List
def extract_metrics_from_mooncake(
dataset: str, adjustment_interval: int
) -> List[Dict[str, Any]]:
"""
Extract metrics from mooncake-style JSONL data.
Args:
dataset: Path to the JSONL file containing mooncake trace data
adjustment_interval: Time interval in seconds to group requests
Returns:
List of dictionaries containing metrics for each interval:
- interval_start: Start time of the interval (in seconds)
- request_count: Total number of requests in the interval
- avg_isl: Average input sequence length
- avg_osl: Average output sequence length
"""
# Read and parse JSONL data from file
records = []
with open(dataset, "r") as f:
for line in f:
if line.strip():
records.append(json.loads(line))
# Group records by adjustment interval
interval_groups = defaultdict(list)
for record in records:
timestamp_ms = record["timestamp"]
# Convert milliseconds to seconds and find the interval
timestamp_sec = timestamp_ms / 1000
interval_start = int(timestamp_sec // adjustment_interval) * adjustment_interval
interval_groups[interval_start].append(record)
# Compute metrics for each interval
metrics = []
for interval_start in sorted(interval_groups.keys()):
records_in_interval = interval_groups[interval_start]
# Calculate metrics
request_count = len(records_in_interval)
# Calculate average ISL and OSL
total_isl = sum(record["input_length"] for record in records_in_interval)
total_osl = sum(record["output_length"] for record in records_in_interval)
avg_isl = total_isl / request_count if request_count > 0 else 0
avg_osl = total_osl / request_count if request_count > 0 else 0
metrics.append(
{
"interval_start": interval_start,
"request_count": request_count,
"avg_isl": avg_isl,
"avg_osl": avg_osl,
}
)
return metrics
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dynamo.planner.utils.argparse import create_sla_planner_parser
from dynamo.planner.utils.planner_core import Planner
logger = logging.getLogger(__name__)
if __name__ == "__main__":
parser = create_sla_planner_parser()
parser.add_argument(
"--dataset", type=str, required=True, help="Path to the jsonl dataset file"
)
parser.add_argument(
"--start-num-p",
type=int,
default=1,
help="Number of prefill workers to start with",
)
parser.add_argument(
"--start-num-d",
type=int,
default=1,
help="Number of decode workers to start with",
)
parser.add_argument(
"--output-plot",
type=str,
default="dryrun_plot.png",
help="Path to the output plot file",
)
args = parser.parse_args()
planner = Planner(None, args, dryrun=True)
planner.dryrun_run()
...@@ -44,20 +44,87 @@ For example, to test the interpolator for `nvidia/Llama-3.1-8B-Instruct-FP8` on ...@@ -44,20 +44,87 @@ For example, to test the interpolator for `nvidia/Llama-3.1-8B-Instruct-FP8` on
python components/planner/src/dynamo/planner/utils/perf_interpolation.py \ python components/planner/src/dynamo/planner/utils/perf_interpolation.py \
--profile_results_dir tests/planner/profiling_results/H200_TP1P_TP1D/ \ --profile_results_dir tests/planner/profiling_results/H200_TP1P_TP1D/ \
--isl 3000 \ --isl 3000 \
--osl 150 \ --osl 300 \
--ttft 0.1 \ --ttft 0.1 \
--itl 0.01 --itl 0.01
> ISL=3000, OSL=150 > ISL=3000, OSL=300
> TTFT=0.1s, ITL=0.01s > TTFT=0.1s, ITL=0.01s
> Using profile results from tests/planner/profiling_results/H200_TP1P_TP1D/ > Using profile results from tests/planner/profiling_results/H200_TP1P_TP1D/
> >
> Interpolating prefill performance ... > Interpolating prefill performance ...
> Estimated TTFT=0.027s <= target TTFT=0.100s. Requests can queue 0.073s maximally while meeting TTFT SLA. > Estimated TTFT=0.027s <= target TTFT=0.100s. Requests can queue 0.073s maximally while meeting TTFT SLA.
> Estimated throughput: 110893.48 tokens/s/gpu. Request rate at 36.96 requests/s will saturate one GPU. > Estimated throughput: 110893.48 tokens/s/gpu. Request rate at 36.96 requests/s will saturate one GPU.
>
> Interpolating decode performance ... Interpolating decode performance ...
> Average context length: isl + osl/2 = 3075. > Average context length: isl + osl/2 = 3150.
> Estimated ITL=0.0098s <= target ITL=0.0100s at 33.33% active kv usage. > Estimated ITL=0.0098s <= target ITL=0.0100s at 36.36% active kv usage.
> Estimated throughput: 10226.60 token/s/gpu. Request rate at 68.18 requests/s will saturate one GPU. > Estimated throughput: 10009.88 token/s/gpu. Request rate at 33.37 requests/s will saturate one GPU.
``` ```
## Generating Load Dataset
We provide a tool to generate load dataset with varying request rate. More details can be found in [sin_load_generator](../../benchmarks/sin_load_generator/README.md).
From previous interpolator testing, ISL 3000 and OSL 300 can handle ~30 request/s/gpu for both prefill and decode.
To test planner's performance for different request rates, we can generate a load dataset with request rate varying between 20 to 80 request/s.
For TP1 H200 engine, planner should scale between 1P1D and 3P3D.
```bash
python benchmarks/sin_load_generator/sin_synth.py \
--time-duration 1800 \
--request-rate-min 20 \
--request-rate-max 80 \
--request-rate-period 600 \
--isl1 3000 \
--osl1 300 \
--isl2 3000 \
--osl2 300 \
--output-file rr-20-80_i3000o300.jsonl
```
The dataset starts at 20 requests/s, increases to 80 requests/s at t=300s, decreases back to 20 requests/s at t=600s, and repeats.
The total duration is 30 minutes or 1800 seconds.
## Planner Dry Run
Before testing SLA planner on real deployments, we provide a dry run feature to test the autoscaling behavior on a given dataset. Specifically, in dry run mode,
- The load predictor will be tested. However, the load metrics will be different from the real deployment because the actual OSL is only known after the requests are processed.
- There will be no SLA predictions. Instead, sla planner will show the safe throughput limit that will ensure the requests can be processed within the SLA.
- The correction factor will be disabled because there is no SLA metrics as reference.
To dry run SLA planner,
```bash
python components/planner/test/planner_sla_dryrun.py \
--<SLA planner arguments> \
--dry-run \
--start-num-p <num_prefill_workers_to_start_with> \
--start-num-d <num_decode_workers_to_start_with> \
--output-plot <path_to_output_plot>
```
For example, to dry run SLA planner for the previous FP8 8B on H200 using the generated `rr-20-80_i3000o300.jsonl` dataset,
```bash
python components/planner/test/planner_sla_dryrun.py \
--ttft 0.1 \
--itl 0.01 \
--adjustment-interval 60 \
--profile-results-dir tests/planner/profiling_results/H200_TP1P_TP1D/ \
--dataset rr-20-80_i3000o300.jsonl \
--start-num-p 1 \
--start-num-d 1 \
--output-plot dryrun_plot.png
```
Below is the dryrun result:
![Dryrun Plot](./figures/dryrun_plot.png)
The first plot shows the actual request rate and the predicted request rate (in the unit of requests/adjustment_interval).
The second plot shows the actual ISL/OSL and the predicted ISL/OSL. The first two plots are useful when tuning the performance of the load predictor.
The third plot shows the actual prefill throughput, number of prefill workers that planner scales, and the safe throughput limit with the number of prefill workers. If the actual throughput is below the safe throughput limit, the deployment has the capacity to adhere the TTFT SLA. Note that in the real deployment, due to other factors such as queueing, load balancing, KV cache transfer latency, and ISL variance, it is not guaranteed that the actual deployment can adhere the TTFT SLA.
The fourth plot, similar to the third plot, shows the actual decode throughput, number of decode workers that planner scales, and the safe throughput limit with the number of decode workers. If the actual throughput is below the safe throughput limit, the deployment has the capacity to adhere the ITL SLA. Note that in the real deployment, due to other factors such as load balancing and OSL variance, it is not guaranteed that the actual deployment can adhere the ITL SLA.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment