w8a8_benchmarks.py 12.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import argparse
import copy
import itertools
import pickle as pkl
import time
9
10
from collections.abc import Iterable
from typing import Callable, Optional
11
12
13
14

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
15
from utils import make_rand_tensors
16
17
18
from weight_shapes import WEIGHT_SHAPES

from vllm import _custom_ops as ops
19
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
20
21
    w8a8_block_fp8_matmul,
)
22
from vllm.utils import FlexibleArgumentParser, cdiv
23

24
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
25
26
27
28
29
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]


# bench
30
31
32
def bench_fn(
    label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
) -> TMeasurement:
33
34
35
    min_run_time = 1

    globals = {
36
37
        "args": args,
        "kwargs": kwargs,
38
39
40
        "fn": fn,
    }
    return TBenchmark.Timer(
41
        stmt="fn(*args, **kwargs)",
42
43
44
45
46
47
48
        globals=globals,
        label=label,
        sub_label=sub_label,
        description=description,
    ).blocked_autorange(min_run_time=min_run_time)


49
def bench_int8(
50
51
52
53
54
55
56
57
    dtype: torch.dtype,
    m: int,
    k: int,
    n: int,
    label: str,
    sub_label: str,
    bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
58
    """Benchmark INT8-based kernels."""
59
60
61
62
    assert dtype == torch.int8
    a, b = make_rand_tensors(torch.int8, m, n, k)
    scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
    scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
63
64
65
    bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
    azp = torch.zeros((m,), device="cuda", dtype=torch.int32)
    azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32)
66

67
    bench_fns = {
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
            a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
        ),
        "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
            a.to(dtype=torch.float16), b.to(dtype=torch.float16)
        ),
        "cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
            a, b, scale_a, scale_b, torch.bfloat16
        ),
        "cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
            a, b, scale_a, scale_b, torch.bfloat16, bias
        ),
        "cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp(
            a, b, scale_a, scale_b, torch.bfloat16, azp_adj
        ),
        "cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp(
            a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias
        ),
        "cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp(
            a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp
        ),
        "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp(
            a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias
        ),
92
93
    }

94
    timers = []
95
96
97
98
99
    for name, fn in bench_fns.items():
        # If bench_kernels is None, run all. Otherwise, run only exact matches.
        if bench_kernels is None or name in bench_kernels:
            print(f"Running {name}")
            timers.append(bench_fn(label, sub_label, name, fn))
100
101
102
103

    return timers


104
def bench_fp8(
105
106
107
108
109
110
111
112
    dtype: torch.dtype,
    m: int,
    k: int,
    n: int,
    label: str,
    sub_label: str,
    bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
113
    """Benchmark FP8-based kernels."""
114
115
    assert dtype == torch.float8_e4m3fn
    a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
116
    a_cont = a.contiguous()
117
118
    scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
    scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
119

120
    block_scale_a = torch.rand((m, cdiv(k, 128)), device="cuda", dtype=torch.float32)
121
    block_scale_b = torch.rand(
122
        cdiv(k, 128), cdiv(n, 128), device="cuda", dtype=torch.float32
123
    )
124
125
    block_scale_a_M_major = block_scale_a.t().contiguous().t()
    block_scale_b_K_major = block_scale_b.t().contiguous().t()
126
    bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
127

128
129
130
    print(m, k, n)

    bench_fns = {
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
            a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
        ),
        "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
            a.to(dtype=torch.float16), b.to(dtype=torch.float16)
        ),
        "pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm(
            a, b, scale_a, scale_b, out_dtype=torch.float16
        ),
        "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
            a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True
        ),
        "pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm(
            a, b, scale_a, scale_b, out_dtype=torch.bfloat16
        ),
        "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
            a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True
        ),
        "cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
            a, b, scale_a, scale_b, torch.bfloat16
        ),
        "cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm(
            a, b, scale_a, scale_b, torch.float16
        ),
        "cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
            a, b, scale_a, scale_b, torch.bfloat16, bias
        ),
        "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
            a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
        ),
        "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
            a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
        ),
        "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
            a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16
        ),
167
    }
168

169
170
171
172
173
174
    timers = []
    for name, fn in bench_fns.items():
        # If bench_kernels is None, run all. Otherwise, run only exact matches.
        if bench_kernels is None or name in bench_kernels:
            print(f"Running {name}")
            timers.append(bench_fn(label, sub_label, name, fn))
175

176
177
178
    return timers


179
180
181
182
183
184
185
186
187
def bench(
    dtype: torch.dtype,
    m: int,
    k: int,
    n: int,
    label: str,
    sub_label: str,
    bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
188
    if dtype == torch.int8:
189
        return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
190
    if dtype == torch.float8_e4m3fn:
191
        return bench_fp8(dtype, m, k, n, label, sub_label, bench_kernels)
192
193
194
195
196
197
198
199
200
    raise ValueError("unsupported type")


# runner
def print_timers(timers: Iterable[TMeasurement]):
    compare = TBenchmark.Compare(timers)
    compare.print()


201
202
203
204
205
def run(
    dtype: torch.dtype,
    MKNs: Iterable[tuple[int, int, int]],
    bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
206
207
    results = []
    for m, k, n in MKNs:
208
209
210
211
212
213
214
215
216
        timers = bench(
            dtype,
            m,
            k,
            n,
            f"scaled-{dtype}-gemm",
            f"MKN=({m}x{k}x{n})",
            bench_kernels=bench_kernels,
        )
217
218
219
220
221
        print_timers(timers)
        results.extend(timers)
    return results


222
223
224
225
226
227
def make_output(
    data: Iterable[TMeasurement],
    MKNs: Iterable[tuple[int, int, int]],
    base_description: str,
    timestamp=None,
):
228
229
230
231
232
233
234
235
236
237
    print(f"== All Results {base_description} ====")
    print_timers(data)

    # pickle all the results
    timestamp = int(time.time()) if timestamp is None else timestamp
    with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
        pkl.dump(data, f)


def run_square_bench(args):
238
    dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
239
    MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
240
    data = run(args.dtype, MKNs, bench_kernels=args.kernels)
241
242
243
244
245
246
247
248
249
250
    make_output(data, MKNs, f"square_bench-{args.dtype}")


def run_range_bench(args):
    dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
    n = len(dim_sizes)
    Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
    Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
    Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
    MKNs = list(zip(Ms, Ks, Ns))
251
    data = run(args.dtype, MKNs, bench_kernels=args.kernels)
252
253
254
255
256
257
258
259
    make_output(data, MKNs, f"range_bench-{args.dtype}")


def run_model_bench(args):
    print("Benchmarking models:")
    for i, model in enumerate(args.models):
        print(f"[{i}]  {model}")

260
    def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        KNs = []
        for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
            KN[tp_split_dim] = KN[tp_split_dim] // tp_size
            KNs.append(KN)
        return KNs

    model_bench_data = []
    models_tps = list(itertools.product(args.models, args.tp_sizes))
    for model, tp_size in models_tps:
        Ms = args.batch_sizes
        KNs = model_shapes(model, tp_size)
        MKNs = []
        for m in Ms:
            for k, n in KNs:
                MKNs.append((m, k, n))

277
        data = run(args.dtype, MKNs, bench_kernels=args.kernels)
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        model_bench_data.append(data)

    # Print all results
    for data, model_tp in zip(model_bench_data, models_tps):
        model, tp_size = model_tp
        print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
        print_timers(data)

    timestamp = int(time.time())

    all_data = []
    for d in model_bench_data:
        all_data.extend(d)
    # pickle all data
    with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
        pkl.dump(all_data, f)


296
if __name__ == "__main__":
297
298
299
300
301
302
303
304

    def to_torch_dtype(dt):
        if dt == "int8":
            return torch.int8
        if dt == "fp8":
            return torch.float8_e4m3fn
        raise ValueError("unsupported dtype")

305
    parser = FlexibleArgumentParser(
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        description="""
Benchmark Cutlass GEMM.

    To run square GEMMs:
        python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
    
    To run constant N and K and sweep M:
        python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
    
    To run dimensions from a model:
        python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
    
    Output:
        - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
            """,  # noqa: E501
321
322
        formatter_class=argparse.RawTextHelpFormatter,
    )
323

324
325
326
327
328
329
    parser.add_argument(
        "--dtype",
        type=to_torch_dtype,
        required=True,
        help="Available options are ['int8', 'fp8']",
    )
330
331
332
333
334
    parser.add_argument(
        "--kernels",
        nargs="+",
        type=str,
        default=None,
335
        help="Exact names of the kernels to benchmark. If not set, runs all kernels.",
336
337
    )

338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    subparsers = parser.add_subparsers(dest="cmd")

    square_parser = subparsers.add_parser("square_bench")
    square_parser.add_argument("--dim-start", type=int, required=True)
    square_parser.add_argument("--dim-end", type=int, required=True)
    square_parser.add_argument("--dim-increment", type=int, required=True)
    square_parser.set_defaults(func=run_square_bench)

    range_parser = subparsers.add_parser("range_bench")
    range_parser.add_argument("--dim-start", type=int, required=True)
    range_parser.add_argument("--dim-end", type=int, required=True)
    range_parser.add_argument("--dim-increment", type=int, required=True)
    range_parser.add_argument("--m-constant", type=int, default=None)
    range_parser.add_argument("--n-constant", type=int, default=None)
    range_parser.add_argument("--k-constant", type=int, default=None)
    range_parser.set_defaults(func=run_range_bench)

    model_parser = subparsers.add_parser("model_bench")
356
357
358
359
360
361
362
363
364
365
366
367
368
    model_parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES.keys(),
    )
    model_parser.add_argument(
        "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
    )
    model_parser.add_argument(
        "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
    )
369
370
371
    model_parser.set_defaults(func=run_model_bench)

    args = parser.parse_args()
372
    args.func(args)