serve_workload.py 10.3 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
4
import math
5
6
7
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
8
from typing import ClassVar, Literal, get_args
9

10
11
12
import numpy as np
from typing_extensions import assert_never

13
from vllm.benchmarks.datasets import DEFAULT_NUM_PROMPTS
14
15
from vllm.utils.import_utils import PlaceholderModule

16
from .param_sweep import ParameterSweep, ParameterSweepItem
17
18
19
20
21
22
from .serve import (
    SweepServeArgs,
    _get_comb_base_path,
    run_comb,
    server_ctx,
)
23
24
from .server import ServerProcess

25
26
27
28
29
try:
    import pandas as pd
except ImportError:
    pd = PlaceholderModule("pandas")

30

31
WorkloadVariable = Literal["request_rate", "max_concurrency"]
32
33


34
35
36
37
def _estimate_workload_value(
    run_data: dict[str, object],
    workload_var: WorkloadVariable,
):
38
    request_throughput = float(run_data["request_throughput"])  # type: ignore
39
    if workload_var == "request_rate":
40
        return request_throughput
41
    if workload_var == "max_concurrency":
42
43
        mean_latency_ms = float(run_data["mean_e2el_ms"])  # type: ignore
        return request_throughput * mean_latency_ms / 1000
44

45
    assert_never(workload_var)
46
47


48
49
50
51
52
53
def _estimate_workload_avg(
    runs: list[dict[str, object]],
    workload_var: WorkloadVariable,
):
    total = sum(_estimate_workload_value(run, workload_var) for run in runs)
    return total / len(runs)
54
55


56
def run_comb_workload(
57
58
59
60
61
    server: ServerProcess | None,
    bench_cmd: list[str],
    *,
    serve_comb: ParameterSweepItem,
    bench_comb: ParameterSweepItem,
62
    output_dir: Path,
63
64
    num_runs: int,
    dry_run: bool,
65
    link_vars: list[tuple[str, str]],
66
67
    workload_var: WorkloadVariable,
    workload_value: int,
68
) -> list[dict[str, object]] | None:
69
    bench_comb_workload = bench_comb | {workload_var: workload_value}
70

71
72
73
74
    return run_comb(
        server,
        bench_cmd,
        serve_comb=serve_comb,
75
        bench_comb=bench_comb_workload,
76
77
78
79
        base_path=_get_comb_base_path(
            output_dir,
            serve_comb,
            bench_comb,
80
            extra_parts=("WL-", f"{workload_var}={workload_value}"),
81
        ),
82
83
84
85
        num_runs=num_runs,
        dry_run=dry_run,
        link_vars=link_vars,
    )
86
87


88
def explore_comb_workloads(
89
90
91
92
93
    server: ServerProcess | None,
    bench_cmd: list[str],
    *,
    serve_comb: ParameterSweepItem,
    bench_comb: ParameterSweepItem,
94
95
    workload_var: WorkloadVariable,
    workload_iters: int,
96
    output_dir: Path,
97
98
    num_runs: int,
    dry_run: bool,
99
    link_vars: list[tuple[str, str]],
100
):
101
    print("[WL START]")
102
103
    print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
    print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
104
    print(f"Number of workload iterations: {workload_iters}")
105

106
107
    if workload_iters < 2:
        raise ValueError("`workload_iters` should be at least 2")
108

109
110
111
112
113
114
115
116
117
118
119
120
121
122
    dataset_size = DEFAULT_NUM_PROMPTS
    if "num_prompts" in bench_comb:
        dataset_size = int(bench_comb["num_prompts"])  # type: ignore
    else:
        for i, arg in enumerate(bench_cmd):
            if arg == "--num-prompts" and i + 1 < len(bench_cmd):
                dataset_size = int(bench_cmd[i + 1])
                break
            elif arg.startswith("--num-prompts="):
                dataset_size = int(arg.split("=", 1)[1])
                break

    print(f"Dataset size: {dataset_size}")

123
    serial_workload_data = run_comb_workload(
124
125
126
        server,
        bench_cmd,
        serve_comb=serve_comb,
127
        bench_comb=bench_comb | {"max_concurrency": 1},
128
129
130
131
        output_dir=output_dir,
        num_runs=num_runs,
        dry_run=dry_run,
        link_vars=link_vars,
132
133
        workload_var=workload_var,
        workload_value=1,
134
    )
135
    batch_workload_data = run_comb_workload(
136
137
138
        server,
        bench_cmd,
        serve_comb=serve_comb,
139
        bench_comb=bench_comb | {"max_concurrency": dataset_size},
140
        output_dir=output_dir,
141
142
        num_runs=num_runs,
        dry_run=dry_run,
143
        link_vars=link_vars,
144
145
        workload_var=workload_var,
        workload_value=dataset_size,
146
    )
147

148
    if serial_workload_data is None or batch_workload_data is None:
149
        if dry_run:
150
151
            print("Omitting intermediate Workload iterations.")
            print("[WL END]")
152

153
        return
154

155
156
157
158
    serial_workload_value = math.ceil(
        _estimate_workload_avg(serial_workload_data, workload_var)
    )
    print(f"Serial inference: {workload_var}={serial_workload_value}")
159

160
161
162
163
    batch_workload_value = math.floor(
        _estimate_workload_avg(batch_workload_data, workload_var)
    )
    print(f"Batch inference: {workload_var}={batch_workload_value}")
164
165

    # Avoid duplicated runs for intermediate values if the range between
166
167
168
169
170
171
172
173
174
175
    # `serial_workload_value` and `batch_workload_value` is small
    inter_workload_values = np.linspace(
        serial_workload_value, batch_workload_value, workload_iters
    )[1:-1]
    inter_workload_values = sorted(set(map(round, inter_workload_values)))

    inter_workloads_data: list[dict[str, object]] = []
    for inter_workload_value in inter_workload_values:
        print(f"Exploring: {workload_var}={inter_workload_value}")
        inter_workload_data = run_comb_workload(
176
177
178
179
180
181
182
183
            server,
            bench_cmd,
            serve_comb=serve_comb,
            bench_comb=bench_comb,
            output_dir=output_dir,
            num_runs=num_runs,
            dry_run=dry_run,
            link_vars=link_vars,
184
185
            workload_var=workload_var,
            workload_value=inter_workload_value,
186
        )
187
188
        if inter_workload_data is not None:
            inter_workloads_data.extend(inter_workload_data)
189

190
    print("[WL END]")
191

192
    return serial_workload_data + inter_workloads_data + batch_workload_data
193
194


195
def explore_combs_workloads(
196
197
198
199
200
    serve_cmd: list[str],
    bench_cmd: list[str],
    after_bench_cmd: list[str],
    *,
    show_stdout: bool,
201
    server_ready_timeout: int,
202
203
    serve_params: ParameterSweep,
    bench_params: ParameterSweep,
204
205
    workload_var: WorkloadVariable,
    workload_iters: int,
206
207
208
    output_dir: Path,
    num_runs: int,
    dry_run: bool,
209
    link_vars: list[tuple[str, str]],
210
):
211
    if any(bench_comb.has_param(workload_var) for bench_comb in bench_params):
212
        raise ValueError(
213
214
            f"You should not override `{workload_var}` in `bench_params` "
            "since it is supposed to be explored automatically."
215
216
217
218
        )

    all_data = list[dict[str, object]]()
    for serve_comb in serve_params:
219
220
221
222
223
224
225
226
227
        with server_ctx(
            serve_cmd,
            after_bench_cmd,
            show_stdout=show_stdout,
            server_ready_timeout=server_ready_timeout,
            serve_comb=serve_comb,
            bench_params=bench_params,
            output_dir=output_dir,
            dry_run=dry_run,
228
229
        ) as server:
            for bench_comb in bench_params:
230
                comb_data = explore_comb_workloads(
231
232
233
234
                    server,
                    bench_cmd,
                    serve_comb=serve_comb,
                    bench_comb=bench_comb,
235
236
                    workload_var=workload_var,
                    workload_iters=workload_iters,
237
238
239
240
241
242
243
244
                    output_dir=output_dir,
                    num_runs=num_runs,
                    dry_run=dry_run,
                    link_vars=link_vars,
                )

                if comb_data is not None:
                    all_data.extend(comb_data)
245
246
247
248
249
250
251
252
253
254
255

    if dry_run:
        return None

    combined_df = pd.DataFrame.from_records(all_data)
    combined_df.to_csv(output_dir / "summary.csv")

    return combined_df


@dataclass
256
257
258
class SweepServeWorkloadArgs(SweepServeArgs):
    workload_var: WorkloadVariable
    workload_iters: int
259

260
    parser_name: ClassVar[str] = "serve_workload"
261
    parser_help: ClassVar[str] = (
262
        "Explore the latency-throughput tradeoff for different workload levels."
263
    )
264

265
266
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
267
268
        # NOTE: Don't use super() as `from_cli_args` calls `cls()`
        base_args = SweepServeArgs.from_cli_args(args)
269
270
271

        return cls(
            **asdict(base_args),
272
273
            workload_var=args.workload_var,
            workload_iters=args.workload_iters,
274
275
276
277
278
279
        )

    @classmethod
    def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
        parser = super().add_cli_args(parser)

280
281
282
        workload_group = parser.add_argument_group("workload options")
        workload_group.add_argument(
            "--workload-var",
283
            type=str,
284
            choices=get_args(WorkloadVariable),
285
            default="request_rate",
286
287
            help="The variable to adjust in each iteration.",
        )
288
289
        workload_group.add_argument(
            "--workload-iters",
290
291
            type=int,
            default=10,
292
            help="Number of workload levels to explore. "
293
            "This includes the first two iterations used to interpolate the value of "
294
            "`workload_var` for remaining iterations.",
295
296
297
298
299
        )

        return parser


300
def run_main(args: SweepServeWorkloadArgs):
301
302
303
304
305
306
307
    timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = args.output_dir / timestamp

    if args.resume and not output_dir.exists():
        raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")

    try:
308
        return explore_combs_workloads(
309
310
311
312
            serve_cmd=args.serve_cmd,
            bench_cmd=args.bench_cmd,
            after_bench_cmd=args.after_bench_cmd,
            show_stdout=args.show_stdout,
313
            server_ready_timeout=args.server_ready_timeout,
314
315
            serve_params=args.serve_params,
            bench_params=args.bench_params,
316
317
            workload_var=args.workload_var,
            workload_iters=args.workload_iters,
318
319
320
            output_dir=output_dir,
            num_runs=args.num_runs,
            dry_run=args.dry_run,
321
            link_vars=args.link_vars,
322
323
324
325
326
327
328
329
330
        )
    except BaseException as exc:
        raise RuntimeError(
            f"The script was terminated early. Use `--resume {timestamp}` "
            f"to continue the script from its last checkpoint."
        ) from exc


def main(args: argparse.Namespace):
331
    run_main(SweepServeWorkloadArgs.from_cli_args(args))
332
333
334


if __name__ == "__main__":
335
336
    parser = argparse.ArgumentParser(description=SweepServeWorkloadArgs.parser_help)
    SweepServeWorkloadArgs.add_cli_args(parser)
337
338

    main(parser.parse_args())