benchmark_machete.py 14.1 KB
Newer Older
1
2
3
4
5
6
import argparse
import copy
import itertools
import math
import pickle as pkl
import time
7
8
from itertools import product
from typing import Callable, Iterable, List, Optional, Tuple
9

10
import pandas as pd
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
    MarlinWorkspace)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    gptq_pack, pack_rows, quantize_weights)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
DEFAULT_TP_SIZES = [1]


def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor:
    w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
    w_q = w_q.t().contiguous().t()  # make col major
    return ops.machete_prepack_B(w_q, wtype)


def make_bench_tensors(
    atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int,
    k: int
) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor,
                                    torch.tensor]]]:
    assert wtype.is_integer(), "TODO: support floating point weights"

    # we want to make sure that weights don't fit into L2 cache between runs so
    #  we construct enough weights to exceed L2 cache, which is 50mb on a H100
    #  so we target total weight size > 2*50mb
    num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits))

    a = torch.randn((m, k), device="cuda", dtype=atype) * 5
    weights = [
        torch.randn((k, n), device="cuda", dtype=atype)
        for _ in range(num_weights)
    ]
    quanitized_weights = [
        quantize_weights(w, wtype, group_size) for w in weights
    ]

    return a, quanitized_weights


# impl


# bench
def bench_fn(label: str, sub_label: str, description: str,
             fn: Callable) -> TMeasurement:

    min_run_time = 1
    return TBenchmark.Timer(
        stmt="fn()",
        globals={
            "fn": fn
        },
        label=label,
        sub_label=sub_label,
        description=description,
    ).blocked_autorange(min_run_time=min_run_time)


def loop_over_weights(
    a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor,
                                         torch.tensor, torch.tensor]],
    fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor],
                 None]):
    for w_ref, w_q, w_s, _ in weights:
        fn(a, w_ref, w_q, w_s)


89
90
91
92
_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None


93
94
95
96
97
98
99
100
101
102
def bench(atype: torch.dtype,
          wtype: ScalarType,
          group_size: int,
          m: int,
          k: int,
          n: int,
          label: str,
          sub_label: str,
          benchmark_marlinv1: bool = True,
          sweep_schedules: bool = True) -> Iterable[TMeasurement]:
103
104
    global _SWEEP_SCHEDULES_RESULTS

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
    sub_label += f", L={len(weights)}"

    weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp)
                       for w_ref, w_q, w_s, w_zp in weights]

    timers = []
    # pytorch impl
    timers.append(
        bench_fn(
            label, sub_label, "torch.matmul", lambda: loop_over_weights(
                a,
                weights,
                lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref),
            )))

    if benchmark_marlinv1:
        w_ref = weights[0][0]

        w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device)
        sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device)
        g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device)

        def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor:
            w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape)
            return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape,
                                          wtype.size_bits)

        def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
            return marlin_permute_scales(w_s, *w_ref.shape, group_size)

        weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q),
                             marlinv1_permute_scales(w_s), w_zp)
                            for w_ref, w_q, w_s, w_zp in weights]

        workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
                                    GPTQ_MARLIN_MAX_PARALLEL)

        # marlinv1
        timers.append(
            bench_fn(
                label, sub_label, "marlin_orig", lambda: loop_over_weights(
                    a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops.
                    gptq_marlin_gemm(a,
                                     w_q,
                                     w_s,
                                     w_zp_empty,
                                     g_idx,
                                     sort_indices,
                                     workspace.scratch,
                                     wtype,
                                     size_m=a.shape[0],
                                     size_n=w_ref.shape[1],
                                     size_k=w_ref.shape[0],
                                     is_k_full=True))))

    # machete
    timers.append(
        bench_fn(
            label, sub_label, "machete_heuristic", lambda: loop_over_weights(
                a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm(
                    a, w_q, wtype, b_scales=w_s, b_group_size=group_size))))

    if sweep_schedules:
        print("Finding best schedule for machete")
        best = None
        best_schedule = None
        schedules = ops.machete_supported_schedules(wtype)
        for schedule in reversed(schedules):
174
175
176
177
178
            schedule_M = int(schedule.split("_")[0].split("x")[1])

            # Prune known bad schedules
            if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
                continue
179
180
181
182
183
184
185
186
187
188
189
190

            def run(a, _, w_q, w_s, schedule=schedule):
                ops.machete_gemm(a,
                                 w_q,
                                 wtype,
                                 w_s,
                                 b_group_size=group_size,
                                 schedule=schedule)

            res = bench_fn(label, sub_label, "machete_best",
                           lambda: loop_over_weights(a, weights_machete, run))

191
192
193
194
195
196
197
198
199
200
201
202
203
204
            results_row = {
                "M": m,
                "K": k,
                "N": n,
                "group_size": group_size,
                "schedule": schedule,
                "median": res.median,
            }
            if _SWEEP_SCHEDULES_RESULTS is None:
                _SWEEP_SCHEDULES_RESULTS = pd.DataFrame(
                    columns=results_row.keys())
            _SWEEP_SCHEDULES_RESULTS.\
                loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
            print(f"  {res.median:5.5} ", schedule)
            if not best or res.median < best.median:
                best = res
                best_schedule = schedule
        print("Best schedule:", best_schedule)
        timers.append(best)

    return timers


# runner
def print_timers(timers: Iterable[TMeasurement]):
    compare = TBenchmark.Compare(timers)
    compare.print()


def run(dtype: torch.dtype, sweep_schedules: bool,
        MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:

    results = []
    for m, k, n in MKNs:
        timers = bench(dtype,
                       scalar_types.uint4b8,
                       128,
                       m,
                       k,
                       n,
                       f"{dtype}-gemm",
                       f"MKN=({m}x{k}x{n})",
                       sweep_schedules=sweep_schedules)
        print_timers(timers)
        results.extend(timers)

    return results


# output makers
def make_output(
    data: Iterable[TMeasurement],
    MKNs: Iterable[Tuple[int, int, int]],
    base_description: str,
    timestamp=None,
):

    print(f"== All Results {base_description} ====")
    print_timers(data)

    # pickle all the results
    timestamp = int(time.time()) if timestamp is None else timestamp
    with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
        pkl.dump(data, f)


# argparse runners


def run_square_bench(args):
    dim_sizes = list(
        range(args.dim_start, args.dim_end + 1, args.dim_increment))
    MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
265

266
267
268
269
270
271
    data = run(args.dtype, args.sweep_schedules, MKNs)

    make_output(data, MKNs, f"square_bench-{args.dtype}")


def run_range_bench(args):
272
273
274
275
276
277
278
279
280
    m_start, k_start, n_start = [int(x) for x in args.dim_start.split(",")]
    m_end, k_end, n_end = [int(x) for x in args.dim_end.split(",")]
    m_increment, k_increment, n_increment = \
        [int(x) for x in args.dim_increment.split(",")]
    Ms = list(range(m_start, m_end + 1, m_increment))
    Ks = list(range(k_start, k_end + 1, k_increment))
    Ns = list(range(n_start, n_end + 1, n_increment))
    MKNs = list(product(Ms, Ks, Ns))

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    data = run(args.dtype, args.sweep_schedules, MKNs)

    make_output(data, MKNs, f"range_bench-{args.dtype}")


def run_model_bench(args):

    print("Benchmarking models:")
    for i, model in enumerate(args.models):
        print(f"[{i}]  {model}")

    def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
        KNs = []
        for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
            KN[tp_split_dim] = KN[tp_split_dim] // tp_size
            KNs.append(KN)
        return KNs

    model_bench_data = []
    models_tps = list(itertools.product(args.models, args.tp_sizes))
    for model, tp_size in models_tps:
        Ms = args.batch_sizes
        KNs = model_shapes(model, tp_size)
        MKNs = []
        for m in Ms:
            for k, n in KNs:
                MKNs.append((m, k, n))

        data = run(args.dtype, args.sweep_schedules, MKNs)
        model_bench_data.append(data)

    # Print all results
    for data, model_tp in zip(model_bench_data, models_tps):
        model, tp_size = model_tp
        print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
        print_timers(data)

    timestamp = int(time.time())

    all_data = []
    for d in model_bench_data:
        all_data.extend(d)
    # pickle all data
    with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
        pkl.dump(all_data, f)


if __name__ == "__main__":

    def to_torch_dtype(dt):
        if dt == "bfloat16":
            return torch.bfloat16
        if dt == "float16":
            return torch.float16
        raise ValueError("unsupported dtype")

    parser = FlexibleArgumentParser(
        description="""
Benchmark Machete GEMM.

    To run square GEMMs:
        python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
    
    To run constant N and K and sweep M:
        python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
    
    To run dimensions from a model:
        python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
    
    Output:
        - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
            """,  # noqa: E501
        formatter_class=argparse.RawTextHelpFormatter,
    )

    parser.add_argument(
        "--dtype",
        type=to_torch_dtype,
        required=True,
        help="Available options are ['bfloat16', 'float16']",
    )
    parser.add_argument(
        "--sweep-schedules",
        action="store_true",
        help="Run a sweep over all supported schedules",
    )
367
368
369
    parser.add_argument("--sweep-csv-out",
                        help="CSV to store sweep results",
                        default="sch_sweep_results.csv")
370
371
372
373
374
375
376
377
378
    subparsers = parser.add_subparsers(dest="cmd", required=True)

    square_parser = subparsers.add_parser("square_bench")
    square_parser.add_argument("--dim-start", type=int, required=True)
    square_parser.add_argument("--dim-end", type=int, required=True)
    square_parser.add_argument("--dim-increment", type=int, required=True)
    square_parser.set_defaults(func=run_square_bench)

    range_parser = subparsers.add_parser("range_bench")
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
    range_parser.add_argument(
        "--dim-start",
        type=str,
        required=True,
        help="Start value for M,K,N as common separated list")
    range_parser.add_argument(
        "--dim-end",
        type=str,
        required=True,
        help="End value (inclusive) for M,K,N as common separated list")
    range_parser.add_argument(
        "--dim-increment",
        type=str,
        required=True,
        help="Increment value for M,K,N as common separated list")
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    range_parser.set_defaults(func=run_range_bench)

    model_parser = subparsers.add_parser("model_bench")
    model_parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES.keys(),
    )
    model_parser.add_argument("--tp-sizes",
                              nargs="+",
                              type=int,
                              default=DEFAULT_TP_SIZES)
    model_parser.add_argument("--batch-sizes",
                              nargs="+",
                              type=int,
                              default=DEFAULT_BATCH_SIZES)
    model_parser.set_defaults(func=run_model_bench)

    args = parser.parse_args()
415
416

    _SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out
417
    args.func(args)
418
419
420

    if _SWEEP_SCHEDULES_RESULTS is not None:
        _SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV)