bench_cutlass_mla.py 5.16 KB
Newer Older
1
2
3
import argparse
import copy
import itertools
4
import os
5
6
7
8
9

import torch
import triton
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from sglang.srt.utils import get_device_capability

# CI environment detection
IS_CI = (
    os.getenv("CI", "false").lower() == "true"
    or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)

# CI environment uses simplified parameters
if IS_CI:
    bs_range = [1]  # Single batch size for CI
    qlen_range = [64]  # Single sequence length for CI
else:
    bs_range = [1, 8, 32, 64, 128, 256]
    qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

configs = list(itertools.product(bs_range, qlen_range))


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["batch_size", "seq_len"],
        x_vals=configs,
        x_log=False,
        line_arg="provider",
        line_vals=[
            "128 heads",
            "64 heads",
            "32 heads",
            "16 heads",
        ],
        line_names=[
            "128 heads",
            "64 heads",
            "32 heads",
            "16 heads",
        ],
        styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
        ylabel="GB/s",
        plot_name="cutlass mla",
        args={},
    )
)
def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
    d = 576
55
    dn = 64
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    dv = 512

    h_q_map = {
        "128": 128,
        "64": 64,
        "32": 32,
        "16": 16,
    }
    parsed_h_q = next(
        (value for key, value in h_q_map.items() if key in provider), None
    )

    if parsed_h_q is None:
        raise ValueError(f"Unknown head configuration in provider: {provider}")
    h_q = parsed_h_q

    seq_lens = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")
    max_seq_len = seq_lens.max().item()
    block_num = (max_seq_len + block_size - 1) // block_size

    # Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
    # One 128-wide tile can hold (128 // block_size) small blocks.
    pack_factor = 128 // block_size
    block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor

81
82
83
84
85
    qn = (
        torch.randn(h_q, batch_size, d - dn, dtype=torch.bfloat16, device="cuda")
        * 100.0
    )
    qr = torch.randn(batch_size, h_q, dn, dtype=torch.bfloat16, device="cuda") * 100.0
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    block_table = torch.randint(
        0,
        batch_size * block_num,
        (batch_size, block_num),
        dtype=torch.int32,
        device="cuda",
    )

    kv_cache = torch.randn(
        block_table.numel(), block_size, d, dtype=torch.bfloat16, device="cuda"
    )

    workspace_size = cutlass_mla_get_workspace_size(
        block_num * block_size, batch_size, num_kv_splits=num_kv_splits
    )
    workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)

    quantiles = [0.5, 0.2, 0.8]
104
    ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
105
        lambda: cutlass_mla_decode(
106
107
108
109
110
111
            qn.transpose(0, 1),
            qr,
            kv_cache,
            seq_lens,
            block_table,
            workspace,
112
            1.44,
113
            num_kv_splits,
114
115
116
117
        ),
        quantiles=quantiles,
    )

118
119
    q_size = qn.numel() * qn.element_size() + qr.numel() * qr.element_size()

120
121
    gbps = (
        lambda ms: (
122
            q_size + q_size * dv / d + kv_cache.numel() * kv_cache.element_size()
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        )
        * 1e-9
        / (ms * 1e-3)
    )
    return gbps(ms), gbps(max_ms), gbps(min_ms)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--block-sizes",
        nargs="+",
        type=int,
        default=[1, 32, 64, 128],
        help="List of batch sizes",
    )
    parser.add_argument(
        "--num-kv-splits",
        nargs="+",
        type=int,
        default=[-1],
        help="List of batch sizes",
    )
    args = parser.parse_args()

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    # Skip in CI environment or unsupported architectures
    if IS_CI:
        major, minor = get_device_capability()
        if major is None or major < 10:  # Requires compute capability 10.0+
            print("Skipping Cutlass MLA benchmark in CI environment")
            if major is not None:
                print(
                    f"Cutlass MLA requires compute capability 10.0+, but found {major}.{minor}"
                )
            else:
                print("Could not determine device capability")
        else:
            for block_size in args.block_sizes:
                for kv_split in args.num_kv_splits:
                    print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
                    benchmark.run(
                        print_data=True,
                        block_size=block_size,
                        num_kv_splits=kv_split,
                    )
            print("Benchmark finished!")
    else:
        for block_size in args.block_sizes:
            for kv_split in args.num_kv_splits:
                print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
                benchmark.run(
                    print_data=True,
                    block_size=block_size,
                    num_kv_splits=kv_split,
                )
        print("Benchmark finished!")