Unverified Commit ef96fa3f authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Benchmark][2/2] Use spline interpolation to tune SLA variables (#32095)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 2a4dbe24
...@@ -129,10 +129,10 @@ vllm bench sweep serve_sla \ ...@@ -129,10 +129,10 @@ vllm bench sweep serve_sla \
The algorithm for adjusting the SLA variable is as follows: The algorithm for adjusting the SLA variable is as follows:
1. Run the benchmark with infinite QPS, and use the corresponding metrics to determine the initial value of the variable. 1. Run the benchmark once with maximum possible QPS, and once with minimum possible QPS. For each run, calculate the distance of the SLA metrics from their targets, resulting in data points of QPS vs SLA distance.
- For example, the initial request rate is set to the concurrency under infinite QPS. 2. Perform spline interpolation between the data points to estimate the QPS that results in zero SLA distance.
2. If the SLA is still satisfied, keep doubling the value until the SLA is no longer satisfied. This gives a relatively narrow window that contains the point where the SLA is barely satisfied. 3. Run the benchmark with the estimated QPS and add the resulting data point to the history.
3. Apply binary search over the window to find the maximum value that still satisfies the SLA. 4. Repeat Steps 2 and 3 until the maximum QPS that passes SLA and the minimum QPS that fails SLA in the history are close enough to each other.
!!! important !!! important
SLA tuning is applied over each combination of `--serve-params`, `--bench-params`, and `--sla-params`. SLA tuning is applied over each combination of `--serve-params`, `--bench-params`, and `--sla-params`.
......
...@@ -5,7 +5,7 @@ from pathlib import Path ...@@ -5,7 +5,7 @@ from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
from vllm.benchmarks.sweep.param_sweep import ParameterSweepItem from vllm.benchmarks.sweep.param_sweep import ParameterSweepItem
from vllm.benchmarks.sweep.serve_sla import _estimate_sla_bounds, _find_sla_value from vllm.benchmarks.sweep.serve_sla import solve_sla
from vllm.benchmarks.sweep.server import ServerProcess from vllm.benchmarks.sweep.server import ServerProcess
from vllm.benchmarks.sweep.sla_sweep import ( from vllm.benchmarks.sweep.sla_sweep import (
SLACriterionBase, SLACriterionBase,
...@@ -39,18 +39,70 @@ def _set_return_value( ...@@ -39,18 +39,70 @@ def _set_return_value(
return patch("vllm.benchmarks.sweep.serve_sla.run_sla", side_effect=mock_run_sla) return patch("vllm.benchmarks.sweep.serve_sla.run_sla", side_effect=mock_run_sla)
def _var2metric_identity(bench_comb): def _var2metric_linear():
return [{"request_throughput": float(bench_comb["request_rate"])}] def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
y = x
return [{"request_throughput": y}]
def _run_estimate_sla_bounds( return wrapped
def _var2metric_concave(elbow_point: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
if x < elbow_point:
y = 0.5 * (x - elbow_point) + elbow_point
else:
y = 1.5 * (x - elbow_point) + elbow_point
return [{"request_throughput": y}]
return wrapped
def _var2metric_convex(elbow_point: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
if x < elbow_point:
y = 1.5 * (x - elbow_point) + elbow_point
else:
y = 0.5 * (x - elbow_point) + elbow_point
return [{"request_throughput": y}]
return wrapped
def _var2metric_quadratic(y_intercept: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
y = y_intercept + 0.1 * x**2
return [{"request_throughput": y}]
return wrapped
def _var2metric_sqrt(y_intercept: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
y = y_intercept + 10 * x**0.5
return [{"request_throughput": y}]
return wrapped
def _run_solve_sla(
var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]], var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
criterion: SLACriterionBase, criterion: SLACriterionBase,
init_value: int, min_value: int = 1,
max_value: int, max_value: int = 100,
): ):
with _set_return_value(var2metric): with _set_return_value(var2metric):
return _estimate_sla_bounds( result = solve_sla(
server=None, server=None,
bench_cmd=[], bench_cmd=[],
serve_comb=ParameterSweepItem(), serve_comb=ParameterSweepItem(),
...@@ -60,143 +112,129 @@ def _run_estimate_sla_bounds( ...@@ -60,143 +112,129 @@ def _run_estimate_sla_bounds(
num_runs=1, num_runs=1,
dry_run=False, dry_run=False,
sla_variable="request_rate", sla_variable="request_rate",
init_value=init_value, sla_min_value=min_value,
max_value=max_value, sla_max_value=max_value,
) )
assert result is not None
return result
def test_estimate_sla_bounds_le():
sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds( def test_solve_linear_sla_le():
_var2metric_identity, sla_data, history = _run_solve_sla(
_var2metric_linear(),
SLALessThanOrEqualTo(target=32), SLALessThanOrEqualTo(target=32),
init_value=1,
max_value=100,
) )
assert max_passing == 32 assert history.get_max_passing() == 32
assert min_failing == 64
assert {val: margin <= 0 for val, margin in history.items()} == { assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True, 1: True,
2: True,
4: True,
8: True,
16: True,
32: True, 32: True,
64: False, 33: False,
} }
def test_estimate_sla_bounds_lt(): def test_solve_linear_sla_lt():
sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds( sla_data, history = _run_solve_sla(
_var2metric_identity, _var2metric_linear(),
SLALessThan(target=32), SLALessThan(target=32),
init_value=1,
max_value=100,
) )
assert max_passing == 16 assert history.get_max_passing() == 31
assert min_failing == 32
assert {val: margin <= 0 for val, margin in history.items()} == { assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True, 1: True,
2: True, 31: True,
4: True,
8: True,
16: True,
32: False, 32: False,
} }
def test_estimate_sla_bounds_oob(): def test_solve_linear_sla_oob():
sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds( sla_data, history = _run_solve_sla(
_var2metric_identity, _var2metric_linear(),
SLALessThanOrEqualTo(target=32), SLALessThanOrEqualTo(target=32),
init_value=64, min_value=64,
max_value=128,
) )
assert max_passing == 0 assert history.get_max_passing() == 64
assert min_failing == 64 assert history.get_min_failing() == 64
assert {val: margin <= 0 for val, margin in history.items()} == { assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
64: False, 64: False,
} }
def _run_test_find_sla_value_le( def test_solve_concave_sla_le():
var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]], sla_data, history = _run_solve_sla(
criterion: SLACriterionBase, _var2metric_concave(elbow_point=32),
min_value: int, SLALessThanOrEqualTo(target=24),
max_value: int, )
):
with _set_return_value(var2metric):
return _find_sla_value(
server=None,
bench_cmd=[],
serve_comb=ParameterSweepItem(),
bench_comb=ParameterSweepItem(),
sla_comb=SLASweepItem({"request_throughput": criterion}),
base_path=Path(""),
num_runs=1,
dry_run=False,
sla_variable="request_rate",
min_value=min_value,
max_value=max_value,
)
assert history.get_max_passing() == 16
def test_find_sla_value_le(): assert {val: margin <= 0 for val, margin in history.items()} == {
sla_data, sla_value, history = _run_test_find_sla_value_le( 100: False,
_var2metric_identity, 1: True,
SLALessThanOrEqualTo(target=50.0), 7: True,
min_value=32, 13: True,
max_value=64, 15: True,
16: True,
17: False,
}
def test_solve_convex_sla_le():
sla_data, history = _run_solve_sla(
_var2metric_convex(elbow_point=32),
SLALessThanOrEqualTo(target=24),
) )
assert sla_value == 50 assert history.get_max_passing() == 26
assert {val: margin <= 0 for val, margin in history.items()} == { assert {val: margin <= 0 for val, margin in history.items()} == {
48: True, 100: False,
56: False, 1: True,
52: False, 48: False,
50: True, 30: False,
51: False, 24: True,
26: True,
27: False,
} }
def test_find_sla_value_lt(): def test_solve_quadratic_sla_le():
sla_data, sla_value, history = _run_test_find_sla_value_le( sla_data, history = _run_solve_sla(
_var2metric_identity, _var2metric_quadratic(y_intercept=10),
SLALessThan(target=50.0), SLALessThanOrEqualTo(target=50),
min_value=32,
max_value=64,
) )
assert sla_value == 49 assert history.get_max_passing() == 20
assert {val: margin <= 0 for val, margin in history.items()} == { assert {val: margin <= 0 for val, margin in history.items()} == {
48: True, 100: False,
56: False, 1: True,
52: False, 4: True,
50: False, 20: True,
49: True, 21: False,
} }
def test_find_sla_value_oob(): def test_solve_sqrt_sla_le():
sla_data, sla_value, history = _run_test_find_sla_value_le( sla_data, history = _run_solve_sla(
_var2metric_identity, _var2metric_sqrt(y_intercept=10),
SLALessThanOrEqualTo(target=50.0), SLALessThanOrEqualTo(target=100),
min_value=64,
max_value=128,
) )
assert sla_value == 64 assert history.get_max_passing() == 81
assert {val: margin <= 0 for val, margin in history.items()} == { assert {val: margin <= 0 for val, margin in history.items()} == {
96: False, 100: False,
80: False, 1: True,
72: False, 89: False,
68: False, 81: True,
66: False, 82: False,
65: False,
64: False,
} }
...@@ -3,14 +3,11 @@ ...@@ -3,14 +3,11 @@
import argparse import argparse
import contextlib import contextlib
import json import json
import math
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import ClassVar, Literal, get_args from typing import ClassVar, Literal, get_args
from typing_extensions import assert_never
from vllm.utils.import_utils import PlaceholderModule from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem from .param_sweep import ParameterSweep, ParameterSweepItem
...@@ -24,6 +21,15 @@ try: ...@@ -24,6 +21,15 @@ try:
except ImportError: except ImportError:
pd = PlaceholderModule("pandas") pd = PlaceholderModule("pandas")
try:
from scipy.interpolate import PchipInterpolator
except ImportError:
PchipInterpolator = (
PlaceholderModule("scipy")
.placeholder_attr("interpolate")
.placeholder_attr("PchipInterpolator")
)
def _get_sla_base_path( def _get_sla_base_path(
output_dir: Path, output_dir: Path,
...@@ -118,89 +124,36 @@ def run_sla( ...@@ -118,89 +124,36 @@ def run_sla(
SLAVariable = Literal["request_rate", "max_concurrency"] SLAVariable = Literal["request_rate", "max_concurrency"]
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable): class SLAHistory(dict[int, float]):
request_throughput = float(run_data["request_throughput"]) # type: ignore def __init__(self, min_value: int, max_value: int) -> None:
if sla_variable == "request_rate": super().__init__()
return request_throughput
if sla_variable == "max_concurrency":
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
return request_throughput * mean_latency_ms / 1000
assert_never(sla_variable) self.min_value = min_value
self.max_value = max_value
def get_xy(self) -> tuple[list[int], list[float]]:
xs = list[int]()
ys = list[float]()
for x, y in sorted(self.items()):
xs.append(x)
ys.append(y)
def _estimate_sla_bounds( return xs, ys
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
sla_comb: SLASweepItem,
base_path: Path,
num_runs: int,
dry_run: bool,
sla_variable: SLAVariable,
init_value: int,
max_value: int,
):
sla_data = list[dict[str, object]]()
val: int = init_value
assert val > 0
history = dict[int, float]()
while True:
print(f"Testing {sla_variable}: {val} req/s")
iter_data = run_sla( def get_max_passing(self) -> float:
server, return max(
bench_cmd, (val for val, margin in self.items() if margin <= 0),
serve_comb=serve_comb, default=self.min_value,
bench_comb=bench_comb | {sla_variable: val},
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val),
num_runs=num_runs,
dry_run=dry_run,
) )
assert iter_data is not None def get_min_failing(self) -> float:
sla_data.extend(iter_data) return min(
(val for val, margin in self.items() if margin > 0),
iter_data_mean = { default=self.max_value,
k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore )
for k in sla_comb
}
sla_margins = [
criterion.print_and_compute_margin(iter_data_mean, k)
for k, criterion in sla_comb.items()
]
margin = max(sla_margins)
history[val] = margin
if margin <= 0:
print("SLA criteria are met.")
val *= 2
else:
print("SLA criteria are not met.")
break
if val >= max_value:
break
max_passing = max(
(val for val, margin in history.items() if margin <= 0),
default=0,
)
min_failing = min(
(val for val, margin in history.items() if margin > 0),
default=max_value,
)
return sla_data, (max_passing, min_failing), history
def _find_sla_value( def solve_sla(
server: ServerProcess | None, server: ServerProcess | None,
bench_cmd: list[str], bench_cmd: list[str],
*, *,
...@@ -211,18 +164,33 @@ def _find_sla_value( ...@@ -211,18 +164,33 @@ def _find_sla_value(
num_runs: int, num_runs: int,
dry_run: bool, dry_run: bool,
sla_variable: SLAVariable, sla_variable: SLAVariable,
min_value: int, sla_min_value: int = 1,
max_value: int, sla_max_value: int = 8192, # The value that represents infinite QPS
): ):
sla_data = list[dict[str, object]]() sla_data = list[dict[str, object]]()
history = SLAHistory(min_value=sla_min_value, max_value=sla_max_value)
left: int = min_value
right: int = max_value # NOTE: We don't use equality here to be more robust against noisy results
while history.get_max_passing() + 1 < history.get_min_failing():
history = dict[int, float]() if len(history) == 0:
val = sla_max_value
while True: elif len(history) == 1:
val = (left + right) // 2 val = sla_min_value
else:
spl = PchipInterpolator(*history.get_xy(), extrapolate=False)
spl_roots = spl.solve()
if len(spl_roots) == 0:
# Fallback to binary search
val = int((history.get_max_passing() + history.get_min_failing()) / 2)
else:
val = int(spl_roots[0])
if val in history:
# Cover both sides (floor and ceil) of the root to be sure
# that it is indeed the target value
val += 1
val = max(sla_min_value, min(val, sla_max_value))
print(f"Testing {sla_variable}: {val} req/s") print(f"Testing {sla_variable}: {val} req/s")
iter_data = run_sla( iter_data = run_sla(
...@@ -234,8 +202,9 @@ def _find_sla_value( ...@@ -234,8 +202,9 @@ def _find_sla_value(
num_runs=num_runs, num_runs=num_runs,
dry_run=dry_run, dry_run=dry_run,
) )
if iter_data is None:
return None
assert iter_data is not None
sla_data.extend(iter_data) sla_data.extend(iter_data)
iter_data_mean = { iter_data_mean = {
...@@ -247,20 +216,14 @@ def _find_sla_value( ...@@ -247,20 +216,14 @@ def _find_sla_value(
criterion.print_and_compute_margin(iter_data_mean, k) criterion.print_and_compute_margin(iter_data_mean, k)
for k, criterion in sla_comb.items() for k, criterion in sla_comb.items()
] ]
margin = max(sla_margins) history[val] = margin = max(sla_margins)
history[val] = margin
if margin <= 0: if margin <= 0:
print("SLA criteria are met.") print(f"SLA criteria are met. ({margin=:.2f})")
left = val
else: else:
print("SLA criteria are not met.") print(f"SLA criteria are not met. ({margin=:.2f})")
right = val
if right - left <= 1 and left in history:
break
return sla_data, left, history return sla_data, history
def search_sla( def search_sla(
...@@ -271,7 +234,6 @@ def search_sla( ...@@ -271,7 +234,6 @@ def search_sla(
bench_comb: ParameterSweepItem, bench_comb: ParameterSweepItem,
sla_comb: SLASweepItem, sla_comb: SLASweepItem,
sla_variable: SLAVariable, sla_variable: SLAVariable,
sla_inf_value: int = 65536, # The value that represents infinite QPS
base_path: Path, base_path: Path,
num_runs: int, num_runs: int,
dry_run: bool, dry_run: bool,
...@@ -279,43 +241,7 @@ def search_sla( ...@@ -279,43 +241,7 @@ def search_sla(
print("[SLA START]") print("[SLA START]")
print(f"SLA criteria: {sla_comb.as_text()}") print(f"SLA criteria: {sla_comb.as_text()}")
sla_data_0 = run_sla( result = solve_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb | {sla_variable: sla_inf_value},
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, sla_inf_value),
num_runs=num_runs,
dry_run=dry_run,
)
if sla_data_0 is None:
assert dry_run
print("Omitting SLA search.")
print("[SLA END]")
return None
sla_init_value = math.ceil(
sum(_estimate_sla_value(item, sla_variable) for item in sla_data_0)
/ len(sla_data_0)
)
print(f"Initial {sla_variable} to search: {sla_init_value} req/s.")
sla_data_1, (sla_min, sla_max), _ = _estimate_sla_bounds(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
sla_comb=sla_comb,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
sla_variable=sla_variable,
init_value=sla_init_value,
max_value=sla_inf_value,
)
print(f"Range of {sla_variable} to search: [{sla_min}, {sla_max}] req/s.")
sla_data_2, sla_value, _ = _find_sla_value(
server, server,
bench_cmd, bench_cmd,
serve_comb=serve_comb, serve_comb=serve_comb,
...@@ -325,11 +251,15 @@ def search_sla( ...@@ -325,11 +251,15 @@ def search_sla(
num_runs=num_runs, num_runs=num_runs,
dry_run=dry_run, dry_run=dry_run,
sla_variable=sla_variable, sla_variable=sla_variable,
min_value=sla_min,
max_value=sla_max,
) )
if result is None:
assert dry_run
print("Omitting SLA search.")
print("[SLA END]")
return
sla_data = sla_data_0 + sla_data_1 + sla_data_2 sla_data, sla_history = result
sla_value = sla_history.get_max_passing()
print(f"Maximum {sla_variable} for SLA: {sla_value} req/s.") print(f"Maximum {sla_variable} for SLA: {sla_value} req/s.")
with _get_sla_iter_path( with _get_sla_iter_path(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment