example_gqa_decode.py 22.3 KB
Newer Older
1
2
3
4
5
6
7
8
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
import itertools
9
10
from functools import lru_cache
from typing import Tuple, Dict
11

12
torch.random.manual_seed(0)
13

14

15
16
17
def get_configs():
    block_N = [64, 128]
    block_H = [64]
18
    num_split = [1, 2, 4, 8]
19
20
21
22
    num_stages = [1, 2, 3]
    threads = [128]
    _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))

23
    configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs]
24
25
26
    return configs


27
28
29
30
31
32
33
34
35
36
@lru_cache(maxsize=1)
def get_heuristic_config() -> Tuple[Dict, int]:
    # Get CUDA device properties
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available")
    device = torch.cuda.current_device()
    sm_major, sm_minor = torch.cuda.get_device_capability(device)
    sm_version = sm_major * 10 + sm_minor
    print(f"CUDA device capability: {sm_version}")
    if sm_version == 89:
37
        cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=0, threads=128)
38
    else:
39
        cfg = dict(block_N=128, block_H=64, num_split=8, num_stages=2, threads=128)
40
41
42
    return cfg, sm_version


43
# TODO(lei): fix warp specialized and tma lower pass
44
def get_pass_configs():
45
    return {tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}
46
47


48
49
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs())
50
51
def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, threads):
    scale = (1.0 / dim) ** 0.5 * 1.44269504  # log2(e)
52
53
54
55
56
57
58
59
    shape_q = [batch, heads, dim]
    shape_k = [batch, seqlen_kv, groups, dim]
    shape_v = [batch, seqlen_kv, groups, dim]
    shape_o = [batch, heads, dim]
    dtype = "float16"
    accum_dtype = "float"
    kv_group_num = heads // groups

60
61
62
63
64
65
    part_shape = [batch, heads, num_split, dim]
    valid_block_H = min(block_H, kv_group_num)
    valid_block_N = min(block_N, seqlen_kv // num_split)

    @T.macro
    def flash_attn(
66
67
68
69
70
        Q: T.Tensor(shape_q, dtype),
        K: T.Tensor(shape_k, dtype),
        V: T.Tensor(shape_v, dtype),
        mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
        Output: T.Tensor([batch, heads, dim], dtype),
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    ):
        with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
            Q_shared = T.alloc_shared([block_H, dim], dtype)
            K_shared = T.alloc_shared([block_N, dim], dtype)
            V_shared = T.alloc_shared([block_N, dim], dtype)
            O_shared = T.alloc_shared([valid_block_H, dim], dtype)
            acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
            acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
            mask_local = T.alloc_fragment([block_N], "uint8")
            acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
            scores_max = T.alloc_fragment([block_H], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
            scores_scale = T.alloc_fragment([block_H], accum_dtype)
            scores_sum = T.alloc_fragment([block_H], accum_dtype)
            logsum = T.alloc_fragment([block_H], accum_dtype)

            bid = bx
            hid = by
            cur_kv_head = hid // (kv_group_num // valid_block_H)

91
            T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
92
93
94
95
96
97
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))

            loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
            for k in T.Pipelined(loop_range, num_stages=num_stages):
98
99
                T.copy(K[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_shared)
                T.copy(mask[bid, k * block_N : (k + 1) * block_N, cur_kv_head], mask_local)
100
101
102
                T.clear(acc_s)
                T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
                for i, j in T.Parallel(block_H, block_N):
103
                    acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], -T.infinity(accum_dtype))
104
                T.copy(scores_max, scores_max_prev)
105
                T.fill(scores_max, -T.infinity(accum_dtype))
106
                T.reduce_max(acc_s, scores_max, dim=1, clear=False)
107
108
                for i in T.Parallel(block_H):
                    scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
109
                for i in T.Parallel(block_H):
110
111
112
113
114
115
116
                    scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                for i, j in T.Parallel(block_H, block_N):
                    acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                T.reduce_sum(acc_s, scores_sum, dim=1)
                for i in T.Parallel(block_H):
                    logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
                T.copy(acc_s, acc_s_cast)
117
                for i, j in T.Parallel(block_H, dim):
118
                    acc_o[i, j] *= scores_scale[i]
119
                T.copy(V[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], V_shared)
120
121
122
123
124
125
                T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
            for i, j in T.Parallel(block_H, dim):
                acc_o[i, j] /= logsum[i]
            for i in T.Parallel(block_H):
                logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
            T.copy(acc_o[:valid_block_H, :], O_shared)
126
            T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
127
128
129

    @T.macro
    def flash_attn_split(
130
131
132
133
134
135
        Q: T.Tensor(shape_q, dtype),
        K: T.Tensor(shape_k, dtype),
        V: T.Tensor(shape_v, dtype),
        mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
        glse: T.Tensor([batch, heads, num_split], dtype),
        Output_partial: T.Tensor(part_shape, dtype),
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    ):
        with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
            Q_shared = T.alloc_shared([block_H, dim], dtype)
            K_shared = T.alloc_shared([block_N, dim], dtype)
            V_shared = T.alloc_shared([block_N, dim], dtype)
            O_shared = T.alloc_shared([valid_block_H, dim], dtype)
            acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
            acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
            mask_local = T.alloc_fragment([block_N], "uint8")
            acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
            scores_max = T.alloc_fragment([block_H], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
            scores_scale = T.alloc_fragment([block_H], accum_dtype)
            scores_sum = T.alloc_fragment([block_H], accum_dtype)
            logsum = T.alloc_fragment([block_H], accum_dtype)

            bid = bx
            hid = by
            sid = bz
            cur_kv_head = hid // (kv_group_num // valid_block_H)

157
            T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
158
159
160
161
162
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))

            loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
163

164
165
            for k in T.Pipelined(loop_range, num_stages=num_stages):
                T.copy(
166
167
168
169
170
171
172
173
                    K[
                        bid,
                        (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
                        cur_kv_head,
                        :,
                    ],
                    K_shared,
                )
174
                T.copy(
175
176
177
178
179
180
181
                    mask[
                        bid,
                        (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
                        cur_kv_head,
                    ],
                    mask_local,
                )
182
183
184
                T.clear(acc_s)
                T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
                for i, j in T.Parallel(block_H, block_N):
185
                    acc_s[i, j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j], -T.infinity(accum_dtype))
186
187
188
                T.copy(scores_max, scores_max_prev)
                T.fill(scores_max, -T.infinity(accum_dtype))
                T.reduce_max(acc_s, scores_max, dim=1, clear=False)
189
190
                for i in T.Parallel(block_H):
                    scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
191
                for i in T.Parallel(block_H):
192
193
194
195
                    scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                for i, j in T.Parallel(block_H, block_N):
                    acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                T.reduce_sum(acc_s, scores_sum, dim=1)
196
                for i in T.Parallel(block_H):
197
198
199
200
201
                    logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
                T.copy(acc_s, acc_s_cast)
                for i, j in T.Parallel(block_H, dim):
                    acc_o[i, j] *= scores_scale[i]
                T.copy(
202
203
204
205
206
207
208
209
                    V[
                        bid,
                        (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
                        cur_kv_head,
                        :,
                    ],
                    V_shared,
                )
210
211
212
213
214
215
216
217
218
219
                T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
            for i, j in T.Parallel(block_H, dim):
                acc_o[i, j] /= logsum[i]
            for i in T.Parallel(block_H):
                logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale

            for i in T.Parallel(block_H):
                if i < valid_block_H:
                    glse[bid, hid * valid_block_H + i, sid] = logsum[i]
            T.copy(acc_o[:valid_block_H, :], O_shared)
220
            T.copy(O_shared, Output_partial[bid, hid * valid_block_H : (hid + 1) * valid_block_H, sid, :])
221
222
223

    @T.macro
    def combine(
224
225
226
        glse: T.Tensor([batch, heads, num_split], dtype),
        Output_partial: T.Tensor(part_shape, dtype),
        Output: T.Tensor(shape_o, dtype),
227
228
229
230
231
    ):
        with T.Kernel(heads, batch, threads=128) as (by, bz):
            po_local = T.alloc_fragment([dim], dtype)
            o_accum_local = T.alloc_fragment([dim], accum_dtype)
            lse_local = T.alloc_fragment([num_split, 128], dtype)
232
            lse_logsum_local = T.alloc_fragment([128], accum_dtype)
233
            lse_max_local = T.alloc_fragment([128], accum_dtype)
234
            scale_local = T.alloc_fragment([128], accum_dtype)
235

236
237
238
239
240
241
242
243
            T.annotate_layout(
                {
                    lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
                    lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i),
                    # lse_local: (local_id, thread_id)
                    lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)),
                }
            )
244
245
246
247
248
249

            T.clear(lse_logsum_local)
            T.clear(o_accum_local)
            for k, j in T.Parallel(num_split, 128):
                lse_local[k, j] = glse[bz, by, k]
            T.reduce_max(lse_local, lse_max_local, dim=0, clear=True)
250
251
252
253
254
            for k in T.serial(num_split):
                for j in T.Parallel(128):
                    lse_logsum_local[j] += T.exp2(lse_local[k, j] - lse_max_local[j])
            for j in T.Parallel(128):
                lse_logsum_local[j] = T.log2(lse_logsum_local[j]) + lse_max_local[j]
255
256
257
            for k in T.serial(num_split):
                for i in T.Parallel(dim):
                    po_local[i] = Output_partial[bz, by, k, i]
258
259
260
                for j in T.Parallel(128):
                    scale_local[j] = T.exp2(lse_local[k, j] - lse_logsum_local[j])
                # Note: Pay attention to dim and the number of threads in Parallel
261
                for i in T.Parallel(dim):
262
                    o_accum_local[i] += po_local[i] * scale_local[i]
263
264
265
266
267
            for i in T.Parallel(dim):
                Output[bz, by, i] = o_accum_local[i]

    @T.prim_func
    def flashattn_gqa_decode_split(
268
269
270
271
272
273
274
        Q: T.Tensor(shape_q, dtype),
        K: T.Tensor(shape_k, dtype),
        V: T.Tensor(shape_v, dtype),
        mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
        glse: T.Tensor([batch, heads, num_split], dtype),
        Output_partial: T.Tensor(part_shape, dtype),
        Output: T.Tensor(shape_o, dtype),
275
276
277
278
279
280
    ):
        flash_attn_split(Q, K, V, mask, glse, Output_partial)
        combine(glse, Output_partial, Output)

    @T.prim_func
    def flashattn_gqa_decode_no_split(
281
282
283
284
285
286
287
        Q: T.Tensor(shape_q, dtype),
        K: T.Tensor(shape_k, dtype),
        V: T.Tensor(shape_v, dtype),
        mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
        glse: T.Tensor([batch, heads, num_split], dtype),
        Output_partial: T.Tensor(part_shape, dtype),
        Output: T.Tensor(shape_o, dtype),
288
289
290
291
292
    ):
        flash_attn(Q, K, V, mask, Output)

    if num_split > 1:
        return flashattn_gqa_decode_split
293
    else:
294
        return flashattn_gqa_decode_no_split
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309


def ref_program(query, key, value, mask, glse, Output_partial):
    #     """
    #     Inputs:
    #     - query (Tensor): [batch, heads, dim]
    #     - key (Tensor): [batch, seqlen_kv, groups, dim]
    #     - value (Tensor): [batch, seqlen_kv, groups, dim]
    #     - mask (Tensor): [batch, seqlen_kv, groups]
    #     Outputs:
    #     - output (Tensor): [batch, heads, dim]
    #     """
    dim = query.shape[-1]
    num_head_groups = query.shape[1] // key.shape[2]
    scale = dim**0.5
310
311
    key = rearrange(key, "b n h d -> b h n d")  # [batch_size, groups, seqlen_kv, dim]
    value = rearrange(value, "b n h d -> b h n d")  # [batch_size, groups, seqlen_kv, dim]
312

313
    query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups)  # [batch_size, num_head_groups, groups, dim]
314

315
    scores = einsum(query, key, "b g h d, b h s d -> b g h s")  # [batch_size, num_head_groups, groups, seqlen_kv]
316
    if mask is not None:
317
        mask = rearrange(mask, "b s h -> b h s")
318
        mask = mask.unsqueeze(1)
319
        scores = scores.masked_fill(mask == 0, float("-inf"))
320

321
    attention = F.softmax(scores / scale, dim=-1)  # [batch_size, num_head_groups, groups, seqlen_kv]
322

323
324
    out = einsum(attention, value, "b g h s, b h s d -> b g h d")  # [batch_size, num_head_groups, groups, dim]
    out = rearrange(out, "b g h d -> b (h g) d")  # [batch_size, heads, dim]
325
326
327
328
    return out


def flash_split_ref(Q, K, V, mask):
329
    num_split = 16
330
331
332
333
334
335
336
337
    batch = Q.size(0)
    nheads = Q.size(1)
    groups = K.size(2)
    dim = Q.size(-1)
    block_N = 32
    seqlen_kv = K.size(1)
    num_head_groups = nheads // groups

338
    scale = (1.0 / dim) ** 0.5 * 1.44269504  # log2(e)
339
    acc_s = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float)
340
    acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float16)
341
342
    acc_o = torch.empty((batch, num_head_groups, groups, dim), device="cuda", dtype=torch.float)
    scores_max = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
343
    scores_max_prev = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
344
345
346
347
348
349
350
    scores_scale = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
    scores_sum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
    logsum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
    gacc_o = torch.empty((num_split, batch, nheads, dim), device="cuda", dtype=torch.float)
    glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float)

    Q_ = Q * scale
351
    Q_ = rearrange(Q_, "b (h g) d -> b g h d", g=num_head_groups)
352
353
354
355

    for ks in range(num_split):
        acc_o.fill_(0)
        logsum.fill_(0)
356
357
        scores_max.fill_(float("-inf"))
        scores_max_prev.fill_(float("-inf"))
358
359
        for i in range(int((seqlen_kv // num_split) / block_N)):
            acc_s.fill_(0)
360
361
362
363
364
            acc_s = torch.einsum(
                "bghd,bkhd->bghk",
                Q_,
                K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
            )  # [batch, nheads, block_N]
365
            if mask is not None:
366
367
                mask_local = mask[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :]
                mask_local = rearrange(mask_local, "b s h -> b h s")
368
                mask_local = mask_local.unsqueeze(1)
369
                acc_s = acc_s.masked_fill(mask_local == 0, float("-inf"))
370
371
372
373
374
375
376
            scores_max_prev = scores_max
            scores_max = acc_s.max(dim=-1, keepdim=False).values  # [batch, nheads]
            scores_scale = torch.exp2(scores_max_prev - scores_max)  # [batch, nheads]
            acc_o *= scores_scale[:, :, :, None]
            acc_s = torch.exp2(acc_s - scores_max[:, :, :, None])
            acc_s_cast = acc_s.to(torch.float16)  # [batch, nheads, block_N]
            acc_o += torch.einsum(
377
378
379
380
                "bghk,bkhd->bghd",
                acc_s_cast,
                V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
            )
381
382
            scores_sum = acc_s.sum(dim=-1, keepdim=False)
            logsum = logsum * scores_scale + scores_sum
383
384
        acc_o_out = rearrange(acc_o, "b g h d->b (h g) d")
        logsum_out = rearrange(logsum, "b g h->b (h g)")
385
        acc_o_out /= logsum_out[:, :, None]
386
        logsum_out = torch.log2(logsum_out) + rearrange(scores_max, "b g h->b (h g)")
387
388
389
390
391
392
393
        gacc_o[ks, :, :, :] = acc_o_out
        glogsum[ks, :, :] = logsum_out

    return glogsum.to(torch.float16).permute(1, 2, 0), gacc_o.to(torch.float16).permute(1, 2, 0, 3)


def reduce_ref(Q, K, V, mask, glse, Output_partial):
394
    num_split = 16
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    o = torch.empty_like(Output_partial[:, :, 0, :]).fill_(0)
    lse_logsum = torch.empty_like(glse[:, :, 0]).fill_(0)  # [batch, heads]
    lse_max = glse.max(dim=2, keepdim=False).values
    for ks in range(num_split):
        lse = glse[:, :, ks]
        lse_logsum += torch.exp2(lse - lse_max)
    lse_logsum = torch.log2(lse_logsum) + lse_max
    for ks in range(num_split):
        lse = glse[:, :, ks]
        scale = torch.exp2(lse - lse_logsum)  # [batch, heads]
        o += Output_partial[:, :, ks, :] * scale[:, :, None]
    return o.to(torch.float16)


409
410
411
412
413
414
415
416
417
418
419
420
421
def ref_split_program(Q, K, V, mask, glse=None, Output_partial=None):
    glse_, Output_partial_ = flash_split_ref(Q, K, V, mask)
    return reduce_ref(Q, K, V, mask, glse_, Output_partial_)


def print_red_warning(msg):
    print(f"\033[91m{msg}\033[0m")


def calc_sim(x, y, name="tensor"):
    x, y = x.data.double(), y.data.double()
    denominator = (x * x + y * y).sum()
    if denominator == 0:
422
        print_red_warning(f"{name} all zero")
423
424
425
426
427
428
429
        return 1
    sim = 2 * (x * y).sum() / denominator
    return sim


def assert_similar(x, y, eps=1e-2, name="tensor", assert_=False, print_=True):
    sim = calc_sim(x, y, name)
430
    diff = 1.0 - sim
431
    if not (0 <= diff <= eps):
432
        print_red_warning(f"{name} Error: {diff}")
433
        if assert_:
434
            raise AssertionError(f"{name} Error: {diff}")
435
436
    else:
        if print_:
437
            print(f"passed: {name} diff={diff}")
438
439


440
def main(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192, dim: int = 128, tune: bool = False):
441
    batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim
442
443
444
445
    qk_flops = 2 * batch * heads * kv_seqlen * dim
    pv_flops = 2 * batch * heads * kv_seqlen * dim
    total_flops = qk_flops + pv_flops

446
    if not tune:
447
        config, sm_version = get_heuristic_config()
448
        kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config)
449
        profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
450
451
452
453
454

        q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16)
        k = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16)
        v = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16)
        mask = torch.randint(0, 2, (batch, kv_seqlen, groups), device="cuda", dtype=torch.uint8)
455
456
457
        split = config["num_split"]
        glse = torch.empty(batch, heads, split, device="cuda", dtype=torch.float16)
        Output_partial = torch.empty(batch, heads, split, dim, device="cuda", dtype=torch.float16)
458
459
460
461
        o = kernel(q, k, v, mask, glse, Output_partial)
        o_ref = ref_program(q, k, v, mask, glse, Output_partial)
        o_ref_split = ref_split_program(q, k, v, mask, glse, Output_partial)

462
463
464
465
        print(o)
        print(o_ref)

        assert_similar(o, o_ref, name="o_ref")
466
        assert_similar(o, o_ref_split, name="o_ref_split")
467

468
        print("All checks pass.")
469
        latency = profiler.do_bench(ref_program, warmup=500)
470
471
        print("Ref: {:.2f} ms".format(latency))
        print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
472
        latency = profiler.do_bench(warmup=500)
473
474
475
        print("Tile-lang: {:.2f} ms".format(latency))
        print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
    else:
476
477
478
479
        kernel = flashattn(batch, heads, groups, kv_seqlen, dim)
        best_latency = kernel.latency
        best_config = kernel.config
        ref_latency = kernel.ref_latency
480
481
482
        print(f"Best latency: {best_latency}")
        print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
        print(f"Best config: {best_config}")
483
484
485
486
487
        print(f"Ref latency: {ref_latency}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
488
489
490
491
492
493
    parser.add_argument("--batch", type=int, default=1, help="batch size")
    parser.add_argument("--heads", type=int, default=32, help="heads")
    parser.add_argument("--groups", type=int, default=8, help="groups")
    parser.add_argument("--kv_seqlen", type=int, default=8192, help="kv sequence length")
    parser.add_argument("--dim", type=int, default=128, help="dim")
    parser.add_argument("--tune", action="store_true", help="tune configs")
494
495
    args = parser.parse_args()
    main(args.batch, args.heads, args.groups, args.kv_seqlen, args.dim, args.tune)