test_tilelang_language_reshape.py 7.77 KB
Newer Older
1
2
3
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
4
import torch
5
import pytest
6
7
8
9
10
11
12


def reshape_test(N, M, dtype):
    import tilelang.language as T

    @T.prim_func
    def main(
13
14
        A: T.Tensor((N,), dtype),
        B: T.Tensor((N // M, M), dtype),
15
16
17
18
19
20
21
22
23
24
    ):
        with T.Kernel(1) as _:
            A_reshaped = T.reshape(A, [N // M, M])
            T.copy(A_reshaped, B)

    return main


def run_reshape(N, M, dtype):
    program = reshape_test(N, M, dtype)
25
26
27
28
29
30
31
32
    # TODO(lei): reshape cannot apply shared memory
    # layout transform propagation
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
33
34
        },
    )
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return A.reshape(N // M, M)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reshape_smem():
    # Test reshape
    run_reshape(1024, 32, "float32")
    run_reshape(2048, 64, "float16")


49
def reshape_test_smem_1d_2_2d(N, M, dtype):
50
51
52
53
    import tilelang.language as T

    @T.prim_func
    def main(
54
55
        A: T.Tensor((N,), dtype),
        B: T.Tensor((N // M, M), dtype),
56
57
58
    ):
        with T.Kernel(1) as _:
            A_shared = T.alloc_shared((N,), dtype)
59
            for i in T.Parallel(N):
60
61
62
                A_shared[i] = A[i]

            A_smem_reshaped = T.reshape(A_shared, [N // M, M])
63
            T.copy(A_smem_reshaped, B)
64
65
66
67

    return main


68
69
def run_reshape_smem_1d_2_2d(N, M, dtype):
    program = reshape_test_smem_1d_2_2d(N, M, dtype)
70
71
72
73
74
75
76
77
    # TODO(lei): reshape cannot apply shared memory
    # layout transform propagation
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
78
79
        },
    )
80
81
82
83
84
85
86
87
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return A.reshape(N // M, M)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


88
89
90
91
92
93
94
95
96
97
def test_reshape_smem_1d_2_2d():
    run_reshape_smem_1d_2_2d(1024, 32, "float32")
    run_reshape_smem_1d_2_2d(2048, 64, "float16")


def reshape_test_smem_2d_2_1d(N, M, dtype):
    import tilelang.language as T

    @T.prim_func
    def main(
98
99
        A: T.Tensor((N // M, M), dtype),
        B: T.Tensor((N,), dtype),
100
101
102
103
104
105
106
107
108
109
110
    ):
        with T.Kernel(1) as _:
            A_shared = T.alloc_shared((N // M, M), dtype)
            for i, j in T.Parallel(N // M, M):
                A_shared[i, j] = A[i, j]

            A_smem_reshaped = T.reshape(A_shared, [N])
            T.copy(A_smem_reshaped, B)

    return main

Gabriel Wu's avatar
Gabriel Wu committed
111

112
113
def run_reshape_smem_2d_2_1d(N, M, dtype):
    program = reshape_test_smem_2d_2_1d(N, M, dtype)
114
115
116
117
118
119
120
121
    # TODO(lei): reshape cannot apply shared memory
    # layout transform propagation
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
122
123
        },
    )
124
125
126
127
128
129
130
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return A.reshape(N)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)

Gabriel Wu's avatar
Gabriel Wu committed
131

132
133
134
135
def test_reshape_smem_2d_2_1d():
    run_reshape_smem_2d_2_1d(1024, 32, "float32")
    run_reshape_smem_2d_2_1d(2048, 64, "float16")

136

137
138
139
140
141
def reshape_fragment_test(N, M, dtype):
    import tilelang.language as T

    @T.prim_func
    def main(
142
143
        A: T.Tensor((N // M, M), dtype),
        B: T.Tensor((N,), dtype),
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    ):
        with T.Kernel(1, threads=32) as _:
            A_shared = T.alloc_shared((N // M, M), dtype, scope="shared")
            A_local = T.alloc_fragment((N // M, M), dtype)
            B_shared = T.alloc_shared((N,), dtype, scope="shared")

            T.copy(A, A_shared)
            T.copy(A_shared, A_local)
            A_local_reshape = T.reshape(A_local, [N])
            T.copy(A_local_reshape, B_shared)
            T.copy(B_shared, B)

    return main


def run_reshape_fragment(N, M, dtype):
    program = reshape_fragment_test(N, M, dtype)
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
167
168
        },
    )
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return A.reshape(N)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reshape_fragment():
    run_reshape_fragment(1024, 32, "float32")
    run_reshape_fragment(2048, 64, "float16")


def reshape_layout_transform_shared(N, M, dtype):
    import tilelang.language as T
    from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout

    @T.prim_func
    def main(
188
189
        A: T.Tensor((N // M, M), dtype),
        B: T.Tensor((N,), dtype),
190
191
192
193
    ):
        with T.Kernel(1, threads=32) as _:
            A_shared = T.alloc_shared((N // M, M), dtype, scope="shared")

194
195
196
197
198
            T.annotate_layout(
                {
                    A_shared: make_mma_swizzle_layout(A_shared),
                }
            )
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
            T.copy(A, A_shared)
            A_shared_reshape = T.reshape(A_shared, [N])
            T.copy(A_shared_reshape, B)

    return main


def run_reshape_layout_transform_shared(N, M, dtype):
    program = reshape_layout_transform_shared(N, M, dtype)
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
214
215
        },
    )
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return A.reshape(N)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reshape_layout_transform_shared():
    run_reshape_layout_transform_shared(1024, 32, "float32")
    run_reshape_layout_transform_shared(2048, 64, "float16")


def reduce_after_reshape_test(N, M, dtype):
    import tilelang.language as T

    @T.prim_func
    def main(
234
235
        A: T.Tensor((N,), dtype),
        B: T.Tensor((N // M,), dtype),
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    ):
        with T.Kernel(1, threads=32) as _:
            A_shared = T.alloc_shared((N,), dtype, scope="shared")
            A_local = T.alloc_fragment((N,), dtype)
            B_local = T.alloc_fragment((N // M,), dtype)

            T.copy(A, A_shared)
            T.copy(A_shared, A_local)
            A_local_reshape = T.reshape(A_local, [N // M, M])
            T.reduce_max(A_local_reshape, B_local, dim=1)
            T.copy(B_local, B)

    return main


def run_reduce_after_reshape(N, M, dtype):
    program = reduce_after_reshape_test(N, M, dtype)
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
259
260
        },
    )
261
262
263
264
265
266
267
268
269
270
271
272
273
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return torch.max(A.reshape(N // M, M), dim=1).values

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reduce_after_reshape():
    run_reduce_after_reshape(1024, 32, "float32")
    run_reduce_after_reshape(2048, 64, "float16")


274
275
276
277
278
def reshape_shape_mismatch_test(N, M, dtype):
    import tilelang.language as T

    @T.prim_func
    def main(
279
280
        A: T.Tensor((N,), dtype),
        B: T.Tensor((N // M, M), dtype),
281
282
283
284
285
286
287
288
289
290
291
292
293
    ):
        with T.Kernel(1) as _:
            A_reshaped = T.reshape(A, [N // M, M + 1])
            T.copy(A_reshaped, B)

    return main


def test_reshape_shape_mismatch():
    with pytest.raises(AssertionError):
        reshape_shape_mismatch_test(1024, 32, "float32")


294
295
if __name__ == "__main__":
    tilelang.testing.main()