example_cumsum.py 4.53 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# Util functions for flash linear attention cumsum
# Reference: fla/ops/utils/cumsum.py

import tilelang
import tilelang.language as T
import sys  # noqa: F401

# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
    import fla
    print(fla.__file__)
    from fla.ops.utils.cumsum import chunk_local_cumsum_scalar
except ImportError:
    print("fla not found, using tilelang implementation")
    fla = None

import torch


@tilelang.jit(
    out_idx=[-1],
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
    })
def tilelang_chunk_local_cumsum_scalar(
    # task config
    B,
    S,
    H,
    chunk_size=64,
    is_varlen=False,
    head_first=False,
    reverse=False,
    input_dtype="float16",
    output_dtype="float32",
    # kernel config
    block_S=64,
    threads=256,
    use_fragment=False,
):
    G_shape = (B, H, S) if head_first else (B, S, H)
    assert chunk_size == 2**(chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
    assert chunk_size == block_S, "chunk_size must be equal to block_S"

    @T.prim_func
    def kernel(
            G: T.Tensor(G_shape, dtype=input_dtype),
            G_new: T.Tensor(G_shape, dtype=output_dtype),
    ):
        with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
            bb, bh = bbh // H, bbh % H
            G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared")
            if head_first:
                T.copy(G[bb, bh, bs * block_S:(bs + 1) * block_S], G_shared)
            else:
                T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared)
            if use_fragment:
                G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared")
                T.copy(G_shared, G_fragment)
                T.cumsum(G_fragment, dim=1, reverse=reverse)
                if head_first:
                    T.copy(G_fragment, G_new[bb, bh, bs * block_S:(bs + 1) * block_S])
                else:
                    T.copy(G_fragment, G_new[bb, bs * block_S:(bs + 1) * block_S, bh])
            else:
                T.cumsum(G_shared, dim=1, reverse=reverse)
                if head_first:
                    T.copy(G_shared, G_new[bb, bh, bs * block_S:(bs + 1) * block_S])
                else:
                    T.copy(G_shared, G_new[bb, bs * block_S:(bs + 1) * block_S, bh])

    return kernel


def prepare_cumsum_input(
    B,
    S,
    H,
    dtype,
):
    G = torch.randn(B, S, H, dtype=dtype).cuda()
    return G


def prepare_cumsum_output(
    B,
    S,
    H,
    dtype,
):
    G_new = torch.empty(B, S, H, dtype=dtype).cuda()
    return G_new


def run_test(
    B,
    S,
    H,
    chunk_size,
    reverse,
    head_first,
    input_dtype,
    output_dtype,
    threads,
    use_fragment,
):
    G = prepare_cumsum_input(B, S, H, getattr(torch, input_dtype))
    G_new_ref = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype))
    G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype))

    # reference cumsum
    G_new_ref = chunk_local_cumsum_scalar(
        g=G,
        chunk_size=chunk_size,
        reverse=reverse,
        head_first=head_first,
        output_dtype=getattr(torch, output_dtype))

    # tilelang cumsum
    block_S = chunk_size
    kernel = tilelang_chunk_local_cumsum_scalar(
        B=B,
        S=S,
        H=H,
        chunk_size=chunk_size,
        reverse=reverse,
        head_first=head_first,
        input_dtype=input_dtype,
        output_dtype=output_dtype,
        block_S=block_S,
        threads=threads,
        use_fragment=use_fragment,
    )
    torch.cuda.profiler.start()
    G_new_tilelang = kernel(G)
    torch.cuda.profiler.stop()
    try:
        torch.testing.assert_close(G_new_tilelang, G_new_ref, rtol=1e-2, atol=1e-2)
        print("tilelang cumsum passed √")
    except Exception as e:
        print("tilelang cumsum failed ✗")
        print(e)
        print("G:")
        print(G.view(-1))
        print("G_new_tilelang:")
        print(G_new_tilelang.view(-1))
        print("G_new_ref:")
        print(G_new_ref.view(-1))


def main():
    run_test(
        B=1,
        S=32768,
        H=32,
        chunk_size=64,
        reverse=True,
        head_first=False,
        input_dtype="float32",
        output_dtype="float32",
        threads=256,
        use_fragment=False)


if __name__ == "__main__":
    main()