w8a8_benchmarks.py 13.3 KB
Newer Older
zhuwenwen's avatar
zhuwenwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
import argparse
import copy
import itertools
import pickle as pkl
import time
from typing import Callable, Iterable, List, Tuple

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES

from vllm import _custom_ops as ops
laibao's avatar
laibao committed
14
from vllm.utils import FlexibleArgumentParser
zhuwenwen's avatar
zhuwenwen committed
15

laibao's avatar
laibao committed
16
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
zhuwenwen's avatar
zhuwenwen committed
17
18
19
20
21
22
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]

# helpers


laibao's avatar
laibao committed
23
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
zhuwenwen's avatar
zhuwenwen committed
24
25
26
27
28
    finfo = torch.finfo(torch.float8_e4m3fn)
    return torch.round(tensor.clamp(
        min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)


laibao's avatar
laibao committed
29
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
zhuwenwen's avatar
zhuwenwen committed
30
31
32
33
    return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)


def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
laibao's avatar
laibao committed
34
                      k: int) -> Tuple[torch.Tensor, torch.Tensor]:
zhuwenwen's avatar
zhuwenwen committed
35
36
37
38
39
40
41
42
43
44
45
46
    a = torch.randn((m, k), device='cuda') * 5
    b = torch.randn((n, k), device='cuda').t() * 5

    if dtype == torch.int8:
        return to_int8(a), to_int8(b)
    if dtype == torch.float8_e4m3fn:
        return to_fp8(a), to_fp8(b)

    raise ValueError("unsupported dtype")


# bench
laibao's avatar
laibao committed
47
48
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
             **kwargs) -> TMeasurement:
zhuwenwen's avatar
zhuwenwen committed
49
50
51
    min_run_time = 1

    globals = {
laibao's avatar
laibao committed
52
53
        "args": args,
        "kwargs": kwargs,
zhuwenwen's avatar
zhuwenwen committed
54
55
56
        "fn": fn,
    }
    return TBenchmark.Timer(
laibao's avatar
laibao committed
57
        stmt="fn(*args, **kwargs)",
zhuwenwen's avatar
zhuwenwen committed
58
59
60
61
62
63
64
65
66
67
68
69
70
        globals=globals,
        label=label,
        sub_label=sub_label,
        description=description,
    ).blocked_autorange(min_run_time=min_run_time)


def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
               sub_label: str) -> Iterable[TMeasurement]:
    assert dtype == torch.int8
    a, b = make_rand_tensors(torch.int8, m, n, k)
    scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
    scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
laibao's avatar
laibao committed
71
72
73
    bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
    azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
    azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
zhuwenwen's avatar
zhuwenwen committed
74
75

    timers = []
laibao's avatar
laibao committed
76
77
78
79
80
81
82
    # pytorch impl - bfloat16
    timers.append(
        bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
                 torch.mm, a.to(dtype=torch.bfloat16),
                 b.to(dtype=torch.bfloat16)))

    # pytorch impl - float16
zhuwenwen's avatar
zhuwenwen committed
83
    timers.append(
laibao's avatar
laibao committed
84
85
86
        bench_fn(label, sub_label,
                 "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
                 a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
zhuwenwen's avatar
zhuwenwen committed
87
88
89

    # cutlass impl
    timers.append(
laibao's avatar
laibao committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
                 ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
                 torch.bfloat16))

    # cutlass with bias
    timers.append(
        bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
                 ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
                 bias))

    # cutlass with azp per-tensor
    timers.append(
        bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp",
                 ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
                 torch.bfloat16, azp_adj))

    # cutlass with azp per-tensor + bias
    timers.append(
        bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias",
                 ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
                 torch.bfloat16, azp_adj, None, bias))

    # cutlass with azp per-token
    timers.append(
        bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt",
                 ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
                 torch.bfloat16, azp_adj, azp))

    # cutlass with azp per-token + bias
    timers.append(
        bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias",
                 ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
                 torch.bfloat16, azp_adj, azp, bias))
zhuwenwen's avatar
zhuwenwen committed
123
124
125
126
127
128
129
130
131
132

    return timers


def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
              sub_label: str) -> Iterable[TMeasurement]:
    assert dtype == torch.float8_e4m3fn
    a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
    scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
    scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
laibao's avatar
laibao committed
133
    bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
zhuwenwen's avatar
zhuwenwen committed
134
135
136

    timers = []

laibao's avatar
laibao committed
137
138
139
140
141
142
    # pytorch impl w. bf16
    timers.append(
        bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
                 torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
                 b.to(dtype=torch.bfloat16, device="cuda")))

zhuwenwen's avatar
zhuwenwen committed
143
144
    # pytorch impl: bf16 output, without fp8 fast accum
    timers.append(
laibao's avatar
laibao committed
145
146
147
148
149
150
151
152
153
        bench_fn(label,
                 sub_label,
                 "pytorch_fp8_fp8_bf16_scaled_mm",
                 torch._scaled_mm,
                 a,
                 b,
                 scale_a=scale_a,
                 scale_b=scale_b,
                 out_dtype=torch.bfloat16))
zhuwenwen's avatar
zhuwenwen committed
154
155
156

    # pytorch impl: bf16 output, with fp8 fast accum
    timers.append(
laibao's avatar
laibao committed
157
158
159
160
161
162
163
164
165
166
        bench_fn(label,
                 sub_label,
                 "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
                 torch._scaled_mm,
                 a,
                 b,
                 scale_a=scale_a,
                 scale_b=scale_b,
                 out_dtype=torch.bfloat16,
                 use_fast_accum=True))
zhuwenwen's avatar
zhuwenwen committed
167
168
169

    # pytorch impl: fp16 output, without fp8 fast accum
    timers.append(
laibao's avatar
laibao committed
170
171
172
173
174
175
176
177
178
        bench_fn(label,
                 sub_label,
                 "pytorch_fp8_fp8_fp16_scaled_mm",
                 torch._scaled_mm,
                 a,
                 b,
                 scale_a=scale_a,
                 scale_b=scale_b,
                 out_dtype=torch.float16))
zhuwenwen's avatar
zhuwenwen committed
179
180
181

    # pytorch impl: fp16 output, with fp8 fast accum
    timers.append(
laibao's avatar
laibao committed
182
183
184
185
186
187
188
189
190
191
        bench_fn(label,
                 sub_label,
                 "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
                 torch._scaled_mm,
                 a,
                 b,
                 scale_a=scale_a,
                 scale_b=scale_b,
                 out_dtype=torch.float16,
                 use_fast_accum=True))
zhuwenwen's avatar
zhuwenwen committed
192
193
194

    # cutlass impl: bf16 output
    timers.append(
laibao's avatar
laibao committed
195
196
197
        bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
                 ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
                 torch.bfloat16))
zhuwenwen's avatar
zhuwenwen committed
198
199
    # cutlass impl: fp16 output
    timers.append(
laibao's avatar
laibao committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm",
                 ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16))

    # cutlass impl: bf16 output, with bias
    timers.append(
        bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias",
                 ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
                 bias))

    # cutlass impl: fp16 output, with bias
    timers.append(
        bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias",
                 ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16,
                 bias.to(dtype=torch.float16)))

zhuwenwen's avatar
zhuwenwen committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    return timers


def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
          sub_label: str) -> Iterable[TMeasurement]:
    if dtype == torch.int8:
        return bench_int8(dtype, m, k, n, label, sub_label)
    if dtype == torch.float8_e4m3fn:
        return bench_fp8(dtype, m, k, n, label, sub_label)
    raise ValueError("unsupported type")


# runner
def print_timers(timers: Iterable[TMeasurement]):
    compare = TBenchmark.Compare(timers)
    compare.print()


def run(dtype: torch.dtype,
        MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
    results = []
    for m, k, n in MKNs:
        timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
                       f"MKN=({m}x{k}x{n})")
        print_timers(timers)
        results.extend(timers)

    return results


# output makers
def make_output(data: Iterable[TMeasurement],
                MKNs: Iterable[Tuple[int, int, int]],
                base_description: str,
                timestamp=None):
    print(f"== All Results {base_description} ====")
    print_timers(data)

    # pickle all the results
    timestamp = int(time.time()) if timestamp is None else timestamp
    with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
        pkl.dump(data, f)


# argparse runners


def run_square_bench(args):
    dim_sizes = list(
        range(args.dim_start, args.dim_end + 1, args.dim_increment))
    MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
    data = run(args.dtype, MKNs)

    make_output(data, MKNs, f"square_bench-{args.dtype}")


def run_range_bench(args):
    dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
    n = len(dim_sizes)
    Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
    Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
    Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
    MKNs = list(zip(Ms, Ks, Ns))
    data = run(args.dtype, MKNs)

    make_output(data, MKNs, f"range_bench-{args.dtype}")


def run_model_bench(args):
    print("Benchmarking models:")
    for i, model in enumerate(args.models):
        print(f"[{i}]  {model}")

    def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
        KNs = []
        for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
            KN[tp_split_dim] = KN[tp_split_dim] // tp_size
            KNs.append(KN)
        return KNs

    model_bench_data = []
    models_tps = list(itertools.product(args.models, args.tp_sizes))
    for model, tp_size in models_tps:
        Ms = args.batch_sizes
        KNs = model_shapes(model, tp_size)
        MKNs = []
        for m in Ms:
            for k, n in KNs:
                MKNs.append((m, k, n))

        data = run(args.dtype, MKNs)
        model_bench_data.append(data)

    # Print all results
    for data, model_tp in zip(model_bench_data, models_tps):
        model, tp_size = model_tp
        print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
        print_timers(data)

    timestamp = int(time.time())

    all_data = []
    for d in model_bench_data:
        all_data.extend(d)
    # pickle all data
    with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
        pkl.dump(all_data, f)


if __name__ == '__main__':

    def to_torch_dtype(dt):
        if dt == "int8":
            return torch.int8
        if dt == "fp8":
            return torch.float8_e4m3fn
        raise ValueError("unsupported dtype")

laibao's avatar
laibao committed
333
    parser = FlexibleArgumentParser(
zhuwenwen's avatar
zhuwenwen committed
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        description="""
Benchmark Cutlass GEMM.

    To run square GEMMs:
        python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
    
    To run constant N and K and sweep M:
        python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
    
    To run dimensions from a model:
        python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
    
    Output:
        - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
            """,  # noqa: E501
        formatter_class=argparse.RawTextHelpFormatter)

    parser.add_argument("--dtype",
                        type=to_torch_dtype,
                        required=True,
                        help="Available options are ['int8', 'fp8']")
    subparsers = parser.add_subparsers(dest="cmd")

    square_parser = subparsers.add_parser("square_bench")
    square_parser.add_argument("--dim-start", type=int, required=True)
    square_parser.add_argument("--dim-end", type=int, required=True)
    square_parser.add_argument("--dim-increment", type=int, required=True)
    square_parser.set_defaults(func=run_square_bench)

    range_parser = subparsers.add_parser("range_bench")
    range_parser.add_argument("--dim-start", type=int, required=True)
    range_parser.add_argument("--dim-end", type=int, required=True)
    range_parser.add_argument("--dim-increment", type=int, required=True)
    range_parser.add_argument("--m-constant", type=int, default=None)
    range_parser.add_argument("--n-constant", type=int, default=None)
    range_parser.add_argument("--k-constant", type=int, default=None)
    range_parser.set_defaults(func=run_range_bench)

    model_parser = subparsers.add_parser("model_bench")
    model_parser.add_argument("--models",
                              nargs="+",
                              type=str,
                              default=DEFAULT_MODELS,
                              choices=WEIGHT_SHAPES.keys())
    model_parser.add_argument("--tp-sizes",
                              nargs="+",
                              type=int,
                              default=DEFAULT_TP_SIZES)
    model_parser.add_argument("--batch-sizes",
                              nargs="+",
                              type=int,
                              default=DEFAULT_BATCH_SIZES)
    model_parser.set_defaults(func=run_model_bench)

    args = parser.parse_args()
    args.func(args)