benchmark_machete.py 21.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
import argparse
import copy
import itertools
import math
8
import os
9
10
import pickle as pkl
import time
11
from collections.abc import Callable, Iterable
12
from dataclasses import dataclass
13
from itertools import product
14

15
import pandas as pd
16
17
18
19
20
21
22
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
23
24
25
26
27
    GPTQ_MARLIN_MAX_PARALLEL,
    GPTQ_MARLIN_MIN_THREAD_N,
    marlin_permute_scales,
    marlin_zero_points,
)
28
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
29
30
    MarlinWorkspace,
)
31
from vllm.model_executor.layers.quantization.utils.quant_utils import (
32
33
34
    pack_rows,
    quantize_weights,
)
35
36
37
38
39
40
41
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
DEFAULT_TP_SIZES = [1]

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False)

if NVTX_PROFILE:
    import nvtx


def terse_type_name(dt):
    return {
        torch.bfloat16: "bf16",
        torch.float16: "fp16",
        torch.int8: "int8",
        torch.float8_e4m3fn: "fp8",
        torch.float: "float",
        torch.int: "int",
    }[dt]


@dataclass
class BenchmarkTensors:
    w_ref: torch.Tensor
    a: torch.Tensor

    w_q: torch.Tensor
65
    group_size: int | None
66
67
    wtype: ScalarType
    w_g_s: torch.Tensor
68
69
70
    w_g_zp: torch.Tensor | None
    w_ch_s: torch.Tensor | None
    w_tok_s: torch.Tensor | None
71
72
73
74
75
76


@dataclass
class TypeConfig:
    act_type: torch.dtype
    weight_type: ScalarType
77
78
79
80
81
    output_type: torch.dtype | None
    group_scale_type: torch.dtype | None
    group_zero_type: torch.dtype | None
    channel_scale_type: torch.dtype | None
    token_scale_type: torch.dtype | None
82
83
84
85
86
87
88
89
90


def rand_data(shape, dtype=torch.float16, scale=1):
    if dtype.is_floating_point:
        return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype)
    else:
        return torch.randint(-15, 15, shape, dtype=dtype, device="cuda")


91
92
93
94
def quantize_and_pack(
    atype: torch.dtype,
    w: torch.Tensor,
    wtype: ScalarType,
95
96
    stype: torch.dtype | None,
    group_size: int | None,
97
98
    zero_points: bool = False,
):
99
100
101
102
103
104
105
106
    assert wtype.is_integer(), "TODO: support floating point weights"

    w_ref, w_q, w_s, w_zp = quantize_weights(
        w,
        wtype,
        group_size=group_size,
        zero_points=zero_points,
        # to match how the kernel applies zps
107
108
        ref_zero_points_after_scales=True,
    )
109
110

    w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
111
    return w_ref, w_q, w_s, w_zp
112
113


114
def create_bench_tensors(
115
    shape: tuple[int, int, int], types: TypeConfig, group_size: int | None
116
) -> list[BenchmarkTensors]:
117
    m, n, k = shape
118
119
120
121

    # we want to make sure that weights don't fit into L2 cache between runs so
    #  we construct enough weights to exceed L2 cache, which is 50mb on a H100
    #  so we target total weight size > 2*50mb
122
123
124
    num_weights = math.ceil(
        2 * 50 * 1024**2 * 8 / (k * n * types.weight_type.size_bits)
    )
125
126
127

    a = rand_data((m, k), types.act_type, scale=5)

128
    benchmark_tensors: list[BenchmarkTensors] = []
129
130
131
132
133
134
135
136
137
    for _ in range(num_weights):
        w = rand_data((k, n), types.act_type, scale=5)

        if types.group_scale_type is not None:
            w = w.to(types.group_scale_type)
        if w.dtype.itemsize == 1:
            w = w.to(torch.float16)

        w_ref, w_q_packed, w_s, w_zp = quantize_and_pack(
138
139
140
141
142
143
144
            a.dtype,
            w,
            types.weight_type,
            types.group_scale_type,
            group_size,
            types.group_zero_type is not None,
        )
145
146
147
148
149
150
151

        if not a.dtype.is_floating_point:
            aiinfo = torch.iinfo(a.dtype)
            w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max)

        w_ref = w_ref.to(torch.float32)

152
153
154
155
156
157
158
159
160
161
        w_ch_s = (
            None
            if types.channel_scale_type is None
            else rand_data((n,), types.channel_scale_type)
        )
        w_tok_s = (
            None
            if types.token_scale_type is None
            else rand_data((m,), types.token_scale_type)
        )
162
163

        benchmark_tensors.append(
164
165
166
167
168
169
170
171
172
173
174
175
            BenchmarkTensors(
                w_ref=w_ref,
                a=a,
                w_q=w_q_packed,
                wtype=types.weight_type,
                w_g_s=w_s,
                w_g_zp=w_zp,
                group_size=group_size,
                w_ch_s=w_ch_s,
                w_tok_s=w_tok_s,
            )
        )
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

    return benchmark_tensors


def torch_matmul_f16_create_bench_fn(bt: BenchmarkTensors) -> Callable:
    a = bt.a
    w = bt.w_ref.to(bt.a.dtype)  # use float reference tensor
    if a.dtype not in [torch.float16, torch.bfloat16]:
        a = a.to(torch.float16)
        w = w.to(torch.float16)
    return lambda: torch.matmul(a, w)


def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable:
    if bt.w_ch_s is not None and bt.w_tok_s is not None:
        scale_a = bt.w_tok_s.to(torch.float32)
        scale_b = bt.w_ch_s.to(torch.float32)
    else:
        scale_a = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
        scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
    w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t()
    return lambda: ops.cutlass_scaled_mm(
198
199
        bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16
    )
200
201
202
203
204


def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
    device = bt.a.device

205
206
207
    workspace = MarlinWorkspace(
        bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
    )
208
209
210
211

    if bt.w_g_zp is None:
        w_zp = torch.empty(0, dtype=torch.int, device=device)
    else:
212
213
214
        w_zp = marlin_zero_points(
            bt.w_g_zp, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
        )
215
216
217
218

    if bt.group_size is None:
        w_s = torch.tensor([], device="cuda", dtype=torch.half)
    else:
219
220
221
        w_s = marlin_permute_scales(
            bt.w_g_s, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.group_size
        )
222
223
224

    sort_indices = torch.empty(0, dtype=torch.int, device=device)
    g_idx = torch.empty(0, dtype=torch.int, device=device)
225
226
227
    w_q = ops.gptq_marlin_repack(
        bt.w_q, sort_indices, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
    )
228
229
230
231
232
233

    if bt.a.dtype.is_floating_point:
        assert bt.w_ch_s is None
        assert bt.w_tok_s is None
        assert bt.group_size is not None

234
235
        fn = lambda: ops.gptq_marlin_gemm(
            a=bt.a,
236
            c=None,
237
            b_q_weight=w_q,
238
            b_bias=None,
239
            b_scales=w_s,
240
            global_scale=None,
241
242
243
244
245
246
247
248
249
250
251
            b_zeros=w_zp,
            g_idx=g_idx,
            perm=sort_indices,
            workspace=workspace.scratch,
            b_q_type=bt.wtype,
            size_m=bt.a.shape[0],
            size_n=bt.w_ref.shape[1],
            size_k=bt.w_ref.shape[0],
            is_k_full=True,
            is_zp_float=False,
        )
252
253
254
    else:
        assert bt.a.dtype == torch.int8
        assert bt.wtype == scalar_types.uint4b8
255
        raise NotImplementedError("QQQ is not supported anymore")
256
257
258
259

    return fn


260
261
262
def machete_create_bench_fn(
    bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
) -> Callable:
263
    w_q = bt.w_q.t().contiguous().t()  # make col major
264
265
266
    w_q = ops.machete_prepack_B(
        w_q, bt.a.dtype, bt.wtype, None if bt.w_g_s is None else bt.w_g_s.dtype
    )
267
268
269
270
271
272
273

    w_g_zp = bt.w_g_zp
    if w_g_zp is not None:
        w_g_zp = -1 * bt.w_g_s * (w_g_zp.to(bt.w_g_s.dtype))

    return lambda: ops.machete_mm(
        a=bt.a,
274
        b_q=w_q,
275
276
277
278
279
280
281
282
283
        b_type=bt.wtype,
        b_group_scales=bt.w_g_s,
        b_group_zeros=w_g_zp,
        b_group_size=bt.group_size,
        b_channel_scales=bt.w_ch_s,
        a_token_scales=bt.w_tok_s,
        out_type=out_type,
        schedule=schedule,
    )
284
285


286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def cutlass_w4a8_create_bench_fn(
    bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
) -> Callable:
    w_q = bt.w_q.t().contiguous().t()  # make col major
    w_q = ops.cutlass_encode_and_reorder_int4b(w_q)
    # expects fp8 scales
    w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn))

    return lambda: ops.cutlass_w4a8_mm(
        a=bt.a,
        b_q=w_q,
        b_group_scales=w_s,
        b_group_size=bt.group_size,
        b_channel_scales=bt.w_ch_s,
        a_token_scales=bt.w_tok_s,
        maybe_schedule=schedule,
    )


305
306
307
308
# impl

# bench

309

310
def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]):
311
312
313
314
315
316
    min_run_time = 1 if not NVTX_PROFILE else 0.1
    res = TBenchmark.Timer(
        stmt="""
        for fn in fns:
            fn()
        """,
317
        globals={"fns": fns},
318
319
320
321
322
        label=label,
        sub_label=sub_label,
        description=description,
    ).blocked_autorange(min_run_time=min_run_time)

323
    if NVTX_PROFILE:
324
325
326
327
        with (
            nvtx.annotate("mm-bench"),
            nvtx.annotate(f"{label}|{sub_label}|{description}"),
        ):
328
            fns[0]()
329

330
    return res
331
332


333
334
_SWEEP_SCHEDULES_RESULTS: pd.DataFrame | None = None
_SWEEP_SCHEDULES_RESULTS_CSV: str | None = None
335
336


337
338
339
340
341
342
343
344
345
346
def bench(
    types: TypeConfig,
    group_size: int,
    m: int,
    k: int,
    n: int,
    label: str,
    sub_label: str,
    sweep_schedules: bool = True,
) -> list[TMeasurement]:
347
348
349
    benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
    sub_label += f", L={len(benchmark_tensors)}"

350
    name_type_string = f"W{types.weight_type}" + f"-A{terse_type_name(types.act_type)}"
351
352
353
354
355
356
357
358
359
360
    if types.group_scale_type is not None:
        name_type_string += f"-GS{terse_type_name(types.group_scale_type)}"
    if types.group_zero_type is not None:
        name_type_string += f"-GZ{terse_type_name(types.group_zero_type)}"
    if group_size is not None:
        name_type_string += f"-G{group_size}"
    if types.channel_scale_type is not None:
        name_type_string += f"-CS{terse_type_name(types.channel_scale_type)}"
    if types.token_scale_type is not None:
        name_type_string += f"-TS{terse_type_name(types.token_scale_type)}"
361
362
363
364

    timers = []
    # pytorch impl
    timers.append(
365
        bench_fns(
366
367
368
369
370
371
            label,
            sub_label,
            "torch.matmul (fp16)",
            [torch_matmul_f16_create_bench_fn(bt) for bt in benchmark_tensors],
        )
    )
372

373
374
375
    if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn:
        timers.append(
            bench_fns(
376
377
378
379
380
381
                label,
                sub_label,
                f"cutlass_scaled_mm ({terse_type_name(types.act_type)})",
                [cutlass_scaled_mm_create_bench_fn(bt) for bt in benchmark_tensors],
            )
        )
382
383

    if types.act_type != torch.float8_e4m3fn:
384
        timers.append(
385
386
387
388
389
390
391
            bench_fns(
                label,
                sub_label,
                f"marlin ({name_type_string})",
                [marlin_create_bench_fn(bt) for bt in benchmark_tensors],
            )
        )
392
393
394

    # machete
    timers.append(
395
396
397
398
399
400
401
402
403
404
        bench_fns(
            label,
            sub_label,
            f"machete ({name_type_string})",
            [
                machete_create_bench_fn(bt, out_type=types.output_type)
                for bt in benchmark_tensors
            ],
        )
    )
405

406
407
408
409
410
411
412
413
414
415
416
417
418
419
    # cutlass w4a8
    if types.act_type == torch.float8_e4m3fn and group_size == 128:
        timers.append(
            bench_fns(
                label,
                sub_label,
                f"cutlass w4a8 ({name_type_string})",
                [
                    cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type)
                    for bt in benchmark_tensors
                ],
            )
        )

420
    if sweep_schedules:
421
422
        global _SWEEP_SCHEDULES_RESULTS

423
424
425
        print("Finding best schedule for machete")
        best = None
        best_schedule = None
426
427
428
429
430
431
432
        schedules = ops.machete_supported_schedules(
            a_type=types.act_type,
            b_type=types.weight_type,
            group_scales_type=types.group_scale_type,
            group_zeros_type=types.group_zero_type,
            token_scales_type=types.token_scale_type,
            channel_scales_type=types.channel_scale_type,
433
434
            out_type=types.output_type,
        )
435
436
437
438

        if schedules is None or len(schedules) == 0:
            raise ValueError("No schedules found to sweep")

439
        for schedule in reversed(schedules):
440
441
442
443
444
            schedule_M = int(schedule.split("_")[0].split("x")[1])

            # Prune known bad schedules
            if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
                continue
445

446
447
448
449
450
451
452
453
454
455
456
            res = bench_fns(
                label,
                sub_label,
                "machete_best",
                [
                    machete_create_bench_fn(
                        bt, out_type=types.output_type, schedule=schedule
                    )
                    for bt in benchmark_tensors
                ],
            )
457

458
459
460
461
462
463
464
465
466
            results_row = {
                "M": m,
                "K": k,
                "N": n,
                "group_size": group_size,
                "schedule": schedule,
                "median": res.median,
            }
            if _SWEEP_SCHEDULES_RESULTS is None:
467
468
                _SWEEP_SCHEDULES_RESULTS = pd.DataFrame(columns=results_row.keys())
            _SWEEP_SCHEDULES_RESULTS.loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
469

470
471
472
473
474
475
476
477
478
479
480
            print(f"  {res.median:5.5} ", schedule)
            if not best or res.median < best.median:
                best = res
                best_schedule = schedule
        print("Best schedule:", best_schedule)
        timers.append(best)

    return timers


# runner
481
def print_timers(timers: list[TMeasurement]):
482
483
484
485
    compare = TBenchmark.Compare(timers)
    compare.print()


486
def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
487
488
    types = TypeConfig(
        act_type=args.act_type,
489
490
491
        weight_type=scalar_types.uint4b8
        if args.group_zero_type is None
        else scalar_types.uint4,
492
493
494
495
496
497
        output_type=args.out_type,
        group_scale_type=args.group_scale_type,
        group_zero_type=args.group_zero_type,
        channel_scale_type=args.channel_scale_type,
        token_scale_type=args.token_scale_type,
    )
498

499
    results: list[TMeasurement] = []
500
    for m, k, n in MKNs:
501
502
503
504
505
506
507
508
509
510
        timers = bench(
            types,
            args.group_size,
            m,
            k,
            n,
            f"{args.act_type}-gemm",
            f"MKN=({m}x{k}x{n})",
            sweep_schedules=args.sweep_schedules,
        )
511
512
513
514
515
516
517
518
        print_timers(timers)
        results.extend(timers)

    return results


# output makers
def make_output(
519
520
    data: list[TMeasurement],
    MKNs: Iterable[tuple[int, int, int]],
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
    base_description: str,
    timestamp=None,
):
    print(f"== All Results {base_description} ====")
    print_timers(data)

    # pickle all the results
    timestamp = int(time.time()) if timestamp is None else timestamp
    with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
        pkl.dump(data, f)


# argparse runners


def run_square_bench(args):
537
    dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
538
539
540
541
542
543
544
    MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
    data = run(args.dtype, args.sweep_schedules, MKNs)

    make_output(data, MKNs, f"square_bench-{args.dtype}")


def run_range_bench(args):
545
546
    m_start, k_start, n_start = (int(x) for x in args.dim_start.split(","))
    m_end, k_end, n_end = (int(x) for x in args.dim_end.split(","))
547
548
549
    m_increment, k_increment, n_increment = (
        int(x) for x in args.dim_increment.split(",")
    )
550
551
552
553
554
    Ms = list(range(m_start, m_end + 1, m_increment))
    Ks = list(range(k_start, k_end + 1, k_increment))
    Ns = list(range(n_start, n_end + 1, n_increment))
    MKNs = list(product(Ms, Ks, Ns))

555
556
557
558
559
560
561
562
563
564
    data = run(args.dtype, args.sweep_schedules, MKNs)

    make_output(data, MKNs, f"range_bench-{args.dtype}")


def run_model_bench(args):
    print("Benchmarking models:")
    for i, model in enumerate(args.models):
        print(f"[{i}]  {model}")

565
    def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
        KNs = []
        for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
            KN[tp_split_dim] = KN[tp_split_dim] // tp_size
            KNs.append(KN)
        return KNs

    model_bench_data = []
    models_tps = list(itertools.product(args.models, args.tp_sizes))
    for model, tp_size in models_tps:
        Ms = args.batch_sizes
        KNs = model_shapes(model, tp_size)
        MKNs = []
        for m in Ms:
            for k, n in KNs:
                MKNs.append((m, k, n))

582
        data = run(args, MKNs)
583
584
        model_bench_data.append(data)

585
586
    type_string = f"{args.act_type}"

587
588
589
    # Print all results
    for data, model_tp in zip(model_bench_data, models_tps):
        model, tp_size = model_tp
590
        print(f"== Results {type_string} {model}-TP{tp_size} ====")
591
592
        print_timers(data)

593
    timestr = time.strftime("%Y%m%d-%H%M%S")
594

595
    all_results = []
596
    for d in model_bench_data:
597
598
        all_results.extend(d)

599
    # pickle all data
600
601
602
    with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f:
        args_dict = vars(args)
        args_dict.pop("func")
603
604
605
606
607
608
609
        pkl.dump(
            {
                "args": args_dict,
                "results": all_results,
            },
            f,
        )
610
611
612
613
614


if __name__ == "__main__":

    def to_torch_dtype(dt):
615
616
617
618
619
620
621
622
623
624
625
626
        return {
            "bfloat16": torch.bfloat16,
            "float16": torch.float16,
            "int8": torch.int8,
            "float8_e4m3fn": torch.float8_e4m3fn,
            "int": torch.int,
            "float": torch.float,
        }[dt]

    class ToTorchDtype(argparse.Action):
        def __call__(self, parser, namespace, values, option_string=None):
            setattr(namespace, self.dest, to_torch_dtype(values))
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646

    parser = FlexibleArgumentParser(
        description="""
Benchmark Machete GEMM.

    To run square GEMMs:
        python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
    
    To run constant N and K and sweep M:
        python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
    
    To run dimensions from a model:
        python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
    
    Output:
        - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
            """,  # noqa: E501
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
647
648
        "--act-type",
        action=ToTorchDtype,
649
        required=True,
650
        choices=["bfloat16", "float16", "int8", "float8_e4m3fn"],
651
652
653
654
    )
    parser.add_argument(
        "--group-scale-type",
        action=ToTorchDtype,
655
        choices=["bfloat16", "float16"],
656
657
658
659
    )
    parser.add_argument(
        "--group-zero-type",
        type=to_torch_dtype,
660
        choices=["bfloat16", "float16"],
661
662
663
664
    )
    parser.add_argument(
        "--channel-scale-type",
        action=ToTorchDtype,
665
        choices=["float"],
666
667
668
669
    )
    parser.add_argument(
        "--token-scale-type",
        action=ToTorchDtype,
670
        choices=["float"],
671
672
673
674
    )
    parser.add_argument(
        "--out-type",
        action=ToTorchDtype,
675
        choices=["bfloat16", "float16"],
676
677
678
679
680
681
    )
    parser.add_argument(
        "--group-size",
        type=int,
        help="Available options are ['None', '-1', '128'], default=128",
        default=128,
682
683
684
685
686
687
    )
    parser.add_argument(
        "--sweep-schedules",
        action="store_true",
        help="Run a sweep over all supported schedules",
    )
688
689
690
691
692
    parser.add_argument(
        "--sweep-csv-out",
        help="CSV to store sweep results",
        default="sch_sweep_results.csv",
    )
693
694
695
696
697
698
699
700
701
    subparsers = parser.add_subparsers(dest="cmd", required=True)

    square_parser = subparsers.add_parser("square_bench")
    square_parser.add_argument("--dim-start", type=int, required=True)
    square_parser.add_argument("--dim-end", type=int, required=True)
    square_parser.add_argument("--dim-increment", type=int, required=True)
    square_parser.set_defaults(func=run_square_bench)

    range_parser = subparsers.add_parser("range_bench")
702
703
704
705
    range_parser.add_argument(
        "--dim-start",
        type=str,
        required=True,
706
707
        help="Start value for M,K,N as common separated list",
    )
708
709
710
711
    range_parser.add_argument(
        "--dim-end",
        type=str,
        required=True,
712
713
        help="End value (inclusive) for M,K,N as common separated list",
    )
714
715
716
717
    range_parser.add_argument(
        "--dim-increment",
        type=str,
        required=True,
718
719
        help="Increment value for M,K,N as common separated list",
    )
720
721
722
723
724
725
726
727
728
729
    range_parser.set_defaults(func=run_range_bench)

    model_parser = subparsers.add_parser("model_bench")
    model_parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES.keys(),
    )
730
731
732
733
734
735
    model_parser.add_argument(
        "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
    )
    model_parser.add_argument(
        "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
    )
736
737
738
    model_parser.set_defaults(func=run_model_bench)

    args = parser.parse_args()
739
740

    _SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out
741
    args.func(args)
742
743
744

    if _SWEEP_SCHEDULES_RESULTS is not None:
        _SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV)