bench_cutlass_mla.py 3.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import copy
import itertools

import torch
import triton
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size

bs_range = [1, 8, 32, 64, 128, 256]
qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]

configs = list(itertools.product(bs_range, qlen_range))


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["batch_size", "seq_len"],
        x_vals=configs,
        x_log=False,
        line_arg="provider",
        line_vals=[
            "128 heads",
            "64 heads",
            "32 heads",
            "16 heads",
        ],
        line_names=[
            "128 heads",
            "64 heads",
            "32 heads",
            "16 heads",
        ],
        styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
        ylabel="GB/s",
        plot_name="cutlass mla",
        args={},
    )
)
def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
    d = 576
    dv = 512

    h_q_map = {
        "128": 128,
        "64": 64,
        "32": 32,
        "16": 16,
    }
    parsed_h_q = next(
        (value for key, value in h_q_map.items() if key in provider), None
    )

    if parsed_h_q is None:
        raise ValueError(f"Unknown head configuration in provider: {provider}")
    h_q = parsed_h_q

    seq_lens = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")
    max_seq_len = seq_lens.max().item()
    block_num = (max_seq_len + block_size - 1) // block_size

    # Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
    # One 128-wide tile can hold (128 // block_size) small blocks.
    pack_factor = 128 // block_size
    block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor

    q = torch.randn(batch_size, h_q, d, dtype=torch.bfloat16, device="cuda") * 100.0
    block_table = torch.randint(
        0,
        batch_size * block_num,
        (batch_size, block_num),
        dtype=torch.int32,
        device="cuda",
    )

    kv_cache = torch.randn(
        block_table.numel(), block_size, d, dtype=torch.bfloat16, device="cuda"
    )

    workspace_size = cutlass_mla_get_workspace_size(
        block_num * block_size, batch_size, num_kv_splits=num_kv_splits
    )
    workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)

    quantiles = [0.5, 0.2, 0.8]
    ms, min_ms, max_ms = triton.testing.do_bench(
        lambda: cutlass_mla_decode(
            q, kv_cache, seq_lens, block_table, workspace, num_kv_splits
        ),
        quantiles=quantiles,
    )

    gbps = (
        lambda ms: (
            q.numel() * q.element_size()
            + q.numel() * q.element_size() * dv / d
            + kv_cache.numel() * kv_cache.element_size()
        )
        * 1e-9
        / (ms * 1e-3)
    )
    return gbps(ms), gbps(max_ms), gbps(min_ms)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--block-sizes",
        nargs="+",
        type=int,
        default=[1, 32, 64, 128],
        help="List of batch sizes",
    )
    parser.add_argument(
        "--num-kv-splits",
        nargs="+",
        type=int,
        default=[-1],
        help="List of batch sizes",
    )
    args = parser.parse_args()

    for block_size in args.block_sizes:
        for kv_split in args.num_kv_splits:
            print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
            benchmark.run(
                print_data=True,
                show_plots=True,
                save_path="bench_blackwell_mla_res",
                block_size=block_size,
                num_kv_splits=kv_split,
            )

    print("Benchmark finished!")