test_serve_sla.py 5.73 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from pathlib import Path
from unittest.mock import patch

from vllm.benchmarks.sweep.param_sweep import ParameterSweepItem
8
from vllm.benchmarks.sweep.serve_sla import solve_sla
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from vllm.benchmarks.sweep.server import ServerProcess
from vllm.benchmarks.sweep.sla_sweep import (
    SLACriterionBase,
    SLALessThan,
    SLALessThanOrEqualTo,
    SLASweepItem,
)


def _set_return_value(
    var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
):
    """
    Create a patch for run_sla with a specific function
    indicating the relationship between the benchmark combination
    (which includes the SLA variable) and the SLA criterion.
    """

    def mock_run_sla(
        server: ServerProcess | None,
        bench_cmd: list[str],
        *,
        serve_comb: ParameterSweepItem,
        bench_comb: ParameterSweepItem,
        iter_path: Path,
        num_runs: int,
        dry_run: bool,
    ):
        return var2metric(bench_comb)

    return patch("vllm.benchmarks.sweep.serve_sla.run_sla", side_effect=mock_run_sla)


42
43
44
45
def _var2metric_linear():
    def wrapped(bench_comb):
        x = float(bench_comb["request_rate"])
        y = x
46

47
        return [{"request_throughput": y}]
48

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    return wrapped


def _var2metric_concave(elbow_point: float):
    def wrapped(bench_comb):
        x = float(bench_comb["request_rate"])
        if x < elbow_point:
            y = 0.5 * (x - elbow_point) + elbow_point
        else:
            y = 1.5 * (x - elbow_point) + elbow_point

        return [{"request_throughput": y}]

    return wrapped


def _var2metric_convex(elbow_point: float):
    def wrapped(bench_comb):
        x = float(bench_comb["request_rate"])
        if x < elbow_point:
            y = 1.5 * (x - elbow_point) + elbow_point
        else:
            y = 0.5 * (x - elbow_point) + elbow_point

        return [{"request_throughput": y}]

    return wrapped


def _var2metric_quadratic(y_intercept: float):
    def wrapped(bench_comb):
        x = float(bench_comb["request_rate"])
        y = y_intercept + 0.1 * x**2

        return [{"request_throughput": y}]

    return wrapped


def _var2metric_sqrt(y_intercept: float):
    def wrapped(bench_comb):
        x = float(bench_comb["request_rate"])
        y = y_intercept + 10 * x**0.5

        return [{"request_throughput": y}]

    return wrapped


def _run_solve_sla(
99
100
    var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
    criterion: SLACriterionBase,
101
102
    min_value: int = 1,
    max_value: int = 100,
103
104
):
    with _set_return_value(var2metric):
105
        result = solve_sla(
106
107
108
109
110
111
112
113
114
            server=None,
            bench_cmd=[],
            serve_comb=ParameterSweepItem(),
            bench_comb=ParameterSweepItem(),
            sla_comb=SLASweepItem({"request_throughput": criterion}),
            base_path=Path(""),
            num_runs=1,
            dry_run=False,
            sla_variable="request_rate",
115
116
            sla_min_value=min_value,
            sla_max_value=max_value,
117
        )
118
        assert result is not None
119

120
        return result
121

122
123
124
125

def test_solve_linear_sla_le():
    sla_data, history = _run_solve_sla(
        _var2metric_linear(),
126
127
128
        SLALessThanOrEqualTo(target=32),
    )

129
    assert history.get_max_passing() == 32
130
131

    assert {val: margin <= 0 for val, margin in history.items()} == {
132
        100: False,
133
134
        1: True,
        32: True,
135
        33: False,
136
137
138
    }


139
140
141
def test_solve_linear_sla_lt():
    sla_data, history = _run_solve_sla(
        _var2metric_linear(),
142
143
144
        SLALessThan(target=32),
    )

145
    assert history.get_max_passing() == 31
146
147

    assert {val: margin <= 0 for val, margin in history.items()} == {
148
        100: False,
149
        1: True,
150
        31: True,
151
152
153
154
        32: False,
    }


155
156
157
def test_solve_linear_sla_oob():
    sla_data, history = _run_solve_sla(
        _var2metric_linear(),
158
        SLALessThanOrEqualTo(target=32),
159
        min_value=64,
160
161
    )

162
163
    assert history.get_max_passing() == 64
    assert history.get_min_failing() == 64
164
165

    assert {val: margin <= 0 for val, margin in history.items()} == {
166
        100: False,
167
168
169
170
        64: False,
    }


171
172
173
174
175
def test_solve_concave_sla_le():
    sla_data, history = _run_solve_sla(
        _var2metric_concave(elbow_point=32),
        SLALessThanOrEqualTo(target=24),
    )
176

177
    assert history.get_max_passing() == 16
178

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    assert {val: margin <= 0 for val, margin in history.items()} == {
        100: False,
        1: True,
        7: True,
        13: True,
        15: True,
        16: True,
        17: False,
    }


def test_solve_convex_sla_le():
    sla_data, history = _run_solve_sla(
        _var2metric_convex(elbow_point=32),
        SLALessThanOrEqualTo(target=24),
194
195
    )

196
197
    assert history.get_max_passing() == 26

198
    assert {val: margin <= 0 for val, margin in history.items()} == {
199
200
201
202
203
204
205
        100: False,
        1: True,
        48: False,
        30: False,
        24: True,
        26: True,
        27: False,
206
207
208
    }


209
210
211
212
def test_solve_quadratic_sla_le():
    sla_data, history = _run_solve_sla(
        _var2metric_quadratic(y_intercept=10),
        SLALessThanOrEqualTo(target=50),
213
214
    )

215
216
    assert history.get_max_passing() == 20

217
    assert {val: margin <= 0 for val, margin in history.items()} == {
218
219
220
221
222
        100: False,
        1: True,
        4: True,
        20: True,
        21: False,
223
224
225
    }


226
227
228
229
def test_solve_sqrt_sla_le():
    sla_data, history = _run_solve_sla(
        _var2metric_sqrt(y_intercept=10),
        SLALessThanOrEqualTo(target=100),
230
231
    )

232
233
    assert history.get_max_passing() == 81

234
    assert {val: margin <= 0 for val, margin in history.items()} == {
235
236
237
238
239
        100: False,
        1: True,
        89: False,
        81: True,
        82: False,
240
    }