example_chunk_delta_bwd.py 22.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Reference: fla/ops/common/chunk_delta_h.py

import sys  # noqa: F401

import tilelang
import tilelang.language as T

print(tilelang.__file__, flush=True)

# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
    import fla
15

16
17
18
19
20
21
22
23
24
25
26
27
    print(fla.__file__, flush=True)
    from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu
except ImportError:
    print("fla not found, using tilelang implementation")
    fla = None

import torch
import torch.nn.functional as F

torch.random.manual_seed(0)
# torch.set_printoptions(profile="full")

28
from test_utils import assert_similar
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52


def prepare_input(
    B,
    S,
    H,
    DK,
    DV,
    chunk_size,
    input_dtype,
    output_dtype,
    accum_dtype,
    gate_dtype,
    state_dtype,
):
    Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
    K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
    K = F.normalize(K, dim=-1, p=2)
    W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
    # Note: G should be in logspace and do chunkwise cumsum
    G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
    G = F.logsigmoid(G)
    try:
        from fla.ops.utils.cumsum import chunk_local_cumsum
53

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        G = chunk_local_cumsum(G, chunk_size)
    except ImportError:
        print("fla not found, skip cumsum")

    h0 = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
    dht = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
    dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
    dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
    return Q, K, W, G, h0, dht, dO, dv


def prepare_input_fake(
    B,
    S,
    H,
    DK,
    DV,
    chunk_size,
    input_dtype,
    output_dtype,
    accum_dtype,
    gate_dtype,
    state_dtype,
):
    Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
    K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
    W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
    G = torch.ones(B, S, H, dtype=gate_dtype).cuda()
    h0 = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda()
    dht = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda()
    dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
    dv = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
    return Q, K, W, G, h0, dht, dO, dv


def prepare_output(
    B,
    S,
    H,
    DK,
    DV,
    chunk_size,
    output_dtype,
    gate_dtype,
    state_dtype,
):
    BS = S // chunk_size
    dh = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda()
    dh0 = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda()
    dv2 = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
    return dh, dh0, dv2


def torch_chunk_gated_delta_rule_bwd_dhu(
    Q: torch.Tensor,
    K: torch.Tensor,
    W: torch.Tensor,
    G: torch.Tensor,
    h0: torch.Tensor,
    dht: torch.Tensor,
    dO: torch.Tensor,
    dv: torch.Tensor,
    scale: float,
    use_g: bool,
    use_initial_state: bool,
    use_final_state_gradient: bool,
    input_dtype,
    output_dtype,
    accum_dtype,
    gate_dtype,
    state_dtype,
):
    B, S, H, DK = Q.shape
    DV = dv.shape[-1]
    block_S = 64
    BS = S // block_S
130
131
132
133
134
    dh, dh0, dv2 = (
        torch.empty((B, BS, H, DK, DV), dtype=output_dtype),
        torch.empty((B, H, DK, DV), dtype=state_dtype),
        torch.empty((B, S, H, DV), dtype=output_dtype),
    )
135
136
137
138
139
140
141
142
143
144
145
    dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype)
    dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype)
    Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype)

    if use_final_state_gradient:
        dh_tmp = dht.clone().to(accum_dtype)
    else:
        dh_tmp = torch.zeros_like(dht).to(accum_dtype)

    for i_s in range(BS - 1, -1, -1):
        dh[:, i_s, :, :, :] = dh_tmp
146
        dv_tmp = torch.matmul(K[:, i_s * block_S : (i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), dh_tmp.to(K.dtype)).permute(0, 2, 1, 3)
147
148
149
150
        if use_g:
            for i_bh in range(B * H):
                i_b, i_h = i_bh // H, i_bh % H
                for i_s2 in range(block_S):
151
152
                    if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h] <= 0:
                        dv_tmp[i_b, i_s2, i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h])
153
154
                    else:
                        dv_tmp[i_b, i_s2, i_h, :] = 0
155
156
        dv_tmp += dv[:, i_s * block_S : (i_s + 1) * block_S, :, :]
        dv2[:, i_s * block_S : (i_s + 1) * block_S, :, :] = dv_tmp
157
158
159
160
161
162

        if use_g:
            G_last = G[:, i_s * block_S + block_S - 1, :]
            for i_bh in range(B * H):
                i_b, i_h = i_bh // H, i_bh % H
                dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h])
163
            Q_tmp = Q[:, i_s * block_S : (i_s + 1) * block_S, :, :]
164
165
166
167
            for i_s2 in range(block_S):
                for i_k in range(DK):
                    Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :])
        Q_tmp *= scale
168
169
        W_tmp = W[:, i_s * block_S : (i_s + 1) * block_S, :, :]
        dO_tmp = dO[:, i_s * block_S : (i_s + 1) * block_S, :, :]
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

        torch.backends.cuda.matmul.allow_tf32 = True
        dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3))
        dh_tmp -= torch.matmul(W_tmp.permute(0, 2, 3, 1), dv_tmp.permute(0, 2, 1, 3))
        torch.backends.cuda.matmul.allow_tf32 = False

    if use_initial_state:
        dh0 = dh_tmp[:, :, :, :]
    else:
        dh0 = torch.zeros_like(dh_tmp[:, :, :, :])
    print(dh0.dtype)

    return dh, dh0, dv2


@tilelang.jit(out_idx=[-3, -2, -1])
def tilelang_chunk_gated_delta_rule_bwd_dhu(
    # task config
    B,
    S,
    H,
    DK,
    DV,
    input_dtype,
    output_dtype,
    accum_dtype,
    gate_dtype,
    state_dtype,
    chunk_size,
    scale,
    use_g=True,
    use_initial_state=True,
    use_final_state_gradient=True,
    # kernel config
    block_DV=64,
    threads=256,
    num_stages=0,
):
    block_S = chunk_size
    # Should support cu_seqlen
    BS = S // block_S

    Q_shape = (B, S, H, DK)
    K_shape = (B, S, H, DK)
    W_shape = (B, S, H, DK)
    G_shape = (B, S, H)
    h0_shape = (B, H, DK, DV)
    dht_shape = (B, H, DK, DV)
    dO_shape = (B, S, H, DV)
    dv_shape = (B, S, H, DV)

    dh_shape = (B, BS, H, DK, DV)
    dh0_shape = (B, H, DK, DV)
    dv2_shape = (B, S, H, DV)

    @T.prim_func
    def kernel(
227
228
229
230
231
232
233
234
235
236
237
238
239
        # Input
        Q: T.Tensor(Q_shape, dtype=input_dtype),
        K: T.Tensor(K_shape, dtype=input_dtype),
        W: T.Tensor(W_shape, dtype=input_dtype),
        G: T.Tensor(G_shape, dtype=gate_dtype),
        h0: T.Tensor(h0_shape, dtype=input_dtype),
        dht: T.Tensor(dht_shape, dtype=input_dtype),
        dO: T.Tensor(dO_shape, dtype=input_dtype),
        dv: T.Tensor(dv_shape, dtype=input_dtype),
        # Output
        dh: T.Tensor(dh_shape, dtype=output_dtype),
        dh0: T.Tensor(dh0_shape, dtype=state_dtype),
        dv2: T.Tensor(dv2_shape, dtype=output_dtype),
240
241
242
243
244
245
246
247
248
249
250
251
252
    ):
        with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh):
            bb, bh = bbh // H, bbh % H

            b_dh_shared = T.alloc_shared((DK, block_DV), dtype=output_dtype)
            b_dh_shared_fp32 = T.alloc_shared((DK, block_DV), dtype=state_dtype)
            b_dh_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
            b_dh_fragment_1 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
            b_dh_fragment_2 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
            dv_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
            dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
            dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
            dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
253
254
255
            dO_shared_t = T.alloc_shared((block_DV, block_S), dtype=T.float32)
            dO_fragment = T.alloc_fragment((block_S, block_DV), dtype=T.float32)
            dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype=T.float32)
256
257
258
            K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)

            Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
259
            Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype=T.float32)
260
261
262
263
264
265
266
267
268
269
270
271
272
            W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)

            G_last_local = T.alloc_local((1), dtype=gate_dtype)
            G_last_local_exp = T.alloc_local((1), dtype=gate_dtype)
            G_shared = T.alloc_shared((block_S), dtype=gate_dtype, scope="shared")
            G_fragment = T.alloc_fragment((block_S), dtype=gate_dtype)
            G_fragment_post = T.alloc_fragment((block_S), dtype=gate_dtype)
            G_fragment_exp = T.alloc_fragment((block_S), dtype=gate_dtype)
            Q_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
            Q_fragment_t = T.alloc_fragment((DK, block_S), dtype=accum_dtype)

            T.use_swizzle(10)

273
274
275
276
277
278
279
280
281
282
283
284
285
            T.annotate_layout(
                {
                    b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared),
                    b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32),
                    dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
                    dO_shared: tilelang.layout.make_swizzled_layout(dO_shared),
                    dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t),
                    K_shared: tilelang.layout.make_swizzled_layout(K_shared),
                    Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
                    Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32),
                    W_shared: tilelang.layout.make_swizzled_layout(W_shared),
                }
            )
286
287

            if use_final_state_gradient:
288
                T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared)
289
290
291
292
293
294
295
296
297
298
                T.copy(b_dh_shared, b_dh_fragment)
            else:
                T.clear(b_dh_fragment)

            for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages):
                # The gradient should be stored in the reverse order
                i_s_inv = T.ceildiv(S, block_S) - i_s - 1

                # Store the updated dh
                T.copy(b_dh_fragment, b_dh_shared)
299
                T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
300
301

                # Update dv
302
                T.copy(K[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], K_shared)
303
304
305
                T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True)

                if use_g:
306
                    T.copy(G[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh], G_shared, disable_tma=True)
307
308
309
310
311
312
313
314
315
                    T.copy(G_shared, G_fragment)
                    G_last_local[0] = G_shared[block_S - 1]
                    G_last_local_exp[0] = T.exp(G_last_local[0])
                    for i_s2 in T.Parallel(block_S):
                        G_fragment_post[i_s2] = T.exp(G_last_local[0] - G_fragment[i_s2])
                    for i_s2, i_v in T.Parallel(block_S, block_DV):
                        # with T.If(G_last_local[0] - G_shared[i_s2] <= 0):
                        with T.If(G_last_local[0] - G_fragment[i_s2] <= 0):
                            with T.Then():
316
                                dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2]
317
318
319
                            with T.Else():
                                dv_fragment[i_s2, i_v] = 0

320
                T.copy(dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared)
321
322
323
324
325
326
                T.copy(dv_shared, dv_fragment_2)
                for i_s2, i_v in T.Parallel(block_S, block_DV):
                    dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v]

                # Store the updated dv
                T.copy(dv_fragment, dv_shared)
327
                T.copy(dv_shared, dv2[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV])
328
329

                # Update dh
330
331
                T.copy(Q[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], Q_shared)
                T.copy(W[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], W_shared)
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350

                T.clear(Q_fragment)
                if use_g:
                    for i_k, i_v in T.Parallel(DK, block_DV):
                        b_dh_fragment[i_k, i_v] *= G_last_local_exp[0]
                    T.copy(Q_shared, Q_fragment)
                    for i_s2 in T.Parallel(block_S):
                        G_fragment_exp[i_s2] = T.exp(G_shared[i_s2])
                    for i_s2, i_k in T.Parallel(block_S, DK):
                        # Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * T.exp(G_shared[i_s2]) * scale
                        Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * G_fragment_exp[i_s2] * scale
                else:
                    T.copy(Q_shared, Q_fragment)
                    for i_s2, i_k in T.Parallel(block_S, DK):
                        Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * scale
                # Get transpose of Q_fragment to meet tf32 gemm requirement
                for i_s2, i_k in T.Parallel(block_S, DK):
                    Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k]

351
                T.copy(dO[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dO_shared)
352
353
354
355
356
357
358
359
360
361
362
363
364
                T.copy(dO_shared, dO_fragment)
                for i_s2, i_v in T.Parallel(block_S, block_DV):
                    dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v]
                T.copy(dO_fragment_t, dO_shared_t)

                T.clear(b_dh_fragment_1)
                T.gemm(Q_fragment_t, dO_shared_t, b_dh_fragment_1, transpose_B=True)
                T.clear(b_dh_fragment_2)
                T.gemm(W_shared, dv_shared, b_dh_fragment_2, transpose_A=True)
                for i_k, i_v in T.Parallel(DK, block_DV):
                    b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v]

            if use_initial_state:
365
                T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439

    return kernel


def test_result(dh_0, dh0_0, dv2_0, dh_1, dh0_1, dv2_1, name):
    try:
        torch.testing.assert_close(dh_0, dh_1, rtol=1e-2, atol=1e-2, equal_nan=True)
        print(f"{name} dh_0 and dh_1 passed for {name}")
    except Exception as e:
        print(f"{name} dh_0 and dh_1 are not close for {name}")
        print(e, end="\n\n")
    try:
        torch.testing.assert_close(dh0_0, dh0_1, rtol=1e-2, atol=1e-2, equal_nan=True)
        print(f"{name} dh0_0 and dh0_1 passed for {name}")
    except Exception as e:
        print(f"{name} dh0_0 and dh0_1 are not close for {name}")
        print(e, end="\n\n")
    try:
        torch.testing.assert_close(dv2_0, dv2_1, rtol=1e-2, atol=1e-2, equal_nan=True)
        print(f"{name} dv2_0 and dv2_1 passed for {name}")
    except Exception as e:
        print(f"{name} dv2_0 and dv2_1 are not close for {name}")
        print(e, end="\n\n")

    close = torch.isclose(dh_0, dh_1, rtol=1e-2, atol=1e-2)
    mismatch_indices = torch.nonzero(~close, as_tuple=True)
    error_num = 0
    for indices in zip(*mismatch_indices):
        if error_num < 100:
            print(
                f"{name} dh_0[{[idx.item() for idx in indices]}] = {dh_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item(), indices[4].item()]}, dh_1[{[idx.item() for idx in indices]}] = {dh_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item(), indices[4].item()]}"
            )
            error_num += 1
    close = torch.isclose(dh0_0, dh0_1, rtol=1e-2, atol=1e-2)
    mismatch_indices = torch.nonzero(~close, as_tuple=True)
    error_num = 0
    for indices in zip(*mismatch_indices):
        if error_num < 100:
            print(
                f"{name} dh0_0[{[idx.item() for idx in indices]}] = {dh0_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}, dh0_1[{[idx.item() for idx in indices]}] = {dh0_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}"
            )
            error_num += 1
    close = torch.isclose(dv2_0, dv2_1, rtol=1e-2, atol=1e-2)
    mismatch_indices = torch.nonzero(~close, as_tuple=True)
    error_num = 0
    for indices in zip(*mismatch_indices):
        if error_num < 100:
            print(
                f"{name} dv2_0[{[idx.item() for idx in indices]}] = {dv2_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}, dv2_1[{[idx.item() for idx in indices]}] = {dv2_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}"
            )
            error_num += 1


def run_test(
    B,
    S,
    H,
    DK,
    DV,
    input_dtype,
    output_dtype,
    accum_dtype,
    gate_dtype,
    state_dtype,
    chunk_size,
    scale,
    use_g=True,
    use_initial_state=True,
    use_final_state_gradient=True,
    block_DV=64,
    threads=256,
    num_stages=0,
    use_torch=False,
):
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    Q, K, W, G, h0, dht, dO, dv = prepare_input(
        B,
        S,
        H,
        DK,
        DV,
        chunk_size,
        getattr(torch, input_dtype),
        getattr(torch, output_dtype),
        getattr(torch, accum_dtype),
        getattr(torch, gate_dtype),
        getattr(torch, state_dtype),
    )
    dh_ref, dh0_ref, dv2_ref = prepare_output(
        B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
    )
    dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(
        B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
    )
459
460
461
462

    # fla ref
    print("fla running...", flush=True)
    if use_g:
463
        dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale)
464
465
    else:
        G = G.fill_(0)
466
        dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale)
467
468
469

    # tilelang
    print("tilelang running...", flush=True)
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
    kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(
        B,
        S,
        H,
        DK,
        DV,
        input_dtype,
        output_dtype,
        accum_dtype,
        gate_dtype,
        state_dtype,
        chunk_size,
        scale,
        use_g,
        use_initial_state,
        use_final_state_gradient,
        block_DV,
        threads,
        num_stages,
    )
490
491
492
493
    # kernel = tilelang.compile(program)
    print(kernel.get_kernel_source())
    dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv)

494
    fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
495
496
497
498
499
500
501
502
503
504
505
506
507
508
    tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv)

    print(f"fla time: {fla_time} ms")
    print(f"tilelang time: {tilelang_time} ms")

    assert_similar(dh_tilelang, dh_ref, 1e-5, "fla-tilelang", data="dh")
    assert_similar(dh0_tilelang, dh0_ref, 1e-5, "fla-tilelang", data="dh0")
    assert_similar(dv2_tilelang, dv2_ref, 1e-5, "fla-tilelang", data="dv2")

    # torch ref
    if use_torch:
        print("torch running...", flush=True)
        if use_g:
            dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
                Q,
                K,
                W,
                G,
                h0,
                dht,
                dO,
                dv,
                scale,
                use_g,
                use_initial_state,
                use_final_state_gradient,
                getattr(torch, input_dtype),
                getattr(torch, output_dtype),
                getattr(torch, accum_dtype),
                getattr(torch, gate_dtype),
                getattr(torch, state_dtype),
            )
527
528
529
530
531
            dh_ref_torch = dh_ref_torch.cuda()
            dh0_ref_torch = dh0_ref_torch.cuda()
            dv2_ref_torch = dv2_ref_torch.cuda()
        else:
            dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
                Q,
                K,
                W,
                None,
                h0,
                dht,
                dO,
                dv,
                scale,
                use_g,
                use_initial_state,
                use_final_state_gradient,
                getattr(torch, input_dtype),
                getattr(torch, output_dtype),
                getattr(torch, accum_dtype),
                getattr(torch, gate_dtype),
                getattr(torch, state_dtype),
            )
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
            dh_ref_torch = dh_ref_torch.cuda()
            dh0_ref_torch = dh0_ref_torch.cuda()
            dv2_ref_torch = dv2_ref_torch.cuda()

        assert_similar(dh_ref_torch, dh_ref, 1e-5, "torch-fla", data="dh")
        assert_similar(dh0_ref_torch, dh0_ref, 1e-5, "torch-fla", data="dh0")
        assert_similar(dv2_ref_torch, dv2_ref, 1e-5, "torch-fla", data="dv2")
        assert_similar(dh_ref_torch, dh_tilelang, 1e-5, "torch-tilelang", data="dh")
        assert_similar(dh0_ref_torch, dh0_tilelang, 1e-5, "torch-tilelang", data="dh0")
        assert_similar(dv2_ref_torch, dv2_tilelang, 1e-5, "torch-tilelang", data="dv2")


def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
    """
    Do benchmark for a function.
    """
    start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
    end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
    for _ in range(warmup):
        fn(*args, **kwargs)

    torch.cuda.synchronize()
    for i in range(rep):
        start_event[i].record()
        fn(*args, **kwargs)
        end_event[i].record()
    torch.cuda.synchronize()

    # Record clocks
    times = torch.tensor(
        [s.elapsed_time(e) for s, e in zip(start_event, end_event)],
        dtype=torch.float,
    )

    return times.mean().item()


def main():
    DK = 128
    run_test(
        B=1,
        S=32768,
        H=8,
        DK=DK,
        DV=128,
595
596
597
598
599
        input_dtype=T.bfloat16,
        output_dtype=T.bfloat16,
        accum_dtype=T.float32,
        gate_dtype=T.float32,
        state_dtype=T.float32,
600
601
602
603
604
605
606
607
608
609
610
611
612
613
        chunk_size=64,
        scale=DK**-0.5,
        use_g=True,
        use_initial_state=True,
        use_final_state_gradient=True,
        block_DV=32,
        threads=128,
        num_stages=1,
        use_torch=False,
    )


if __name__ == "__main__":
    main()