"vscode:/vscode.git/clone" did not exist on "61a97c32f64641738d2cc623708f28046768224e"
sparse_benchmarks.py 13.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import argparse
import copy
import itertools
import pickle as pkl
import time
9
10
from collections.abc import Iterable
from typing import Callable
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from utils import make_rand_sparse_tensors
from weight_shapes import WEIGHT_SHAPES

from vllm import _custom_ops as ops
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]


# bench
27
28
29
def bench_fn(
    label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
) -> TMeasurement:
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    min_run_time = 1

    globals = {
        "args": args,
        "kwargs": kwargs,
        "fn": fn,
    }
    return TBenchmark.Timer(
        stmt="fn(*args, **kwargs)",
        globals=globals,
        label=label,
        sub_label=sub_label,
        description=description,
    ).blocked_autorange(min_run_time=min_run_time)


46
47
48
def bench_int8(
    dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
49
50
51
52
    assert dtype == torch.int8
    b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
    scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
    scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
53
    bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
54

55
56
57
    out = ops.cutlass_scaled_sparse_mm(
        a, b_compressed, e, scale_a, scale_b, torch.bfloat16
    )
58
59
60
61
62
63
64
65
66
67
68
69
    out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)

    if not torch.allclose(out, out_ref):
        print("Incorrect results")
        print(out)
        print(out_ref)
    else:
        print("Correct results")

    timers = []
    # pytorch impl - bfloat16
    timers.append(
70
71
72
73
74
75
76
77
78
        bench_fn(
            label,
            sub_label,
            "pytorch_bf16_bf16_bf16_matmul-no-scales",
            torch.mm,
            a.to(dtype=torch.bfloat16),
            b.to(dtype=torch.bfloat16),
        )
    )
79
80
81

    # pytorch impl - float16
    timers.append(
82
83
84
85
86
87
88
89
90
        bench_fn(
            label,
            sub_label,
            "pytorch_fp16_fp16_fp16_matmul-no-scales",
            torch.mm,
            a.to(dtype=torch.float16),
            b.to(dtype=torch.float16),
        )
    )
91
92
93

    # cutlass impl
    timers.append(
94
95
96
97
98
99
100
101
102
103
104
105
        bench_fn(
            label,
            sub_label,
            "cutlass_i8_i8_bf16_scaled_mm",
            ops.cutlass_scaled_mm,
            a,
            b,
            scale_a,
            scale_b,
            torch.bfloat16,
        )
    )
106
107
108

    # cutlass with bias
    timers.append(
109
110
111
112
113
114
115
116
117
118
119
120
121
        bench_fn(
            label,
            sub_label,
            "cutlass_i8_i8_bf16_scaled_mm_bias",
            ops.cutlass_scaled_mm,
            a,
            b,
            scale_a,
            scale_b,
            torch.bfloat16,
            bias,
        )
    )
122
123
124

    # cutlass sparse impl
    timers.append(
125
126
127
128
129
130
131
132
133
134
135
136
137
        bench_fn(
            label,
            sub_label,
            "cutlass_i8_i8_bf16_scaled_sparse_mm",
            ops.cutlass_scaled_sparse_mm,
            a,
            b_compressed,
            e,
            scale_a,
            scale_b,
            torch.bfloat16,
        )
    )
138
139
140

    # cutlass sparse with bias
    timers.append(
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        bench_fn(
            label,
            sub_label,
            "cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
            ops.cutlass_scaled_sparse_mm,
            a,
            b_compressed,
            e,
            scale_a,
            scale_b,
            torch.bfloat16,
            bias,
        )
    )
155
156
157
158

    return timers


159
160
161
def bench_fp8(
    dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
162
    assert dtype == torch.float8_e4m3fn
163
    b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
164
165
    scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
    scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
166
    bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
167

168
169
170
    out = ops.cutlass_scaled_sparse_mm(
        a, b_compressed, e, scale_a, scale_b, torch.bfloat16
    )
171
172
173
174
175
176
177
178
179
180
181
182
183
    out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)

    if not torch.allclose(out, out_ref):
        print("Incorrect results")
        print(out)
        print(out_ref)
    else:
        print("Correct results")

    timers = []

    # pytorch impl w. bf16
    timers.append(
184
185
186
187
188
189
190
191
192
        bench_fn(
            label,
            sub_label,
            "pytorch_bf16_bf16_bf16_matmul-no-scales",
            torch.mm,
            a.to(dtype=torch.bfloat16, device="cuda"),
            b.to(dtype=torch.bfloat16, device="cuda"),
        )
    )
193
194
195

    # pytorch impl: bf16 output, without fp8 fast accum
    timers.append(
196
197
198
199
200
201
202
203
204
205
206
207
        bench_fn(
            label,
            sub_label,
            "pytorch_fp8_fp8_bf16_scaled_mm",
            torch._scaled_mm,
            a,
            b,
            scale_a=scale_a,
            scale_b=scale_b,
            out_dtype=torch.bfloat16,
        )
    )
208
209
210

    # pytorch impl: bf16 output, with fp8 fast accum
    timers.append(
211
212
213
214
215
216
217
218
219
220
221
222
223
        bench_fn(
            label,
            sub_label,
            "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
            torch._scaled_mm,
            a,
            b,
            scale_a=scale_a,
            scale_b=scale_b,
            out_dtype=torch.bfloat16,
            use_fast_accum=True,
        )
    )
224
225
226

    # pytorch impl: fp16 output, without fp8 fast accum
    timers.append(
227
228
229
230
231
232
233
234
235
236
237
238
        bench_fn(
            label,
            sub_label,
            "pytorch_fp8_fp8_fp16_scaled_mm",
            torch._scaled_mm,
            a,
            b,
            scale_a=scale_a,
            scale_b=scale_b,
            out_dtype=torch.float16,
        )
    )
239
240
241

    # pytorch impl: fp16 output, with fp8 fast accum
    timers.append(
242
243
244
245
246
247
248
249
250
251
252
253
254
        bench_fn(
            label,
            sub_label,
            "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
            torch._scaled_mm,
            a,
            b,
            scale_a=scale_a,
            scale_b=scale_b,
            out_dtype=torch.float16,
            use_fast_accum=True,
        )
    )
255
256
257

    # cutlass impl: bf16 output
    timers.append(
258
259
260
261
262
263
264
265
266
267
268
269
        bench_fn(
            label,
            sub_label,
            "cutlass_fp8_fp8_bf16_scaled_mm",
            ops.cutlass_scaled_mm,
            a,
            b,
            scale_a,
            scale_b,
            torch.bfloat16,
        )
    )
270
271
272

    # cutlass impl: bf16 output
    timers.append(
273
274
275
276
277
278
279
280
281
282
283
284
285
        bench_fn(
            label,
            sub_label,
            "cutlass_fp8_fp8_bf16_scaled_sparse_mm",
            ops.cutlass_scaled_sparse_mm,
            a,
            b_compressed,
            e,
            scale_a,
            scale_b,
            torch.bfloat16,
        )
    )
286
287
288

    # cutlass impl: fp16 output
    timers.append(
289
290
291
292
293
294
295
296
297
298
299
300
301
        bench_fn(
            label,
            sub_label,
            "cutlass_fp8_fp8_fp16_scaled_sparse_mm",
            ops.cutlass_scaled_sparse_mm,
            a,
            b_compressed,
            e,
            scale_a,
            scale_b,
            torch.float16,
        )
    )
302
303
304

    # cutlass impl: bf16 output, with bias
    timers.append(
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        bench_fn(
            label,
            sub_label,
            "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
            ops.cutlass_scaled_sparse_mm,
            a,
            b_compressed,
            e,
            scale_a,
            scale_b,
            torch.bfloat16,
            bias,
        )
    )
319
320
321

    # cutlass impl: fp16 output, with bias
    timers.append(
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        bench_fn(
            label,
            sub_label,
            "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
            ops.cutlass_scaled_sparse_mm,
            a,
            b_compressed,
            e,
            scale_a,
            scale_b,
            torch.float16,
            bias.to(dtype=torch.float16),
        )
    )
336
337
338
339

    return timers


340
341
342
def bench(
    dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
343
344
345
346
347
348
349
350
351
352
353
354
355
    if dtype == torch.int8:
        return bench_int8(dtype, m, k, n, label, sub_label)
    if dtype == torch.float8_e4m3fn:
        return bench_fp8(dtype, m, k, n, label, sub_label)
    raise ValueError("unsupported type")


# runner
def print_timers(timers: Iterable[TMeasurement]):
    compare = TBenchmark.Compare(timers)
    compare.print()


356
357
358
def run(
    dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]]
) -> Iterable[TMeasurement]:
359
360
    results = []
    for m, k, n in MKNs:
361
        timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})")
362
363
364
365
366
367
368
        print_timers(timers)
        results.extend(timers)

    return results


# output makers
369
370
371
372
373
374
def make_output(
    data: Iterable[TMeasurement],
    MKNs: Iterable[tuple[int, int, int]],
    base_description: str,
    timestamp=None,
):
375
376
377
378
379
380
381
382
383
384
385
386
387
    print(f"== All Results {base_description} ====")
    print_timers(data)

    # pickle all the results
    timestamp = int(time.time()) if timestamp is None else timestamp
    with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
        pkl.dump(data, f)


# argparse runners


def run_square_bench(args):
388
    dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
    data = run(args.dtype, MKNs)

    make_output(data, MKNs, f"square_bench-{args.dtype}")


def run_range_bench(args):
    dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
    n = len(dim_sizes)
    Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
    Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
    Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
    MKNs = list(zip(Ms, Ks, Ns))
    data = run(args.dtype, MKNs)

    make_output(data, MKNs, f"range_bench-{args.dtype}")


def run_model_bench(args):
    print("Benchmarking models:")
    for i, model in enumerate(args.models):
        print(f"[{i}]  {model}")

412
    def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
        KNs = []
        for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
            KN[tp_split_dim] = KN[tp_split_dim] // tp_size
            KNs.append(KN)
        return KNs

    model_bench_data = []
    models_tps = list(itertools.product(args.models, args.tp_sizes))
    for model, tp_size in models_tps:
        Ms = args.batch_sizes
        KNs = model_shapes(model, tp_size)
        MKNs = []
        for m in Ms:
            for k, n in KNs:
                MKNs.append((m, k, n))

        data = run(args.dtype, MKNs)
        model_bench_data.append(data)

    # Print all results
    for data, model_tp in zip(model_bench_data, models_tps):
        model, tp_size = model_tp
        print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
        print_timers(data)

    timestamp = int(time.time())

    all_data = []
    for d in model_bench_data:
        all_data.extend(d)
    # pickle all data
    with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
        pkl.dump(all_data, f)


448
if __name__ == "__main__":
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472

    def to_torch_dtype(dt):
        if dt == "int8":
            return torch.int8
        if dt == "fp8":
            return torch.float8_e4m3fn
        raise ValueError("unsupported dtype")

    parser = FlexibleArgumentParser(
        description="""
Benchmark Cutlass GEMM.

    To run square GEMMs:
        python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
    
    To run constant N and K and sweep M:
        python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
    
    To run dimensions from a model:
        python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
    
    Output:
        - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
            """,  # noqa: E501
473
474
475
476
477
478
479
480
481
        formatter_class=argparse.RawTextHelpFormatter,
    )

    parser.add_argument(
        "--dtype",
        type=to_torch_dtype,
        required=True,
        help="Available options are ['int8', 'fp8']",
    )
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
    subparsers = parser.add_subparsers(dest="cmd")

    square_parser = subparsers.add_parser("square_bench")
    square_parser.add_argument("--dim-start", type=int, required=True)
    square_parser.add_argument("--dim-end", type=int, required=True)
    square_parser.add_argument("--dim-increment", type=int, required=True)
    square_parser.set_defaults(func=run_square_bench)

    range_parser = subparsers.add_parser("range_bench")
    range_parser.add_argument("--dim-start", type=int, required=True)
    range_parser.add_argument("--dim-end", type=int, required=True)
    range_parser.add_argument("--dim-increment", type=int, required=True)
    range_parser.add_argument("--m-constant", type=int, default=None)
    range_parser.add_argument("--n-constant", type=int, default=None)
    range_parser.add_argument("--k-constant", type=int, default=None)
    range_parser.set_defaults(func=run_range_bench)

    model_parser = subparsers.add_parser("model_bench")
500
501
502
503
504
505
506
507
508
509
510
511
512
    model_parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES.keys(),
    )
    model_parser.add_argument(
        "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
    )
    model_parser.add_argument(
        "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
    )
513
514
515
516
    model_parser.set_defaults(func=run_model_bench)

    args = parser.parse_args()
    args.func(args)