test_tilelang_language_reshape.py 7.54 KB
Newer Older
1
2
import tilelang.testing
import tilelang as tl
3
from tilelang import language as T
4
import torch
5
import pytest
6
7
8
9
10


def reshape_test(N, M, dtype):
    @T.prim_func
    def main(
11
12
        A: T.Tensor((N,), dtype),
        B: T.Tensor((N // M, M), dtype),
13
14
15
16
17
18
19
20
21
22
    ):
        with T.Kernel(1) as _:
            A_reshaped = T.reshape(A, [N // M, M])
            T.copy(A_reshaped, B)

    return main


def run_reshape(N, M, dtype):
    program = reshape_test(N, M, dtype)
23
24
25
26
27
28
29
30
    # TODO(lei): reshape cannot apply shared memory
    # layout transform propagation
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
31
32
        },
    )
33
34
35
36
37
38
39
40
41
42
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return A.reshape(N // M, M)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reshape_smem():
    # Test reshape
43
44
    run_reshape(1024, 32, T.float32)
    run_reshape(2048, 64, T.float16)
45
46


47
def reshape_test_smem_1d_2_2d(N, M, dtype):
48
49
    @T.prim_func
    def main(
50
51
        A: T.Tensor((N,), dtype),
        B: T.Tensor((N // M, M), dtype),
52
53
54
    ):
        with T.Kernel(1) as _:
            A_shared = T.alloc_shared((N,), dtype)
55
            for i in T.Parallel(N):
56
57
58
                A_shared[i] = A[i]

            A_smem_reshaped = T.reshape(A_shared, [N // M, M])
59
            T.copy(A_smem_reshaped, B)
60
61
62
63

    return main


64
65
def run_reshape_smem_1d_2_2d(N, M, dtype):
    program = reshape_test_smem_1d_2_2d(N, M, dtype)
66
67
68
69
70
71
72
73
    # TODO(lei): reshape cannot apply shared memory
    # layout transform propagation
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
74
75
        },
    )
76
77
78
79
80
81
82
83
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return A.reshape(N // M, M)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


84
def test_reshape_smem_1d_2_2d():
85
86
    run_reshape_smem_1d_2_2d(1024, 32, T.float32)
    run_reshape_smem_1d_2_2d(2048, 64, T.float16)
87
88
89
90
91


def reshape_test_smem_2d_2_1d(N, M, dtype):
    @T.prim_func
    def main(
92
93
        A: T.Tensor((N // M, M), dtype),
        B: T.Tensor((N,), dtype),
94
95
96
97
98
99
100
101
102
103
104
    ):
        with T.Kernel(1) as _:
            A_shared = T.alloc_shared((N // M, M), dtype)
            for i, j in T.Parallel(N // M, M):
                A_shared[i, j] = A[i, j]

            A_smem_reshaped = T.reshape(A_shared, [N])
            T.copy(A_smem_reshaped, B)

    return main

Gabriel Wu's avatar
Gabriel Wu committed
105

106
107
def run_reshape_smem_2d_2_1d(N, M, dtype):
    program = reshape_test_smem_2d_2_1d(N, M, dtype)
108
109
110
111
112
113
114
115
    # TODO(lei): reshape cannot apply shared memory
    # layout transform propagation
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
116
117
        },
    )
118
119
120
121
122
123
124
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return A.reshape(N)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)

Gabriel Wu's avatar
Gabriel Wu committed
125

126
def test_reshape_smem_2d_2_1d():
127
128
    run_reshape_smem_2d_2_1d(1024, 32, T.float32)
    run_reshape_smem_2d_2_1d(2048, 64, T.float16)
129

130

131
132
133
def reshape_fragment_test(N, M, dtype):
    @T.prim_func
    def main(
134
135
        A: T.Tensor((N // M, M), dtype),
        B: T.Tensor((N,), dtype),
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    ):
        with T.Kernel(1, threads=32) as _:
            A_shared = T.alloc_shared((N // M, M), dtype, scope="shared")
            A_local = T.alloc_fragment((N // M, M), dtype)
            B_shared = T.alloc_shared((N,), dtype, scope="shared")

            T.copy(A, A_shared)
            T.copy(A_shared, A_local)
            A_local_reshape = T.reshape(A_local, [N])
            T.copy(A_local_reshape, B_shared)
            T.copy(B_shared, B)

    return main


def run_reshape_fragment(N, M, dtype):
    program = reshape_fragment_test(N, M, dtype)
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
159
160
        },
    )
161
162
163
164
165
166
167
168
169
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return A.reshape(N)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reshape_fragment():
170
171
    run_reshape_fragment(1024, 32, T.float32)
    run_reshape_fragment(2048, 64, T.float16)
172
173
174
175
176
177
178


def reshape_layout_transform_shared(N, M, dtype):
    from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout

    @T.prim_func
    def main(
179
180
        A: T.Tensor((N // M, M), dtype),
        B: T.Tensor((N,), dtype),
181
182
183
184
    ):
        with T.Kernel(1, threads=32) as _:
            A_shared = T.alloc_shared((N // M, M), dtype, scope="shared")

185
186
187
188
189
            T.annotate_layout(
                {
                    A_shared: make_mma_swizzle_layout(A_shared),
                }
            )
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
            T.copy(A, A_shared)
            A_shared_reshape = T.reshape(A_shared, [N])
            T.copy(A_shared_reshape, B)

    return main


def run_reshape_layout_transform_shared(N, M, dtype):
    program = reshape_layout_transform_shared(N, M, dtype)
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
205
206
        },
    )
207
208
209
210
211
212
213
214
215
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return A.reshape(N)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reshape_layout_transform_shared():
216
217
    run_reshape_layout_transform_shared(1024, 32, T.float32)
    run_reshape_layout_transform_shared(2048, 64, T.float16)
218
219
220
221
222


def reduce_after_reshape_test(N, M, dtype):
    @T.prim_func
    def main(
223
224
        A: T.Tensor((N,), dtype),
        B: T.Tensor((N // M,), dtype),
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    ):
        with T.Kernel(1, threads=32) as _:
            A_shared = T.alloc_shared((N,), dtype, scope="shared")
            A_local = T.alloc_fragment((N,), dtype)
            B_local = T.alloc_fragment((N // M,), dtype)

            T.copy(A, A_shared)
            T.copy(A_shared, A_local)
            A_local_reshape = T.reshape(A_local, [N // M, M])
            T.reduce_max(A_local_reshape, B_local, dim=1)
            T.copy(B_local, B)

    return main


def run_reduce_after_reshape(N, M, dtype):
    program = reduce_after_reshape_test(N, M, dtype)
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
248
249
        },
    )
250
251
252
253
254
255
256
257
258
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return torch.max(A.reshape(N // M, M), dim=1).values

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reduce_after_reshape():
259
260
    run_reduce_after_reshape(1024, 32, T.float32)
    run_reduce_after_reshape(2048, 64, T.float16)
261
262


263
264
265
def reshape_shape_mismatch_test(N, M, dtype):
    @T.prim_func
    def main(
266
267
        A: T.Tensor((N,), dtype),
        B: T.Tensor((N // M, M), dtype),
268
269
270
271
272
273
274
275
276
277
    ):
        with T.Kernel(1) as _:
            A_reshaped = T.reshape(A, [N // M, M + 1])
            T.copy(A_reshaped, B)

    return main


def test_reshape_shape_mismatch():
    with pytest.raises(AssertionError):
278
        reshape_shape_mismatch_test(1024, 32, T.float32)
279
280


281
282
if __name__ == "__main__":
    tilelang.testing.main()