test_tilelang_language_cumsum.py 5.86 KB
Newer Older
1
2
3
4
5
6
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import torch


7
def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32"):
8
9
10
11
    import tilelang.language as T

    @T.prim_func
    def cumsum(
12
13
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M, N), dtype),
14
15
16
17
18
19
20
21
22
23
24
25
    ):
        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_N), dtype)

            T.copy(A[by * block_M, bx * block_N], A_shared)
            T.cumsum(src=A_shared, dim=dim, reverse=reverse)
            T.copy(A_shared, B[by * block_M, bx * block_N])

    return cumsum


26
def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32"):
27
28
29
30
    import tilelang.language as T

    @T.prim_func
    def cumsum(
31
32
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M, N), dtype),
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    ):
        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_N), dtype)
            A_fragment = T.alloc_fragment((block_M, block_N), dtype)

            T.copy(A[by * block_M, bx * block_N], A_shared)
            T.copy(A_shared, A_fragment)
            T.cumsum(src=A_fragment, dim=dim, reverse=reverse)
            T.copy(A_fragment, B[by * block_M, bx * block_N])

    return cumsum


47
def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", scope="smem"):
48
49
50
51
52
    if scope == "smem":
        program = cumsum_smem_test(M, N, block_M, block_N, dim, reverse, dtype)
    elif scope == "fragment":
        program = cumsum_fragment_test(M, N, block_M, block_N, dim, reverse, dtype)
    jit_kernel = tl.compile(program, out_idx=-1)
53
54

    A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda()
55
56
57
58
59

    def ref_program(A):
        ref_b = torch.empty_like(A)
        for i in range(M // block_M):
            for j in range(N // block_N):
60
61
62
                ref_b[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = A[
                    i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N
                ].cumsum(dim=dim)
63
                if reverse:
64
65
66
67
68
69
                    ref_b[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = (
                        A[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N]
                        .flip(dims=[dim])
                        .cumsum(dim=dim)
                        .flip(dims=[dim])
                    )
70
71
        return ref_b

72
73
74
    tilelang_res = jit_kernel(A)
    ref_res = ref_program(A)
    torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3)
75
76


77
78
79
80
81
def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"):
    import tilelang.language as T

    @T.prim_func
    def cumsum(
82
83
        A: T.Tensor((N,), dtype),
        B: T.Tensor((N,), dtype),
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    ):
        with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
            A_shared = T.alloc_shared((block_N,), dtype)

            T.copy(A[bx * block_N], A_shared)
            T.cumsum(src=A_shared, dim=0, reverse=reverse)
            T.copy(A_shared, B[bx * block_N])

    return cumsum


def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"):
    import tilelang.language as T

    @T.prim_func
    def cumsum(
100
101
        A: T.Tensor((N,), dtype),
        B: T.Tensor((N,), dtype),
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    ):
        with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
            A_shared = T.alloc_shared((block_N,), dtype)
            A_fragment = T.alloc_fragment((block_N,), dtype)

            T.copy(A[bx * block_N], A_shared)
            T.copy(A_shared, A_fragment)
            T.cumsum(src=A_fragment, dim=0, reverse=reverse)
            T.copy(A_fragment, B[bx * block_N])

    return cumsum


def run_cumsum_1d(N, block_N, reverse=False, dtype="float32", scope="smem"):
    if scope == "smem":
        program = cumsum_smem_test_1d(N, block_N, reverse, dtype)
    elif scope == "fragment":
        program = cumsum_fragment_test_1d(N, block_N, reverse, dtype)
    else:
        raise ValueError(f"Unknown scope {scope}")

    jit_kernel = tl.compile(program, out_idx=-1)
    A = torch.randn(N, dtype=getattr(torch, dtype)).cuda()

    def ref_program(A):
        ref_b = torch.empty_like(A)
        num_blocks = (N + block_N - 1) // block_N
        for j in range(num_blocks):
            start = j * block_N
            end = min(start + block_N, N)
            chunk = A[start:end]
            if reverse:
                chunk = torch.flip(chunk, dims=[0])
            chunk = chunk.cumsum(dim=0)
            if reverse:
                chunk = torch.flip(chunk, dims=[0])
            ref_b[start:end] = chunk
        return ref_b

    tilelang_res = jit_kernel(A)
    ref_res = ref_program(A)
    torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3)


146
147
148
149
150
151
152
153
def test_cumsum_smem():
    # Test different sizes
    run_cumsum(1024, 1024, 128, 128)
    run_cumsum(1024, 1024, 128, 128, dim=1)
    run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True)

    # Test different dtypes
    run_cumsum(256, 256, 128, 128, dtype="float32")
154
    run_cumsum(256, 256, 128, 128, dtype="float32")
155
156
157
158
159
160
161
162
163


def test_cumsum_fragment():
    run_cumsum(1024, 1024, 128, 128, scope="fragment")
    run_cumsum(1024, 1024, 128, 128, dim=1, scope="fragment")
    run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True, scope="fragment")

    # Test different dtypes
    run_cumsum(256, 256, 128, 128, dtype="float32", scope="fragment")
164
    run_cumsum(256, 256, 128, 128, dtype="float32", scope="fragment")
165
166


167
168
169
170
171
172
173
174
175
176
def test_cumsum_smem_1d():
    run_cumsum_1d(1024, 128)
    run_cumsum_1d(1024, 128, reverse=True)


def test_cumsum_fragment_1d():
    run_cumsum_1d(1024, 128, scope="fragment")
    run_cumsum_1d(1024, 128, reverse=True, scope="fragment")


177
178
if __name__ == "__main__":
    tilelang.testing.main()