"test/verify/test_softmax3.cpp" did not exist on "2466dd6f1073fbc5925d98d9aef344d791e15f43"
example_chunk_scaled_dot_kkt.py 6.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# Reference: fla/ops/common/chunk_scaled_dot_kkt.py

import tilelang
import tilelang.language as T
import sys  # noqa: F401

# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
    import fla
12

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    print(fla.__file__)
    from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
except ImportError:
    print("fla not found, using tilelang implementation")
    fla = None

import torch

torch.set_printoptions(profile="full")
torch.random.manual_seed(0)


def prepare_input(
    B,
    S,
    H,
    DK,
    input_dtype,
    output_dtype,
    accum_dtype,
):
    K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
    Beta = torch.randn(B, S, H, dtype=input_dtype).cuda()
    G = torch.randn(B, S, H, dtype=accum_dtype).cuda()
    return K, Beta, G


def prepare_output(
    B,
    S,
    H,
    chunk_size,
    dtype,
):
    BS = chunk_size
    A = torch.empty(B, S, H, BS, dtype=dtype).cuda()
    return A


@tilelang.jit(out_idx=[-1])
def tilelang_chunk_scaled_dot_kkt_fwd(
    # task config
    B,
    S,
    H,
    DK,
    chunk_size=64,
    input_dtype="bfloat16",
    output_dtype="bfloat16",
    accum_dtype="float32",
    use_g=True,
    # kernel config
    block_S=64,
    block_DK=64,
    threads=256,
    num_stages=0,
):
    K_shape = (B, S, H, DK)
    Beta_shape = (B, S, H)
    G_shape = (B, S, H)
    assert chunk_size == block_S, "chunk_size must be equal to block_S"
    BS = chunk_size
    output_shape = (B, S, H, BS)

    @T.prim_func
    def kernel(
79
80
81
82
        K: T.Tensor(K_shape, dtype=input_dtype),
        Beta: T.Tensor(Beta_shape, dtype=input_dtype),
        G: T.Tensor(G_shape, dtype=accum_dtype),
        A: T.Tensor(output_shape, dtype=output_dtype),
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    ):
        with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
            bb, bh = bbh // H, bbh % H
            # !! Pay attention to the scope of the shared memory: may cause misaligned address when shape is one dimension or the buffer is too small
            Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype, scope="shared")
            K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
            A_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype)
            Beta_K_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype)
            A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)

            # Tensor used for gated:
            G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared")
            G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)

97
98
99
100
101
102
            T.annotate_layout(
                {
                    K_shared: tilelang.layout.make_swizzled_layout(K_shared),
                    A_shared: tilelang.layout.make_swizzled_layout(A_shared),
                }
            )
103
104

            T.fill(A_fragment, 0)
105
            T.disable_warp_group_reg_alloc()
106
107
108
109
            for i_s in T.Parallel(block_S):
                Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]

            for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
110
                T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
111
112
113
114
115
116
117
118
119
120
121
122
                for i_s, i_k2 in T.Parallel(block_S, block_DK):
                    Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s]
                T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True)

            if use_g:
                for i_s in T.Parallel(block_S):
                    G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
                for i_s1, i_s2 in T.Parallel(block_S, block_S):
                    G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2]
                for i_s1, i_s2 in T.Parallel(block_S, block_S):
                    with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2):
                        with T.Then():
123
                            A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2])
124
125
126
127
128
129
130
131
132
                        with T.Else():
                            A_fragment[i_s1, i_s2] = 0
            else:
                for i_s1, i_s2 in T.Parallel(block_S, block_S):
                    with T.If(i_s1 <= i_s2):  # noqa: SIM117
                        with T.Then():
                            A_fragment[i_s1, i_s2] = 0

            T.copy(A_fragment, A_shared)
133
            T.copy(A_shared, A[bb, bs * block_S : (bs + 1) * block_S, bh, :])
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151

    return kernel


def run_test(
    B,
    S,
    H,
    DK,
    chunk_size,
    input_dtype,
    output_dtype,
    accum_dtype,
    use_g,
    block_DK,
    threads,
    num_stages,
):
152
    K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype))
153
154
155
156
157
    A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
    A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))

    # reference
    if use_g:
158
        A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
159
    else:
160
        A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
161
162
163

    # tilelang
    block_S = chunk_size
164
165
166
    kernel = tilelang_chunk_scaled_dot_kkt_fwd(
        B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages
    )
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    A_tilelang = kernel(K, Beta, G)

    try:
        torch.testing.assert_close(A_tilelang, A_ref, rtol=1e-2, atol=1e-2)
        print("tilelang chunk scaled dot kkt fwd passed √")
    except Exception as e:
        print("tilelang chunk scaled dot kkt fwd failed ✗")
        print(e)
        print("reference cuda kernel:")
        print(kernel.get_kernel_source())


def main():
    run_test(
        B=1,
        S=32768,
        H=32,
        DK=128,
        chunk_size=64,
        input_dtype="bfloat16",
        output_dtype="bfloat16",
        accum_dtype="float32",
        use_g=True,
        block_DK=64,
        threads=128,
192
193
        num_stages=2,
    )
194
195
196
197


if __name__ == "__main__":
    main()