test_tilelang_language_reduce.py 6.13 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl

tilelang.testing.set_random_seed()


def _make_shared_reduce(M, N, dtype, reduce_cb):
    import tilelang.language as T

    @T.prim_func
    def main(
            A: T.Tensor((M, N), dtype),
            B: T.Tensor((M,), dtype),
    ):
        with T.Kernel(1) as _:
            A_shared = T.alloc_shared((M, N), dtype)
            B_shared = T.alloc_shared((M,), dtype)

            T.copy(A, A_shared)
            reduce_cb(T, A_shared, B_shared)
            T.copy(B_shared, B)

    return main


def _run_program(program, ref_program, atol=1e-2, rtol=1e-2):
    jit_kernel = tl.compile(program, out_idx=-1)
    profiler = jit_kernel.get_profiler()
    profiler.assert_allclose(ref_program, atol=atol, rtol=rtol)


def reduce_max_test(M, N, dtype="float16"):
    import tilelang.language as T

    @T.prim_func
    def main(
            A: T.Tensor((M, N), dtype),
            B: T.Tensor((M,), dtype),
    ):
        with T.Kernel(1) as _:
            A_local = T.alloc_fragment((M, N), dtype)
            B_local = T.alloc_fragment((M,), dtype)

            T.copy(A, A_local)
            T.reduce_max(A_local, B_local, dim=1)
            T.copy(B_local, B)

    return main


def reduce_sum_test(M, N, dtype="float32"):
    import tilelang.language as T

    @T.prim_func
    def main(
            A: T.Tensor((M, N), dtype),
            B: T.Tensor((M,), dtype),
    ):
        with T.Kernel(1) as _:
            A_local = T.alloc_fragment((M, N), dtype)
            B_local = T.alloc_fragment((M,), dtype)

            T.copy(A, A_local)
            T.reduce_sum(A_local, B_local, dim=1)
            T.copy(B_local, B)

    return main


def reduce_sum_ss(M, N, dtype="float32"):
    return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_sum(src, dst, dim=1))


def reduce_max_ss(M, N, dtype="float32"):
    return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_max(src, dst, dim=1))


def reduce_min_ss(M, N, dtype="float32"):
    return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_min(src, dst, dim=1))


def reduce_abssum_ss(M, N, dtype="float32"):
    return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_abssum(src, dst, dim=1))


def reduce_absmax_ss(M, N, dtype="float32"):
    return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_absmax(src, dst, dim=1))


def run_reduce_sum(M, N, dtype="float32", mode="rr"):
    if mode == "rr":
        program = reduce_sum_test(M, N, dtype)
    elif mode == "ss":
        program = reduce_sum_ss(M, N, dtype)
    else:
        raise NotImplementedError("run_reduce_sum only supports rr and ss")
    _run_program(program, lambda A: A.sum(dim=1))


def run_shared_reduce(program_builder, ref_program, M, N, dtype="float32"):
    program = program_builder(M, N, dtype)
    _run_program(program, ref_program)


def run_reduce_max(M, N, dtype="float16"):
    program = reduce_max_test(M, N, dtype)
    _run_program(program, lambda A: A.max(dim=1).values, atol=1e-2, rtol=1e-2)


def test_reduce_sum():
    run_reduce_sum(256, 256)
    run_reduce_sum(512, 128)
    run_reduce_sum(128, 512)


def test_reduce_sum_shared():
    run_reduce_sum(64, 64, mode="ss")
    run_reduce_sum(32, 96, mode="ss")


def test_reduce_max():
    run_reduce_max(256, 256, "float16")
    run_reduce_max(512, 128, "float16")
    run_reduce_max(256, 256, "float32")


def test_reduce_max_shared():
    run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, "float32")
    run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 96, 48, "float32")


def test_reduce_min_shared():
    run_shared_reduce(reduce_min_ss, lambda A: A.min(dim=1).values, 64, 64, "float32")


def test_reduce_abssum_shared():
    run_shared_reduce(reduce_abssum_ss, lambda A: A.abs().sum(dim=1), 64, 64, "float32")


def test_reduce_absmax_shared():
    run_shared_reduce(reduce_absmax_ss, lambda A: A.abs().max(dim=1).values, 64, 64, "float32")


def reduce_sum_test_clear(M, N, dtype="float32"):
    import tilelang.language as T

    @T.prim_func
    def main(
            A: T.Tensor((M, N), dtype),
            B: T.Tensor((M,), dtype),
    ):
        with T.Kernel(1, threads=32) as _:
            A_local = T.alloc_fragment((M, N), dtype)
            B_local = T.alloc_fragment((M,), dtype)

            T.copy(A, A_local)
            T.fill(B_local, 1)
            T.reduce_sum(A_local, B_local, dim=1, clear=False)
            T.copy(B_local, B)

    return main


def run_reduce_sum_clear(M, N, dtype="float32"):
    program = reduce_sum_test_clear(M, N, dtype)
    jit_kernel = tl.compile(program, out_idx=-1)

    def ref_program(A):
        return A.sum(dim=1) + 1

    import torch

    dummy_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda()
    ref_out = ref_program(dummy_A)
    tl_out = jit_kernel(dummy_A)
    torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2)


def test_reduce_sum_clear():
    run_reduce_sum_clear(256, 256, "float32")
    run_reduce_sum_clear(512, 128, "float32")
    run_reduce_sum_clear(128, 512, "float32")


def reduce_max_test_clear(M, N, dtype="float16"):
    import tilelang.language as T

    @T.prim_func
    def main(
            A: T.Tensor((M, N), dtype),
            B: T.Tensor((M,), dtype),
    ):
        with T.Kernel(1, threads=32) as _:
            A_local = T.alloc_fragment((M, N), dtype)
            B_local = T.alloc_fragment((M,), dtype)

            T.copy(A, A_local)
            T.fill(B_local, -T.infinity(dtype))
            T.reduce_max(A_local, B_local, dim=1, clear=False)
            T.copy(B_local, B)

    return main


def run_reduce_max_clear(M, N, dtype="float16"):
    program = reduce_max_test_clear(M, N, dtype)
    jit_kernel = tl.compile(program, out_idx=-1)

    def ref_program(A):
        return A.max(dim=1).values

    import torch

    dummy_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda()
    ref_out = ref_program(dummy_A)
    tl_out = jit_kernel(dummy_A)
    torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2)


def test_reduce_max_clear():
    run_reduce_max_clear(256, 256, "float16")


if __name__ == "__main__":
    tilelang.testing.main()