"tests/vscode:/vscode.git/clone" did not exist on "9a234c7adca95bbdf2c1dc410c5617d421ac3c11"
Unverified Commit 5f2385a4 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Benchmark][1/2] Generalize SLA criterion validation from binary flags to margins (#32075)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent a01a1c0d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from pathlib import Path
from unittest.mock import patch
from vllm.benchmarks.sweep.param_sweep import ParameterSweepItem
from vllm.benchmarks.sweep.serve_sla import _estimate_sla_bounds, _find_sla_value
from vllm.benchmarks.sweep.server import ServerProcess
from vllm.benchmarks.sweep.sla_sweep import (
SLACriterionBase,
SLALessThan,
SLALessThanOrEqualTo,
SLASweepItem,
)
def _set_return_value(
var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
):
"""
Create a patch for run_sla with a specific function
indicating the relationship between the benchmark combination
(which includes the SLA variable) and the SLA criterion.
"""
def mock_run_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
iter_path: Path,
num_runs: int,
dry_run: bool,
):
return var2metric(bench_comb)
return patch("vllm.benchmarks.sweep.serve_sla.run_sla", side_effect=mock_run_sla)
def _var2metric_identity(bench_comb):
return [{"request_throughput": float(bench_comb["request_rate"])}]
def _run_estimate_sla_bounds(
var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
criterion: SLACriterionBase,
init_value: int,
max_value: int,
):
with _set_return_value(var2metric):
return _estimate_sla_bounds(
server=None,
bench_cmd=[],
serve_comb=ParameterSweepItem(),
bench_comb=ParameterSweepItem(),
sla_comb=SLASweepItem({"request_throughput": criterion}),
base_path=Path(""),
num_runs=1,
dry_run=False,
sla_variable="request_rate",
init_value=init_value,
max_value=max_value,
)
def test_estimate_sla_bounds_le():
sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds(
_var2metric_identity,
SLALessThanOrEqualTo(target=32),
init_value=1,
max_value=100,
)
assert max_passing == 32
assert min_failing == 64
assert {val: margin <= 0 for val, margin in history.items()} == {
1: True,
2: True,
4: True,
8: True,
16: True,
32: True,
64: False,
}
def test_estimate_sla_bounds_lt():
sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds(
_var2metric_identity,
SLALessThan(target=32),
init_value=1,
max_value=100,
)
assert max_passing == 16
assert min_failing == 32
assert {val: margin <= 0 for val, margin in history.items()} == {
1: True,
2: True,
4: True,
8: True,
16: True,
32: False,
}
def test_estimate_sla_bounds_oob():
sla_data, (max_passing, min_failing), history = _run_estimate_sla_bounds(
_var2metric_identity,
SLALessThanOrEqualTo(target=32),
init_value=64,
max_value=128,
)
assert max_passing == 0
assert min_failing == 64
assert {val: margin <= 0 for val, margin in history.items()} == {
64: False,
}
def _run_test_find_sla_value_le(
var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
criterion: SLACriterionBase,
min_value: int,
max_value: int,
):
with _set_return_value(var2metric):
return _find_sla_value(
server=None,
bench_cmd=[],
serve_comb=ParameterSweepItem(),
bench_comb=ParameterSweepItem(),
sla_comb=SLASweepItem({"request_throughput": criterion}),
base_path=Path(""),
num_runs=1,
dry_run=False,
sla_variable="request_rate",
min_value=min_value,
max_value=max_value,
)
def test_find_sla_value_le():
sla_data, sla_value, history = _run_test_find_sla_value_le(
_var2metric_identity,
SLALessThanOrEqualTo(target=50.0),
min_value=32,
max_value=64,
)
assert sla_value == 50
assert {val: margin <= 0 for val, margin in history.items()} == {
48: True,
56: False,
52: False,
50: True,
51: False,
}
def test_find_sla_value_lt():
sla_data, sla_value, history = _run_test_find_sla_value_le(
_var2metric_identity,
SLALessThan(target=50.0),
min_value=32,
max_value=64,
)
assert sla_value == 49
assert {val: margin <= 0 for val, margin in history.items()} == {
48: True,
56: False,
52: False,
50: False,
49: True,
}
def test_find_sla_value_oob():
sla_data, sla_value, history = _run_test_find_sla_value_le(
_var2metric_identity,
SLALessThanOrEqualTo(target=50.0),
min_value=64,
max_value=128,
)
assert sla_value == 64
assert {val: margin <= 0 for val, margin in history.items()} == {
96: False,
80: False,
72: False,
68: False,
66: False,
65: False,
64: False,
}
...@@ -74,7 +74,8 @@ class ParameterSweepItem(dict[str, object]): ...@@ -74,7 +74,8 @@ class ParameterSweepItem(dict[str, object]):
representation of all parameters. representation of all parameters.
""" """
if "_benchmark_name" in self: if "_benchmark_name" in self:
return self["_benchmark_name"] return str(self["_benchmark_name"])
return self.as_text(sep="-") return self.as_text(sep="-")
# In JSON, we prefer "_" # In JSON, we prefer "_"
......
...@@ -145,12 +145,11 @@ def _estimate_sla_bounds( ...@@ -145,12 +145,11 @@ def _estimate_sla_bounds(
): ):
sla_data = list[dict[str, object]]() sla_data = list[dict[str, object]]()
max_passing: int = 0
min_failing: int = 0
val: int = init_value val: int = init_value
assert val > 0 assert val > 0
history = dict[int, float]()
while True: while True:
print(f"Testing {sla_variable}: {val} req/s") print(f"Testing {sla_variable}: {val} req/s")
...@@ -172,24 +171,33 @@ def _estimate_sla_bounds( ...@@ -172,24 +171,33 @@ def _estimate_sla_bounds(
for k in sla_comb for k in sla_comb
} }
sla_results = [ sla_margins = [
criterion.print_and_validate(iter_data_mean, k) criterion.print_and_compute_margin(iter_data_mean, k)
for k, criterion in sla_comb.items() for k, criterion in sla_comb.items()
] ]
margin = max(sla_margins)
history[val] = margin
if all(sla_results): if margin <= 0:
print("SLA criteria are met.") print("SLA criteria are met.")
max_passing = val
val *= 2 val *= 2
else: else:
print("SLA criteria are not met.") print("SLA criteria are not met.")
min_failing = val
break break
if val >= max_value: if val >= max_value:
break break
return sla_data, (max_passing, min_failing) max_passing = max(
(val for val, margin in history.items() if margin <= 0),
default=0,
)
min_failing = min(
(val for val, margin in history.items() if margin > 0),
default=max_value,
)
return sla_data, (max_passing, min_failing), history
def _find_sla_value( def _find_sla_value(
...@@ -211,6 +219,8 @@ def _find_sla_value( ...@@ -211,6 +219,8 @@ def _find_sla_value(
left: int = min_value left: int = min_value
right: int = max_value right: int = max_value
history = dict[int, float]()
while True: while True:
val = (left + right) // 2 val = (left + right) // 2
print(f"Testing {sla_variable}: {val} req/s") print(f"Testing {sla_variable}: {val} req/s")
...@@ -233,22 +243,24 @@ def _find_sla_value( ...@@ -233,22 +243,24 @@ def _find_sla_value(
for k in sla_comb for k in sla_comb
} }
sla_results = [ sla_margins = [
criterion.print_and_validate(iter_data_mean, k) criterion.print_and_compute_margin(iter_data_mean, k)
for k, criterion in sla_comb.items() for k, criterion in sla_comb.items()
] ]
margin = max(sla_margins)
history[val] = margin
if all(sla_results): if margin <= 0:
print("SLA criteria are met.") print("SLA criteria are met.")
left = val left = val
else: else:
print("SLA criteria are not met.") print("SLA criteria are not met.")
right = val right = val
if right - left <= 1: if right - left <= 1 and left in history:
break break
return sla_data, left return sla_data, left, history
def search_sla( def search_sla(
...@@ -288,7 +300,7 @@ def search_sla( ...@@ -288,7 +300,7 @@ def search_sla(
) )
print(f"Initial {sla_variable} to search: {sla_init_value} req/s.") print(f"Initial {sla_variable} to search: {sla_init_value} req/s.")
sla_data_1, (sla_min, sla_max) = _estimate_sla_bounds( sla_data_1, (sla_min, sla_max), _ = _estimate_sla_bounds(
server, server,
bench_cmd, bench_cmd,
serve_comb=serve_comb, serve_comb=serve_comb,
...@@ -303,7 +315,7 @@ def search_sla( ...@@ -303,7 +315,7 @@ def search_sla(
) )
print(f"Range of {sla_variable} to search: [{sla_min}, {sla_max}] req/s.") print(f"Range of {sla_variable} to search: [{sla_min}, {sla_max}] req/s.")
sla_data_2, sla_value = _find_sla_value( sla_data_2, sla_value, _ = _find_sla_value(
server, server,
bench_cmd, bench_cmd,
serve_comb=serve_comb, serve_comb=serve_comb,
......
...@@ -7,39 +7,45 @@ from dataclasses import dataclass ...@@ -7,39 +7,45 @@ from dataclasses import dataclass
from typing_extensions import override from typing_extensions import override
SLA_EPS = 1e-8
"""Offset used to differentiate margins for equality checks."""
@dataclass @dataclass
class SLACriterionBase(ABC): class SLACriterionBase(ABC):
target: float target: float
@abstractmethod @abstractmethod
def validate(self, actual: float) -> bool: def compute_margin(self, actual: float) -> float:
"""Return `True` if this criterion is met; otherwise `False`.""" """
Return a negative value or `0` if this criterion is met;
otherwise a positive value indicating the distance to the target.
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def format_cond(self, lhs: str) -> str: def format_cond(self, lhs: str) -> str:
raise NotImplementedError raise NotImplementedError
def print_and_validate( def print_and_compute_margin(
self, self,
metrics: dict[str, float], metrics: dict[str, float],
metrics_key: str, metrics_key: str,
) -> bool: ) -> float:
metric = metrics[metrics_key] metric = metrics[metrics_key]
result = self.validate(metric) margin = self.compute_margin(metric)
cond = self.format_cond(f"{metrics_key} = {metric:.2f}") cond = self.format_cond(f"{metrics_key} = {metric:.2f}")
print(f"Validating SLA: {cond} | " + ("PASSED" if result else "FAILED")) print(f"Validating SLA: {cond} | " + ("PASSED" if margin <= 0 else "FAILED"))
return result return margin
@dataclass @dataclass
class SLALessThan(SLACriterionBase): class SLALessThan(SLACriterionBase):
@override @override
def validate(self, actual: float) -> bool: def compute_margin(self, actual: float) -> float:
return actual < self.target return actual + SLA_EPS - self.target
@override @override
def format_cond(self, lhs: str) -> str: def format_cond(self, lhs: str) -> str:
...@@ -49,8 +55,8 @@ class SLALessThan(SLACriterionBase): ...@@ -49,8 +55,8 @@ class SLALessThan(SLACriterionBase):
@dataclass @dataclass
class SLALessThanOrEqualTo(SLACriterionBase): class SLALessThanOrEqualTo(SLACriterionBase):
@override @override
def validate(self, actual: float) -> bool: def compute_margin(self, actual: float) -> float:
return actual <= self.target return actual - self.target
@override @override
def format_cond(self, lhs: str) -> str: def format_cond(self, lhs: str) -> str:
...@@ -60,8 +66,8 @@ class SLALessThanOrEqualTo(SLACriterionBase): ...@@ -60,8 +66,8 @@ class SLALessThanOrEqualTo(SLACriterionBase):
@dataclass @dataclass
class SLAGreaterThan(SLACriterionBase): class SLAGreaterThan(SLACriterionBase):
@override @override
def validate(self, actual: float) -> bool: def compute_margin(self, actual: float) -> float:
return actual > self.target return self.target + SLA_EPS - actual
@override @override
def format_cond(self, lhs: str) -> str: def format_cond(self, lhs: str) -> str:
...@@ -71,8 +77,8 @@ class SLAGreaterThan(SLACriterionBase): ...@@ -71,8 +77,8 @@ class SLAGreaterThan(SLACriterionBase):
@dataclass @dataclass
class SLAGreaterThanOrEqualTo(SLACriterionBase): class SLAGreaterThanOrEqualTo(SLACriterionBase):
@override @override
def validate(self, actual: float) -> bool: def compute_margin(self, actual: float) -> float:
return actual >= self.target return self.target - actual
@override @override
def format_cond(self, lhs: str) -> str: def format_cond(self, lhs: str) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment