test_tilelang_language_reshape.py 3.45 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl


def reshape_test(N, M, dtype):
    import tilelang.language as T

    @T.prim_func
    def main(
            A: T.Tensor((N,), dtype),
            B: T.Tensor((N // M, M), dtype),
    ):
        with T.Kernel(1) as _:
            A_reshaped = T.reshape(A, [N // M, M])
            T.copy(A_reshaped, B)

    return main


def run_reshape(N, M, dtype):
    program = reshape_test(N, M, dtype)
    # TODO(lei): reshape cannot apply shared memory
    # layout transform propagation
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        })
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return A.reshape(N // M, M)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reshape_smem():
    # Test reshape
    run_reshape(1024, 32, "float32")
    run_reshape(2048, 64, "float16")


def reshape_test_smem_1d_2_2d(N, M, dtype):
    import tilelang.language as T

    @T.prim_func
    def main(
            A: T.Tensor((N,), dtype),
            B: T.Tensor((N // M, M), dtype),
    ):
        with T.Kernel(1) as _:
            A_shared = T.alloc_shared((N,), dtype)
            for i in T.Parallel(N):
                A_shared[i] = A[i]

            A_smem_reshaped = T.reshape(A_shared, [N // M, M])
            T.copy(A_smem_reshaped, B)

    return main


def run_reshape_smem_1d_2_2d(N, M, dtype):
    program = reshape_test_smem_1d_2_2d(N, M, dtype)
    # TODO(lei): reshape cannot apply shared memory
    # layout transform propagation
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        })
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return A.reshape(N // M, M)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reshape_smem_1d_2_2d():
    run_reshape_smem_1d_2_2d(1024, 32, "float32")
    run_reshape_smem_1d_2_2d(2048, 64, "float16")


def reshape_test_smem_2d_2_1d(N, M, dtype):
    import tilelang.language as T

    @T.prim_func
    def main(
            A: T.Tensor((N // M, M), dtype),
            B: T.Tensor((N,), dtype),
    ):
        with T.Kernel(1) as _:
            A_shared = T.alloc_shared((N // M, M), dtype)
            for i, j in T.Parallel(N // M, M):
                A_shared[i, j] = A[i, j]

            A_smem_reshaped = T.reshape(A_shared, [N])
            T.copy(A_smem_reshaped, B)

    return main


def run_reshape_smem_2d_2_1d(N, M, dtype):
    program = reshape_test_smem_2d_2_1d(N, M, dtype)
    # TODO(lei): reshape cannot apply shared memory
    # layout transform propagation
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        })
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return A.reshape(N)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reshape_smem_2d_2_1d():
    run_reshape_smem_2d_2_1d(1024, 32, "float32")
    run_reshape_smem_2d_2_1d(2048, 64, "float16")


if __name__ == "__main__":
    tilelang.testing.main()