Unverified Commit ef96fa3f authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Benchmark][2/2] Use spline interpolation to tune SLA variables (#32095)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 2a4dbe24
......@@ -129,10 +129,10 @@ vllm bench sweep serve_sla \
The algorithm for adjusting the SLA variable is as follows:
1. Run the benchmark with infinite QPS, and use the corresponding metrics to determine the initial value of the variable.
- For example, the initial request rate is set to the concurrency under infinite QPS.
2. If the SLA is still satisfied, keep doubling the value until the SLA is no longer satisfied. This gives a relatively narrow window that contains the point where the SLA is barely satisfied.
3. Apply binary search over the window to find the maximum value that still satisfies the SLA.
1. Run the benchmark once with maximum possible QPS, and once with minimum possible QPS. For each run, calculate the distance of the SLA metrics from their targets, resulting in data points of QPS vs SLA distance.
2. Perform spline interpolation between the data points to estimate the QPS that results in zero SLA distance.
3. Run the benchmark with the estimated QPS and add the resulting data point to the history.
4. Repeat Steps 2 and 3 until the maximum QPS that passes SLA and the minimum QPS that fails SLA in the history are close enough to each other.
!!! important
SLA tuning is applied over each combination of `--serve-params`, `--bench-params`, and `--sla-params`.
......
......@@ -5,7 +5,7 @@ from pathlib import Path
from unittest.mock import patch
from vllm.benchmarks.sweep.param_sweep import ParameterSweepItem
from vllm.benchmarks.sweep.serve_sla import _estimate_sla_bounds, _find_sla_value
from vllm.benchmarks.sweep.serve_sla import solve_sla
from vllm.benchmarks.sweep.server import ServerProcess
from vllm.benchmarks.sweep.sla_sweep import (
SLACriterionBase,
......@@ -39,18 +39,70 @@ def _set_return_value(
return patch("vllm.benchmarks.sweep.serve_sla.run_sla", side_effect=mock_run_sla)
def _var2metric_identity(bench_comb):
return [{"request_throughput": float(bench_comb["request_rate"])}]
def _var2metric_linear():
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
y = x
return [{"request_throughput": y}]
def _run_estimate_sla_bounds(
return wrapped
def _var2metric_concave(elbow_point: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
if x < elbow_point:
y = 0.5 * (x - elbow_point) + elbow_point
else:
y = 1.5 * (x - elbow_point) + elbow_point
return [{"request_throughput": y}]
return wrapped
def _var2metric_convex(elbow_point: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
if x < elbow_point:
y = 1.5 * (x - elbow_point) + elbow_point
else:
y = 0.5 * (x - elbow_point) + elbow_point
return [{"request_throughput": y}]
return wrapped
def _var2metric_quadratic(y_intercept: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
y = y_intercept + 0.1 * x**2
return [{"request_throughput": y}]
return wrapped
def _var2metric_sqrt(y_intercept: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
y = y_intercept + 10 * x**0.5
return [{"request_throughput": y}]
return wrapped
def _run_solve_sla(
var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
criterion: SLACriterionBase,
init_value: int,
max_value: int,
min_value: int = 1,
max_value: int = 100,
):
with _set_return_value(var2metric):
return _estimate_sla_bounds(
result = solve_sla(
server=None,
bench_cmd=[],
serve_comb=ParameterSweepItem(),
......@@ -60,143 +112,129 @@ def _run_estimate_sla_bounds(
num_runs=1,
dry_run=False,
sla_variable="request_rate",
init_value=init_value,
max_value=max_value,
sla_min_value=min_value,
sla_max_value=max_value,
)
assert result is not None
return result
def test_estimate_sla_bounds_le():
sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds(
_var2metric_identity,
def test_solve_linear_sla_le():
sla_data, history = _run_solve_sla(
_var2metric_linear(),
SLALessThanOrEqualTo(target=32),
init_value=1,
max_value=100,
)
assert max_passing == 32
assert min_failing == 64
assert history.get_max_passing() == 32
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True,
2: True,
4: True,
8: True,
16: True,
32: True,
64: False,
33: False,
}
def test_estimate_sla_bounds_lt():
sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds(
_var2metric_identity,
def test_solve_linear_sla_lt():
sla_data, history = _run_solve_sla(
_var2metric_linear(),
SLALessThan(target=32),
init_value=1,
max_value=100,
)
assert max_passing == 16
assert min_failing == 32
assert history.get_max_passing() == 31
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True,
2: True,
4: True,
8: True,
16: True,
31: True,
32: False,
}
def test_estimate_sla_bounds_oob():
sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds(
_var2metric_identity,
def test_solve_linear_sla_oob():
sla_data, history = _run_solve_sla(
_var2metric_linear(),
SLALessThanOrEqualTo(target=32),
init_value=64,
max_value=128,
min_value=64,
)
assert max_passing == 0
assert min_failing == 64
assert history.get_max_passing() == 64
assert history.get_min_failing() == 64
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
64: False,
}
def _run_test_find_sla_value_le(
var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
criterion: SLACriterionBase,
min_value: int,
max_value: int,
):
with _set_return_value(var2metric):
return _find_sla_value(
server=None,
bench_cmd=[],
serve_comb=ParameterSweepItem(),
bench_comb=ParameterSweepItem(),
sla_comb=SLASweepItem({"request_throughput": criterion}),
base_path=Path(""),
num_runs=1,
dry_run=False,
sla_variable="request_rate",
min_value=min_value,
max_value=max_value,
)
def test_solve_concave_sla_le():
sla_data, history = _run_solve_sla(
_var2metric_concave(elbow_point=32),
SLALessThanOrEqualTo(target=24),
)
assert history.get_max_passing() == 16
def test_find_sla_value_le():
sla_data, sla_value, history = _run_test_find_sla_value_le(
_var2metric_identity,
SLALessThanOrEqualTo(target=50.0),
min_value=32,
max_value=64,
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True,
7: True,
13: True,
15: True,
16: True,
17: False,
}
def test_solve_convex_sla_le():
sla_data, history = _run_solve_sla(
_var2metric_convex(elbow_point=32),
SLALessThanOrEqualTo(target=24),
)
assert sla_value == 50
assert history.get_max_passing() == 26
assert {val: margin <= 0 for val, margin in history.items()} == {
48: True,
56: False,
52: False,
50: True,
51: False,
100: False,
1: True,
48: False,
30: False,
24: True,
26: True,
27: False,
}
def test_find_sla_value_lt():
sla_data, sla_value, history = _run_test_find_sla_value_le(
_var2metric_identity,
SLALessThan(target=50.0),
min_value=32,
max_value=64,
def test_solve_quadratic_sla_le():
sla_data, history = _run_solve_sla(
_var2metric_quadratic(y_intercept=10),
SLALessThanOrEqualTo(target=50),
)
assert sla_value == 49
assert history.get_max_passing() == 20
assert {val: margin <= 0 for val, margin in history.items()} == {
48: True,
56: False,
52: False,
50: False,
49: True,
100: False,
1: True,
4: True,
20: True,
21: False,
}
def test_find_sla_value_oob():
sla_data, sla_value, history = _run_test_find_sla_value_le(
_var2metric_identity,
SLALessThanOrEqualTo(target=50.0),
min_value=64,
max_value=128,
def test_solve_sqrt_sla_le():
sla_data, history = _run_solve_sla(
_var2metric_sqrt(y_intercept=10),
SLALessThanOrEqualTo(target=100),
)
assert sla_value == 64
assert history.get_max_passing() == 81
assert {val: margin <= 0 for val, margin in history.items()} == {
96: False,
80: False,
72: False,
68: False,
66: False,
65: False,
64: False,
100: False,
1: True,
89: False,
81: True,
82: False,
}
......@@ -3,14 +3,11 @@
import argparse
import contextlib
import json
import math
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import ClassVar, Literal, get_args
from typing_extensions import assert_never
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
......@@ -24,6 +21,15 @@ try:
except ImportError:
pd = PlaceholderModule("pandas")
try:
from scipy.interpolate import PchipInterpolator
except ImportError:
PchipInterpolator = (
PlaceholderModule("scipy")
.placeholder_attr("interpolate")
.placeholder_attr("PchipInterpolator")
)
def _get_sla_base_path(
output_dir: Path,
......@@ -118,89 +124,36 @@ def run_sla(
SLAVariable = Literal["request_rate", "max_concurrency"]
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
request_throughput = float(run_data["request_throughput"]) # type: ignore
if sla_variable == "request_rate":
return request_throughput
if sla_variable == "max_concurrency":
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
return request_throughput * mean_latency_ms / 1000
class SLAHistory(dict[int, float]):
def __init__(self, min_value: int, max_value: int) -> None:
super().__init__()
assert_never(sla_variable)
self.min_value = min_value
self.max_value = max_value
def get_xy(self) -> tuple[list[int], list[float]]:
xs = list[int]()
ys = list[float]()
for x, y in sorted(self.items()):
xs.append(x)
ys.append(y)
def _estimate_sla_bounds(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
sla_comb: SLASweepItem,
base_path: Path,
num_runs: int,
dry_run: bool,
sla_variable: SLAVariable,
init_value: int,
max_value: int,
):
sla_data = list[dict[str, object]]()
val: int = init_value
assert val > 0
history = dict[int, float]()
while True:
print(f"Testing {sla_variable}: {val} req/s")
return xs, ys
iter_data = run_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb | {sla_variable: val},
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val),
num_runs=num_runs,
dry_run=dry_run,
def get_max_passing(self) -> float:
return max(
(val for val, margin in self.items() if margin <= 0),
default=self.min_value,
)
assert iter_data is not None
sla_data.extend(iter_data)
iter_data_mean = {
k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore
for k in sla_comb
}
sla_margins = [
criterion.print_and_compute_margin(iter_data_mean, k)
for k, criterion in sla_comb.items()
]
margin = max(sla_margins)
history[val] = margin
if margin <= 0:
print("SLA criteria are met.")
val *= 2
else:
print("SLA criteria are not met.")
break
if val >= max_value:
break
max_passing = max(
(val for val, margin in history.items() if margin <= 0),
default=0,
)
min_failing = min(
(val for val, margin in history.items() if margin > 0),
default=max_value,
)
return sla_data, (max_passing, min_failing), history
def get_min_failing(self) -> float:
return min(
(val for val, margin in self.items() if margin > 0),
default=self.max_value,
)
def _find_sla_value(
def solve_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
......@@ -211,18 +164,33 @@ def _find_sla_value(
num_runs: int,
dry_run: bool,
sla_variable: SLAVariable,
min_value: int,
max_value: int,
sla_min_value: int = 1,
sla_max_value: int = 8192, # The value that represents infinite QPS
):
sla_data = list[dict[str, object]]()
left: int = min_value
right: int = max_value
history = dict[int, float]()
while True:
val = (left + right) // 2
history = SLAHistory(min_value=sla_min_value, max_value=sla_max_value)
# NOTE: We don't use equality here to be more robust against noisy results
while history.get_max_passing() + 1 < history.get_min_failing():
if len(history) == 0:
val = sla_max_value
elif len(history) == 1:
val = sla_min_value
else:
spl = PchipInterpolator(*history.get_xy(), extrapolate=False)
spl_roots = spl.solve()
if len(spl_roots) == 0:
# Fallback to binary search
val = int((history.get_max_passing() + history.get_min_failing()) / 2)
else:
val = int(spl_roots[0])
if val in history:
# Cover both sides (floor and ceil) of the root to be sure
# that it is indeed the target value
val += 1
val = max(sla_min_value, min(val, sla_max_value))
print(f"Testing {sla_variable}: {val} req/s")
iter_data = run_sla(
......@@ -234,8 +202,9 @@ def _find_sla_value(
num_runs=num_runs,
dry_run=dry_run,
)
if iter_data is None:
return None
assert iter_data is not None
sla_data.extend(iter_data)
iter_data_mean = {
......@@ -247,20 +216,14 @@ def _find_sla_value(
criterion.print_and_compute_margin(iter_data_mean, k)
for k, criterion in sla_comb.items()
]
margin = max(sla_margins)
history[val] = margin
history[val] = margin = max(sla_margins)
if margin <= 0:
print("SLA criteria are met.")
left = val
print(f"SLA criteria are met. ({margin=:.2f})")
else:
print("SLA criteria are not met.")
right = val
if right - left <= 1 and left in history:
break
print(f"SLA criteria are not met. ({margin=:.2f})")
return sla_data, left, history
return sla_data, history
def search_sla(
......@@ -271,7 +234,6 @@ def search_sla(
bench_comb: ParameterSweepItem,
sla_comb: SLASweepItem,
sla_variable: SLAVariable,
sla_inf_value: int = 65536, # The value that represents infinite QPS
base_path: Path,
num_runs: int,
dry_run: bool,
......@@ -279,43 +241,7 @@ def search_sla(
print("[SLA START]")
print(f"SLA criteria: {sla_comb.as_text()}")
sla_data_0 = run_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb | {sla_variable: sla_inf_value},
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, sla_inf_value),
num_runs=num_runs,
dry_run=dry_run,
)
if sla_data_0 is None:
assert dry_run
print("Omitting SLA search.")
print("[SLA END]")
return None
sla_init_value = math.ceil(
sum(_estimate_sla_value(item, sla_variable) for item in sla_data_0)
/ len(sla_data_0)
)
print(f"Initial {sla_variable} to search: {sla_init_value} req/s.")
sla_data_1, (sla_min, sla_max), _ = _estimate_sla_bounds(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
sla_comb=sla_comb,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
sla_variable=sla_variable,
init_value=sla_init_value,
max_value=sla_inf_value,
)
print(f"Range of {sla_variable} to search: [{sla_min}, {sla_max}] req/s.")
sla_data_2, sla_value, _ = _find_sla_value(
result = solve_sla(
server,
bench_cmd,
serve_comb=serve_comb,
......@@ -325,11 +251,15 @@ def search_sla(
num_runs=num_runs,
dry_run=dry_run,
sla_variable=sla_variable,
min_value=sla_min,
max_value=sla_max,
)
if result is None:
assert dry_run
print("Omitting SLA search.")
print("[SLA END]")
return
sla_data = sla_data_0 + sla_data_1 + sla_data_2
sla_data, sla_history = result
sla_value = sla_history.get_max_passing()
print(f"Maximum {sla_variable} for SLA: {sla_value} req/s.")
with _get_sla_iter_path(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment