sparse_benchmarks.py 13.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import argparse
import copy
import itertools
import pickle as pkl
import time
9
from collections.abc import Callable, Iterable
10
11
12
13
14
15
16
17

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from utils import make_rand_sparse_tensors
from weight_shapes import WEIGHT_SHAPES

from vllm import _custom_ops as ops
18
from vllm.utils.argparse_utils import FlexibleArgumentParser
19
20
21
22
23
24
25

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]


# bench
26
27
28
def bench_fn(
    label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
) -> TMeasurement:
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    min_run_time = 1

    globals = {
        "args": args,
        "kwargs": kwargs,
        "fn": fn,
    }
    return TBenchmark.Timer(
        stmt="fn(*args, **kwargs)",
        globals=globals,
        label=label,
        sub_label=sub_label,
        description=description,
    ).blocked_autorange(min_run_time=min_run_time)


45
46
47
def bench_int8(
    dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
48
49
50
51
    assert dtype == torch.int8
    b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
    scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
    scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
52
    bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
53

54
55
56
    out = ops.cutlass_scaled_sparse_mm(
        a, b_compressed, e, scale_a, scale_b, torch.bfloat16
    )
57
58
59
60
61
62
63
64
65
66
67
68
    out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)

    if not torch.allclose(out, out_ref):
        print("Incorrect results")
        print(out)
        print(out_ref)
    else:
        print("Correct results")

    timers = []
    # pytorch impl - bfloat16
    timers.append(
69
70
71
72
73
74
75
76
77
        bench_fn(
            label,
            sub_label,
            "pytorch_bf16_bf16_bf16_matmul-no-scales",
            torch.mm,
            a.to(dtype=torch.bfloat16),
            b.to(dtype=torch.bfloat16),
        )
    )
78
79
80

    # pytorch impl - float16
    timers.append(
81
82
83
84
85
86
87
88
89
        bench_fn(
            label,
            sub_label,
            "pytorch_fp16_fp16_fp16_matmul-no-scales",
            torch.mm,
            a.to(dtype=torch.float16),
            b.to(dtype=torch.float16),
        )
    )
90
91
92

    # cutlass impl
    timers.append(
93
94
95
96
97
98
99
100
101
102
103
104
        bench_fn(
            label,
            sub_label,
            "cutlass_i8_i8_bf16_scaled_mm",
            ops.cutlass_scaled_mm,
            a,
            b,
            scale_a,
            scale_b,
            torch.bfloat16,
        )
    )
105
106
107

    # cutlass with bias
    timers.append(
108
109
110
111
112
113
114
115
116
117
118
119
120
        bench_fn(
            label,
            sub_label,
            "cutlass_i8_i8_bf16_scaled_mm_bias",
            ops.cutlass_scaled_mm,
            a,
            b,
            scale_a,
            scale_b,
            torch.bfloat16,
            bias,
        )
    )
121
122
123

    # cutlass sparse impl
    timers.append(
124
125
126
127
128
129
130
131
132
133
134
135
136
        bench_fn(
            label,
            sub_label,
            "cutlass_i8_i8_bf16_scaled_sparse_mm",
            ops.cutlass_scaled_sparse_mm,
            a,
            b_compressed,
            e,
            scale_a,
            scale_b,
            torch.bfloat16,
        )
    )
137
138
139

    # cutlass sparse with bias
    timers.append(
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        bench_fn(
            label,
            sub_label,
            "cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
            ops.cutlass_scaled_sparse_mm,
            a,
            b_compressed,
            e,
            scale_a,
            scale_b,
            torch.bfloat16,
            bias,
        )
    )
154
155
156
157

    return timers


158
159
160
def bench_fp8(
    dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
161
    assert dtype == torch.float8_e4m3fn
162
    b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
163
164
    scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
    scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
165
    bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
166

167
168
169
    out = ops.cutlass_scaled_sparse_mm(
        a, b_compressed, e, scale_a, scale_b, torch.bfloat16
    )
170
171
172
173
174
175
176
177
178
179
180
181
182
    out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)

    if not torch.allclose(out, out_ref):
        print("Incorrect results")
        print(out)
        print(out_ref)
    else:
        print("Correct results")

    timers = []

    # pytorch impl w. bf16
    timers.append(
183
184
185
186
187
188
189
190
191
        bench_fn(
            label,
            sub_label,
            "pytorch_bf16_bf16_bf16_matmul-no-scales",
            torch.mm,
            a.to(dtype=torch.bfloat16, device="cuda"),
            b.to(dtype=torch.bfloat16, device="cuda"),
        )
    )
192
193
194

    # pytorch impl: bf16 output, without fp8 fast accum
    timers.append(
195
196
197
198
199
200
201
202
203
204
205
206
        bench_fn(
            label,
            sub_label,
            "pytorch_fp8_fp8_bf16_scaled_mm",
            torch._scaled_mm,
            a,
            b,
            scale_a=scale_a,
            scale_b=scale_b,
            out_dtype=torch.bfloat16,
        )
    )
207
208
209

    # pytorch impl: bf16 output, with fp8 fast accum
    timers.append(
210
211
212
213
214
215
216
217
218
219
220
221
222
        bench_fn(
            label,
            sub_label,
            "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
            torch._scaled_mm,
            a,
            b,
            scale_a=scale_a,
            scale_b=scale_b,
            out_dtype=torch.bfloat16,
            use_fast_accum=True,
        )
    )
223
224
225

    # pytorch impl: fp16 output, without fp8 fast accum
    timers.append(
226
227
228
229
230
231
232
233
234
235
236
237
        bench_fn(
            label,
            sub_label,
            "pytorch_fp8_fp8_fp16_scaled_mm",
            torch._scaled_mm,
            a,
            b,
            scale_a=scale_a,
            scale_b=scale_b,
            out_dtype=torch.float16,
        )
    )
238
239
240

    # pytorch impl: fp16 output, with fp8 fast accum
    timers.append(
241
242
243
244
245
246
247
248
249
250
251
252
253
        bench_fn(
            label,
            sub_label,
            "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
            torch._scaled_mm,
            a,
            b,
            scale_a=scale_a,
            scale_b=scale_b,
            out_dtype=torch.float16,
            use_fast_accum=True,
        )
    )
254
255
256

    # cutlass impl: bf16 output
    timers.append(
257
258
259
260
261
262
263
264
265
266
267
268
        bench_fn(
            label,
            sub_label,
            "cutlass_fp8_fp8_bf16_scaled_mm",
            ops.cutlass_scaled_mm,
            a,
            b,
            scale_a,
            scale_b,
            torch.bfloat16,
        )
    )
269
270
271

    # cutlass impl: bf16 output
    timers.append(
272
273
274
275
276
277
278
279
280
281
282
283
284
        bench_fn(
            label,
            sub_label,
            "cutlass_fp8_fp8_bf16_scaled_sparse_mm",
            ops.cutlass_scaled_sparse_mm,
            a,
            b_compressed,
            e,
            scale_a,
            scale_b,
            torch.bfloat16,
        )
    )
285
286
287

    # cutlass impl: fp16 output
    timers.append(
288
289
290
291
292
293
294
295
296
297
298
299
300
        bench_fn(
            label,
            sub_label,
            "cutlass_fp8_fp8_fp16_scaled_sparse_mm",
            ops.cutlass_scaled_sparse_mm,
            a,
            b_compressed,
            e,
            scale_a,
            scale_b,
            torch.float16,
        )
    )
301
302
303

    # cutlass impl: bf16 output, with bias
    timers.append(
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        bench_fn(
            label,
            sub_label,
            "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
            ops.cutlass_scaled_sparse_mm,
            a,
            b_compressed,
            e,
            scale_a,
            scale_b,
            torch.bfloat16,
            bias,
        )
    )
318
319
320

    # cutlass impl: fp16 output, with bias
    timers.append(
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        bench_fn(
            label,
            sub_label,
            "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
            ops.cutlass_scaled_sparse_mm,
            a,
            b_compressed,
            e,
            scale_a,
            scale_b,
            torch.float16,
            bias.to(dtype=torch.float16),
        )
    )
335
336
337
338

    return timers


339
340
341
def bench(
    dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
342
343
344
345
346
347
348
349
350
351
352
353
354
    if dtype == torch.int8:
        return bench_int8(dtype, m, k, n, label, sub_label)
    if dtype == torch.float8_e4m3fn:
        return bench_fp8(dtype, m, k, n, label, sub_label)
    raise ValueError("unsupported type")


# runner
def print_timers(timers: Iterable[TMeasurement]):
    compare = TBenchmark.Compare(timers)
    compare.print()


355
356
357
def run(
    dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]]
) -> Iterable[TMeasurement]:
358
359
    results = []
    for m, k, n in MKNs:
360
        timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})")
361
362
363
364
365
366
367
        print_timers(timers)
        results.extend(timers)

    return results


# output makers
368
369
370
371
372
373
def make_output(
    data: Iterable[TMeasurement],
    MKNs: Iterable[tuple[int, int, int]],
    base_description: str,
    timestamp=None,
):
374
375
376
377
378
379
380
381
382
383
384
385
386
    print(f"== All Results {base_description} ====")
    print_timers(data)

    # pickle all the results
    timestamp = int(time.time()) if timestamp is None else timestamp
    with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
        pkl.dump(data, f)


# argparse runners


def run_square_bench(args):
387
    dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
    data = run(args.dtype, MKNs)

    make_output(data, MKNs, f"square_bench-{args.dtype}")


def run_range_bench(args):
    dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
    n = len(dim_sizes)
    Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
    Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
    Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
    MKNs = list(zip(Ms, Ks, Ns))
    data = run(args.dtype, MKNs)

    make_output(data, MKNs, f"range_bench-{args.dtype}")


def run_model_bench(args):
    print("Benchmarking models:")
    for i, model in enumerate(args.models):
        print(f"[{i}]  {model}")

411
    def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        KNs = []
        for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
            KN[tp_split_dim] = KN[tp_split_dim] // tp_size
            KNs.append(KN)
        return KNs

    model_bench_data = []
    models_tps = list(itertools.product(args.models, args.tp_sizes))
    for model, tp_size in models_tps:
        Ms = args.batch_sizes
        KNs = model_shapes(model, tp_size)
        MKNs = []
        for m in Ms:
            for k, n in KNs:
                MKNs.append((m, k, n))

        data = run(args.dtype, MKNs)
        model_bench_data.append(data)

    # Print all results
    for data, model_tp in zip(model_bench_data, models_tps):
        model, tp_size = model_tp
        print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
        print_timers(data)

    timestamp = int(time.time())

    all_data = []
    for d in model_bench_data:
        all_data.extend(d)
    # pickle all data
    with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
        pkl.dump(all_data, f)


447
if __name__ == "__main__":
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471

    def to_torch_dtype(dt):
        if dt == "int8":
            return torch.int8
        if dt == "fp8":
            return torch.float8_e4m3fn
        raise ValueError("unsupported dtype")

    parser = FlexibleArgumentParser(
        description="""
Benchmark Cutlass GEMM.

    To run square GEMMs:
        python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
    
    To run constant N and K and sweep M:
        python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
    
    To run dimensions from a model:
        python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
    
    Output:
        - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
            """,  # noqa: E501
472
473
474
475
476
477
478
479
480
        formatter_class=argparse.RawTextHelpFormatter,
    )

    parser.add_argument(
        "--dtype",
        type=to_torch_dtype,
        required=True,
        help="Available options are ['int8', 'fp8']",
    )
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    subparsers = parser.add_subparsers(dest="cmd")

    square_parser = subparsers.add_parser("square_bench")
    square_parser.add_argument("--dim-start", type=int, required=True)
    square_parser.add_argument("--dim-end", type=int, required=True)
    square_parser.add_argument("--dim-increment", type=int, required=True)
    square_parser.set_defaults(func=run_square_bench)

    range_parser = subparsers.add_parser("range_bench")
    range_parser.add_argument("--dim-start", type=int, required=True)
    range_parser.add_argument("--dim-end", type=int, required=True)
    range_parser.add_argument("--dim-increment", type=int, required=True)
    range_parser.add_argument("--m-constant", type=int, default=None)
    range_parser.add_argument("--n-constant", type=int, default=None)
    range_parser.add_argument("--k-constant", type=int, default=None)
    range_parser.set_defaults(func=run_range_bench)

    model_parser = subparsers.add_parser("model_bench")
499
500
501
502
503
504
505
506
507
508
509
510
511
    model_parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES.keys(),
    )
    model_parser.add_argument(
        "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
    )
    model_parser.add_argument(
        "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
    )
512
513
514
515
    model_parser.set_defaults(func=run_model_bench)

    args = parser.parse_args()
    args.func(args)