test_serve_sla.py 7.15 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
5
6
7
8
from collections.abc import Callable
from pathlib import Path
from unittest.mock import patch

from vllm.benchmarks.sweep.param_sweep import ParameterSweepItem
9
from vllm.benchmarks.sweep.serve_sla import _get_sla_run_path, solve_sla
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from vllm.benchmarks.sweep.server import ServerProcess
from vllm.benchmarks.sweep.sla_sweep import (
    SLACriterionBase,
    SLALessThan,
    SLALessThanOrEqualTo,
    SLASweepItem,
)


def _set_return_value(
    var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
):
    """
    Create a patch for run_sla with a specific function
    indicating the relationship between the benchmark combination
    (which includes the SLA variable) and the SLA criterion.
    """

    def mock_run_sla(
        server: ServerProcess | None,
        bench_cmd: list[str],
        *,
        serve_comb: ParameterSweepItem,
        bench_comb: ParameterSweepItem,
        iter_path: Path,
        num_runs: int,
        dry_run: bool,
    ):
38
39
40
41
42
43
44
45
        iter_data = var2metric(bench_comb)

        summary_path = _get_sla_run_path(iter_path, run_number=None)
        summary_path.parent.mkdir(parents=True, exist_ok=True)
        with summary_path.open("w") as f:
            json.dump(iter_data, f, indent=4)

        return iter_data
46
47
48
49

    return patch("vllm.benchmarks.sweep.serve_sla.run_sla", side_effect=mock_run_sla)


50
51
52
53
def _var2metric_linear():
    def wrapped(bench_comb):
        x = float(bench_comb["request_rate"])
        y = x
54

55
        return [{"request_throughput": y}]
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    return wrapped


def _var2metric_concave(elbow_point: float):
    def wrapped(bench_comb):
        x = float(bench_comb["request_rate"])
        if x < elbow_point:
            y = 0.5 * (x - elbow_point) + elbow_point
        else:
            y = 1.5 * (x - elbow_point) + elbow_point

        return [{"request_throughput": y}]

    return wrapped


def _var2metric_convex(elbow_point: float):
    def wrapped(bench_comb):
        x = float(bench_comb["request_rate"])
        if x < elbow_point:
            y = 1.5 * (x - elbow_point) + elbow_point
        else:
            y = 0.5 * (x - elbow_point) + elbow_point

        return [{"request_throughput": y}]

    return wrapped


def _var2metric_quadratic(y_intercept: float):
    def wrapped(bench_comb):
        x = float(bench_comb["request_rate"])
        y = y_intercept + 0.1 * x**2

        return [{"request_throughput": y}]

    return wrapped


def _var2metric_sqrt(y_intercept: float):
    def wrapped(bench_comb):
        x = float(bench_comb["request_rate"])
        y = y_intercept + 10 * x**0.5

        return [{"request_throughput": y}]

    return wrapped


def _run_solve_sla(
107
108
    var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
    criterion: SLACriterionBase,
109
    base_path: Path,
110
111
    min_value: int = 1,
    max_value: int = 100,
112
113
):
    with _set_return_value(var2metric):
114
        result = solve_sla(
115
116
117
118
119
            server=None,
            bench_cmd=[],
            serve_comb=ParameterSweepItem(),
            bench_comb=ParameterSweepItem(),
            sla_comb=SLASweepItem({"request_throughput": criterion}),
120
            base_path=base_path,
121
122
123
            num_runs=1,
            dry_run=False,
            sla_variable="request_rate",
124
125
            sla_min_value=min_value,
            sla_max_value=max_value,
126
        )
127
        assert result is not None
128

129
        return result
130

131

132
def test_solve_linear_sla_le(tmp_path):
133
134
    sla_data, history = _run_solve_sla(
        _var2metric_linear(),
135
        SLALessThanOrEqualTo(target=32),
136
        tmp_path,
137
138
    )

139
    assert history.get_max_passing() == 32
140
141

    assert {val: margin <= 0 for val, margin in history.items()} == {
142
        100: False,
143
144
        1: True,
        32: True,
145
        33: False,
146
147
148
    }


149
def test_solve_linear_sla_lt(tmp_path):
150
151
    sla_data, history = _run_solve_sla(
        _var2metric_linear(),
152
        SLALessThan(target=32),
153
        tmp_path,
154
155
    )

156
    assert history.get_max_passing() == 31
157
158

    assert {val: margin <= 0 for val, margin in history.items()} == {
159
        100: False,
160
        1: True,
161
        31: True,
162
163
164
165
        32: False,
    }


166
def test_solve_linear_sla_oob(tmp_path):
167
168
    sla_data, history = _run_solve_sla(
        _var2metric_linear(),
169
        SLALessThanOrEqualTo(target=32),
170
        tmp_path,
171
        min_value=64,
172
173
    )

174
175
    assert history.get_max_passing() == 64
    assert history.get_min_failing() == 64
176
177

    assert {val: margin <= 0 for val, margin in history.items()} == {
178
        100: False,
179
180
181
182
        64: False,
    }


183
def test_solve_concave_sla_le(tmp_path):
184
185
186
    sla_data, history = _run_solve_sla(
        _var2metric_concave(elbow_point=32),
        SLALessThanOrEqualTo(target=24),
187
        tmp_path,
188
    )
189

190
    assert history.get_max_passing() == 16
191

192
193
194
195
196
197
198
199
200
201
202
    assert {val: margin <= 0 for val, margin in history.items()} == {
        100: False,
        1: True,
        7: True,
        13: True,
        15: True,
        16: True,
        17: False,
    }


203
def test_solve_convex_sla_le(tmp_path):
204
205
206
    sla_data, history = _run_solve_sla(
        _var2metric_convex(elbow_point=32),
        SLALessThanOrEqualTo(target=24),
207
        tmp_path,
208
209
    )

210
211
    assert history.get_max_passing() == 26

212
    assert {val: margin <= 0 for val, margin in history.items()} == {
213
214
215
216
217
218
219
        100: False,
        1: True,
        48: False,
        30: False,
        24: True,
        26: True,
        27: False,
220
221
222
    }


223
def test_solve_quadratic_sla_le(tmp_path):
224
225
226
    sla_data, history = _run_solve_sla(
        _var2metric_quadratic(y_intercept=10),
        SLALessThanOrEqualTo(target=50),
227
        tmp_path,
228
229
    )

230
231
    assert history.get_max_passing() == 20

232
    assert {val: margin <= 0 for val, margin in history.items()} == {
233
234
235
236
237
        100: False,
        1: True,
        4: True,
        20: True,
        21: False,
238
239
240
    }


241
def test_solve_sqrt_sla_le(tmp_path):
242
243
244
    sla_data, history = _run_solve_sla(
        _var2metric_sqrt(y_intercept=10),
        SLALessThanOrEqualTo(target=100),
245
        tmp_path,
246
247
    )

248
249
    assert history.get_max_passing() == 81

250
    assert {val: margin <= 0 for val, margin in history.items()} == {
251
252
253
254
255
        100: False,
        1: True,
        89: False,
        81: True,
        82: False,
256
    }
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298


def test_solve_reuse_history(tmp_path):
    sla_data, history = _run_solve_sla(
        _var2metric_linear(),
        SLALessThanOrEqualTo(target=10),
        tmp_path,
        min_value=1,
        max_value=20,
    )

    assert history.get_max_passing() == 10

    assert {val: margin <= 0 for val, margin in history.items()} == {
        20: False,
        1: True,
        10: True,
        11: False,
    }

    sla_data, history = _run_solve_sla(
        _var2metric_linear(),
        SLALessThanOrEqualTo(target=30),
        tmp_path,
        min_value=21,
        max_value=40,
    )

    assert history.get_max_passing() == 30

    assert {val: margin <= 0 for val, margin in history.items()} == {
        # Items from the past run
        # (the margins are different because the target changed)
        20: True,
        1: True,
        10: True,
        11: True,
        # Items from this run
        40: False,
        30: True,
        31: False,
    }