test_tilelang_language_reduce.py 5.96 KB
Newer Older
1
2
3
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
4
import tilelang.language as T
5
6
7
8
9
10
11

tilelang.testing.set_random_seed()


def _make_shared_reduce(M, N, dtype, reduce_cb):
    @T.prim_func
    def main(
12
13
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M,), dtype),
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    ):
        with T.Kernel(1) as _:
            A_shared = T.alloc_shared((M, N), dtype)
            B_shared = T.alloc_shared((M,), dtype)

            T.copy(A, A_shared)
            reduce_cb(T, A_shared, B_shared)
            T.copy(B_shared, B)

    return main


def _run_program(program, ref_program, atol=1e-2, rtol=1e-2):
    jit_kernel = tl.compile(program, out_idx=-1)
    profiler = jit_kernel.get_profiler()
    profiler.assert_allclose(ref_program, atol=atol, rtol=rtol)


32
def reduce_max_test(M, N, dtype=T.float16):
33
34
35
36
    import tilelang.language as T

    @T.prim_func
    def main(
37
38
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M,), dtype),
39
40
41
42
43
44
45
46
47
48
49
50
    ):
        with T.Kernel(1) as _:
            A_local = T.alloc_fragment((M, N), dtype)
            B_local = T.alloc_fragment((M,), dtype)

            T.copy(A, A_local)
            T.reduce_max(A_local, B_local, dim=1)
            T.copy(B_local, B)

    return main


51
def reduce_sum_test(M, N, dtype=T.float32):
52
53
54
55
    import tilelang.language as T

    @T.prim_func
    def main(
56
57
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M,), dtype),
58
59
60
61
62
63
64
65
66
67
68
69
    ):
        with T.Kernel(1) as _:
            A_local = T.alloc_fragment((M, N), dtype)
            B_local = T.alloc_fragment((M,), dtype)

            T.copy(A, A_local)
            T.reduce_sum(A_local, B_local, dim=1)
            T.copy(B_local, B)

    return main


70
def reduce_sum_ss(M, N, dtype=T.float32):
71
72
73
    return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_sum(src, dst, dim=1))


74
def reduce_max_ss(M, N, dtype=T.float32):
75
76
77
    return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_max(src, dst, dim=1))


78
def reduce_min_ss(M, N, dtype=T.float32):
79
80
81
    return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_min(src, dst, dim=1))


82
def reduce_abssum_ss(M, N, dtype=T.float32):
83
84
85
    return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_abssum(src, dst, dim=1))


86
def reduce_absmax_ss(M, N, dtype=T.float32):
87
88
89
    return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_absmax(src, dst, dim=1))


90
def run_reduce_sum(M, N, dtype=T.float32, mode="rr"):
91
92
93
94
95
96
97
98
99
    if mode == "rr":
        program = reduce_sum_test(M, N, dtype)
    elif mode == "ss":
        program = reduce_sum_ss(M, N, dtype)
    else:
        raise NotImplementedError("run_reduce_sum only supports rr and ss")
    _run_program(program, lambda A: A.sum(dim=1))


100
def run_shared_reduce(program_builder, ref_program, M, N, dtype=T.float32):
101
102
103
104
    program = program_builder(M, N, dtype)
    _run_program(program, ref_program)


105
def run_reduce_max(M, N, dtype=T.float16):
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    program = reduce_max_test(M, N, dtype)
    _run_program(program, lambda A: A.max(dim=1).values, atol=1e-2, rtol=1e-2)


def test_reduce_sum():
    run_reduce_sum(256, 256)
    run_reduce_sum(512, 128)
    run_reduce_sum(128, 512)


def test_reduce_sum_shared():
    run_reduce_sum(64, 64, mode="ss")


def test_reduce_max():
121
122
123
    run_reduce_max(256, 256, T.float16)
    run_reduce_max(512, 128, T.float16)
    run_reduce_max(256, 256, T.float32)
124
125
126


def test_reduce_max_shared():
127
    run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, T.float32)
128
129
130


def test_reduce_min_shared():
131
    run_shared_reduce(reduce_min_ss, lambda A: A.min(dim=1).values, 64, 64, T.float32)
132
133
134


def test_reduce_abssum_shared():
135
    run_shared_reduce(reduce_abssum_ss, lambda A: A.abs().sum(dim=1), 64, 64, T.float32)
136
137
138


def test_reduce_absmax_shared():
139
    run_shared_reduce(reduce_absmax_ss, lambda A: A.abs().max(dim=1).values, 64, 64, T.float32)
140
141


142
def reduce_sum_test_clear(M, N, dtype=T.float32):
143
144
145
146
    import tilelang.language as T

    @T.prim_func
    def main(
147
148
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M,), dtype),
149
150
151
152
153
154
155
156
157
158
159
160
161
    ):
        with T.Kernel(1, threads=32) as _:
            A_local = T.alloc_fragment((M, N), dtype)
            B_local = T.alloc_fragment((M,), dtype)

            T.copy(A, A_local)
            T.fill(B_local, 1)
            T.reduce_sum(A_local, B_local, dim=1, clear=False)
            T.copy(B_local, B)

    return main


162
def run_reduce_sum_clear(M, N, dtype=T.float32):
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    program = reduce_sum_test_clear(M, N, dtype)
    jit_kernel = tl.compile(program, out_idx=-1)

    def ref_program(A):
        return A.sum(dim=1) + 1

    import torch

    dummy_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda()
    ref_out = ref_program(dummy_A)
    tl_out = jit_kernel(dummy_A)
    torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2)


def test_reduce_sum_clear():
178
179
180
    run_reduce_sum_clear(256, 256, T.float32)
    run_reduce_sum_clear(512, 128, T.float32)
    run_reduce_sum_clear(128, 512, T.float32)
181
182


183
def reduce_max_test_clear(M, N, dtype=T.float16):
184
185
186
187
    import tilelang.language as T

    @T.prim_func
    def main(
188
189
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M,), dtype),
190
191
192
193
194
195
196
197
198
199
200
201
202
    ):
        with T.Kernel(1, threads=32) as _:
            A_local = T.alloc_fragment((M, N), dtype)
            B_local = T.alloc_fragment((M,), dtype)

            T.copy(A, A_local)
            T.fill(B_local, -T.infinity(dtype))
            T.reduce_max(A_local, B_local, dim=1, clear=False)
            T.copy(B_local, B)

    return main


203
def run_reduce_max_clear(M, N, dtype=T.float16):
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    program = reduce_max_test_clear(M, N, dtype)
    jit_kernel = tl.compile(program, out_idx=-1)

    def ref_program(A):
        return A.max(dim=1).values

    import torch

    dummy_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda()
    ref_out = ref_program(dummy_A)
    tl_out = jit_kernel(dummy_A)
    torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2)


def test_reduce_max_clear():
219
    run_reduce_max_clear(256, 256, T.float16)
220
221
222
223


if __name__ == "__main__":
    tilelang.testing.main()