"ts/webui/src/static/style/succTable.scss" did not exist on "1cd7ad5fff424af07345a18f67f7f12f80e259b6"
example_chunk_delta_h.py 13 KB
Newer Older
1
2
3
4
5
# Reference: fla/ops/common/chunk_delta_h.py

import sys  # noqa: F401
import tilelang
import tilelang.language as T
6
from tilelang.autotuner import autotune
7
8
9
10
11
12

# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
    import fla
13

14
15
16
17
18
19
20
21
22
23
    print(fla.__file__)
    from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h
except ImportError:
    print("fla not found, using tilelang implementation")
    fla = None

import torch
import torch.nn.functional as F
from tilelang.engine.callback import register_cuda_postproc_callback  # noqa: F401

24
from test_utils import assert_similar
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

# (zhengju) We can slightly modify the generated cuda code from tilelang lowering
# in the debug folder to make the performance better. To enable this callback,
# you can comment out the following function.
# @register_cuda_postproc_callback
# def tilelang_callback_cuda_postproc(code, _):
#     cuda_code = open("../debug/chunk_delta_h_fuse.cu", "r").read()
#     code = cuda_code
#     return code

torch.random.manual_seed(0)


def prepare_input(
    B,
    S,
    H,
    DK,
    DV,
    chunk_size,
    input_dtype,
    output_dtype,
    accum_dtype,
    gate_dtype,
):
    K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
    K = F.normalize(K, dim=-1, p=2)
    W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
    W = F.normalize(W, dim=-1, p=2)
    U = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
    U = F.normalize(U, dim=-1, p=2)
    G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
    G = F.logsigmoid(G)
    try:
        from fla.ops.utils.cumsum import chunk_local_cumsum
60

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        G = chunk_local_cumsum(G, chunk_size)
    except ImportError:
        print("fla not found, skip cumsum")

    initial_state = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
    return K, W, U, G, initial_state


def prepare_output(
    B,
    S,
    H,
    DK,
    DV,
    chunk_size,
    output_dtype,
    state_dtype,
):
    BS = S // chunk_size
    h = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda()
    final_state = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda()
    V_new = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
    return h, final_state, V_new


86
87
def get_configs():
    import itertools
88

89
90
91
92
93
94
    block_DK = [32, 64, 128]
    block_DV = [32, 64, 128]
    threads = [128, 256]
    num_stages = [1, 2, 3]
    _configs = list(itertools.product(block_DK, block_DV, threads, num_stages))

95
    configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs]
96
97
98
99
100
    return configs


@autotune(configs=get_configs(), warmup=3, rep=5)
@tilelang.jit(out_idx=[-3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True})
101
102
103
104
105
106
107
108
109
110
111
112
113
def tilelang_chunk_gated_delta_rule_fwd_h(
    # task config
    B,
    S,
    H,
    DK,
    DV,
    input_dtype,
    output_dtype,
    accum_dtype,
    gate_dtype,
    state_dtype,
    chunk_size,
114
115
116
117
    use_g,
    use_initial_state,
    store_final_state,
    save_new_value,
118
119
    # kernel config
    block_DK=64,
120
121
122
    block_DV=32,
    threads=128,
    num_stages=1,
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
):
    block_S = chunk_size
    BS = S // block_S

    K_shape = (B, S, H, DK)
    V_shape = (B, S, H, DV)
    W_shape = (B, S, H, DK)
    U_shape = (B, S, H, DV)
    G_shape = (B, S, H)
    h_shape = (B, BS, H, DK, DV)
    initial_state_shape = (B, H, DK, DV)
    final_state_shape = (B, H, DK, DV)

    @T.prim_func
    def kernel(
138
139
140
141
142
143
144
145
        K: T.Tensor(K_shape, dtype=input_dtype),
        W: T.Tensor(W_shape, dtype=input_dtype),
        U: T.Tensor(U_shape, dtype=input_dtype),
        G: T.Tensor(G_shape, dtype=gate_dtype),
        initial_state: T.Tensor(initial_state_shape, dtype=input_dtype),
        h: T.Tensor(h_shape, dtype=output_dtype),
        final_state: T.Tensor(final_state_shape, dtype=state_dtype),
        V_new: T.Tensor(V_shape, dtype=output_dtype),
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    ):
        with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh):
            bb, bh = bbh // H, bbh % H

            b_h_shared = T.alloc_shared((DK, block_DV), dtype=input_dtype)
            b_h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)

            U_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
            U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
            W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
            V_new_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
            V_new_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype)
            K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
            G_last_local = T.alloc_local((1), dtype=gate_dtype)
            G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype)
            G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype)

163
164
165
166
167
168
169
170
171
172
            T.annotate_layout(
                {
                    b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared),
                    U_shared: tilelang.layout.make_swizzled_layout(U_shared),
                    W_shared: tilelang.layout.make_swizzled_layout(W_shared),
                    V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared),
                    K_shared: tilelang.layout.make_swizzled_layout(K_shared),
                    G_shared: tilelang.layout.make_swizzled_layout(G_shared),
                }
            )
173
174
175
176

            T.use_swizzle(10)

            if use_initial_state:
177
                T.copy(initial_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_h_shared)
178
179
180
181
182
183
                T.copy(b_h_shared, b_h_fragment)
            else:
                T.clear(b_h_fragment)

            for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages):
                # Store previous result to the hidden tensor, like the epilogue
184
                T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
185
186

                # Recurrence
187
                T.copy(W[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], W_shared)
188
189
190
                T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True)

                # U - W * S
191
                T.copy(U[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], U_shared)
192
193
194
195
196
197
198
                T.copy(U_shared, U_fragment)
                for i_s2, i_v in T.Parallel(block_S, block_DV):
                    V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v]

                # Save V_new
                if save_new_value:
                    T.copy(V_new_fragment, dst=V_new_shared)
199
                    T.copy(V_new_shared, V_new[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV])
200

201
                T.copy(K[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], K_shared)
202
203
204
205
206
207
208
209
210
                # use_g
                if use_g:
                    G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh]
                    for i_s2, i_v in T.Parallel(block_S, block_DV):
                        G_shared[i_s2, i_v] = G[bb, i_s * block_S + i_s2, bh]
                    T.copy(G_shared, G_fragment)
                    for i_s2, i_v in T.Parallel(block_S, block_DV):
                        with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0):
                            with T.Then():
211
                                V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp2(
212
213
                                    (G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695
                                )
214
215
                            with T.Else():
                                V_new_fragment[i_s2, i_v] = 0
216
                    G_last_local[0] = T.exp2(G_last_local[0] * 1.442695)
217
218
219
220
221
222
223
224
225
226
227
                    for i_k, i_v in T.Parallel(DK, block_DV):
                        b_h_fragment[i_k, i_v] *= G_last_local[0]

                # Update intermediate results
                T.copy(V_new_fragment, V_new_shared)
                T.gemm(K_shared, V_new_shared, b_h_fragment, transpose_A=True)

                T.copy(b_h_fragment, b_h_shared)

            # Save final state
            if store_final_state:
228
                T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278

    return kernel


def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
    """
    Do benchmark for a function.
    """
    start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
    end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
    for _ in range(warmup):
        fn(*args, **kwargs)

    torch.cuda.synchronize()
    for i in range(rep):
        start_event[i].record()
        fn(*args, **kwargs)
        end_event[i].record()
    torch.cuda.synchronize()

    # Record clocks
    times = torch.tensor(
        [s.elapsed_time(e) for s, e in zip(start_event, end_event)],
        dtype=torch.float,
    )

    return times.mean().item()


def run_test(
    B,
    S,
    H,
    DK,
    DV,
    input_dtype,
    output_dtype,
    accum_dtype,
    gate_dtype,
    state_dtype,
    chunk_size,
    use_g=True,
    use_initial_state=True,
    store_final_state=True,
    save_new_value=True,
    block_DK=64,
    block_DV=32,
    threads=128,
    num_stages=0,
):
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    K, W, U, G, initial_state = prepare_input(
        B,
        S,
        H,
        DK,
        DV,
        chunk_size,
        getattr(torch, input_dtype),
        getattr(torch, output_dtype),
        getattr(torch, accum_dtype),
        getattr(torch, gate_dtype),
    )
    h_ref, final_state_ref, V_new_ref = prepare_output(
        B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype)
    )
    h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(
        B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype)
    )
297
298

    # fla ref
299
300
301
302
303
304
305
306
    h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(
        k=K,
        w=W,
        u=U,
        g=G,
        initial_state=initial_state,
        output_final_state=store_final_state,
        chunk_size=chunk_size,
307
308
        save_new_value=save_new_value,
    )
309
310

    # tilelang
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    kernel = tilelang_chunk_gated_delta_rule_fwd_h(
        B,
        S,
        H,
        DK,
        DV,
        input_dtype,
        output_dtype,
        accum_dtype,
        gate_dtype,
        state_dtype,
        chunk_size,
        use_g,
        use_initial_state,
        store_final_state,
        save_new_value,
    )
328
329
330
331
    h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state)
    # (zhengju) If you want to print the generated cuda code, you can uncomment the following line
    # print("CUDA Code:\n", kernel.get_kernel_source())

332
333
334
335
336
337
338
339
340
    fla_time = do_bench(
        chunk_gated_delta_rule_fwd_h,
        k=K,
        w=W,
        u=U,
        g=G,
        initial_state=initial_state,
        output_final_state=store_final_state,
        chunk_size=chunk_size,
341
342
        save_new_value=save_new_value,
    )
343
344
345
346
347
348
    tilelang_time = do_bench(kernel, K, W, U, G, initial_state)

    # check correctness
    try:
        h_ref_fp32 = h_ref.to(torch.float32)
        h_tilelang_fp32 = h_tilelang.to(torch.float32)
349
        assert_similar(h_ref_fp32, h_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd h", raise_assert=False)
350
351
352
353
354
355
356
357
358
359
360
361
362
        print("tilelang chunk gated delta rule fwd h passed √")
    except Exception as e:
        print("tilelang chunk gated delta rule fwd h failed ✗")
        print(e)

    try:
        final_state_ref_fp32 = final_state_ref.to(torch.float32)
        final_state_tilelang_fp32 = final_state_tilelang.to(torch.float32)
        assert_similar(
            final_state_ref_fp32,
            final_state_tilelang_fp32,
            eps=1e-5,
            name="tilelang chunk gated delta rule fwd final_state",
363
364
            raise_assert=False,
        )
365
366
367
368
369
370
371
372
        print("tilelang chunk gated delta rule fwd final_state passed √")
    except Exception as e:
        print("tilelang chunk gated delta rule fwd final_state failed ✗")
        print(e)

    try:
        V_new_ref_fp32 = V_new_ref.to(torch.float32)
        V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32)
373
        assert_similar(V_new_ref_fp32, V_new_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd V_new", raise_assert=False)
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        print("tilelang chunk gated delta rule fwd V_new passed √")
    except Exception as e:
        print("tilelang chunk gated delta rule fwd V_new failed ✗")
        print(e)

    print(f"tilelang time: {tilelang_time} ms")
    print(f"fla time: {fla_time} ms")


def main():
    run_test(
        B=1,
        S=32768,
        H=32,
        DK=128,
        DV=128,
        input_dtype="bfloat16",
        output_dtype="bfloat16",
        accum_dtype="float32",
        gate_dtype="float32",
        state_dtype="float32",
        chunk_size=64,
        use_g=True,
397
        use_initial_state=False,
398
399
        store_final_state=True,
        save_new_value=True,
400
        block_DK=32,
401
402
        block_DV=32,
        threads=128,
403
        num_stages=2,
404
405
406
407
408
    )


if __name__ == "__main__":
    main()