test_tilelang_language_pipeline.py 6.88 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
from tilelang import tvm as tvm
import tilelang.testing


def matmul(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    accum_dtype,
    threads,
    order,
    stage,
):
    A_shape = (K, M) if trans_A else (M, K)
    B_shape = (N, K) if trans_B else (K, N)
    A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
    B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

    import tilelang.language as T

    @T.prim_func
    def main(
            A: T.Tensor(A_shape, in_dtype),
            B: T.Tensor(B_shape, in_dtype),
            C: T.Tensor((M, N), out_dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage):
                if trans_A:
                    T.copy(A[k * block_K, by * block_M], A_shared)
                else:
                    T.copy(A[by * block_M, k * block_K], A_shared)
                if trans_B:
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                else:
                    T.copy(B[k * block_K, bx * block_N], B_shared)
                T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def run_gemm(
    order,
    stage,
):
    M = 1024
    N = 1024
    K = 1024
    block_M = 128
    block_N = 128
    block_K = 32
    trans_A = False
    trans_B = False
    in_dtype = "float16"
    out_dtype = "float16"
    dtypeAccum = "float32"
    num_threads = 128
    program = matmul(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_threads,
        order,
        stage,
    )

    kernel = tilelang.compile(
        program,
        out_idx=[2],
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        })
    profiler = kernel.get_profiler()

    def ref_program(A, B):
        import torch

        if trans_A:
            A = A.T
        if trans_B:
            B = B.T
        if in_dtype == "float32":
            # Convert float32 to tfloat32 because tfloat32 mma cannot truncate
            # float32 automatically, -0x1000 meas
            A = ((A.view(torch.int32) - 0x1000)).view(torch.float32)
            B = ((B.view(torch.int32) - 0x1000)).view(torch.float32)
        C = torch.matmul(A.to(torch.float), B.to(torch.float))
        C = C.to(torch.__getattribute__(out_dtype))
        return C

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_pipeline_order_stage():
    run_gemm(order=[0, 1, 2], stage=[0, 0, 1])
    run_gemm(order=[0, 1, 2], stage=[0, 0, 2])
    run_gemm(order=[1, 2, 0], stage=[0, 0, 2])
    run_gemm(order=[1, 2, 0], stage=[0, 0, 1])


@tilelang.jit(
    out_idx=[-1],
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
    })
def blocksparse_matmul(M,
                       N,
                       K,
                       block_M,
                       block_N,
                       block_K,
                       num_stages,
                       dtype="float16",
                       accum_dtype="float"):

    block_mask_shape = (M // block_M, N // block_N, K // block_K)

    import tilelang.language as T

    @T.prim_func
    def block_sparse_matmul(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            BlockMask: T.Tensor(block_mask_shape, "bool"),
            C: T.Tensor((M, N), dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            block_mask = T.alloc_local((1,), "bool")
            C_shared = T.alloc_shared((block_M, block_N), dtype)

            T.clear(C_local)

            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                block_mask[0] = BlockMask[by, bx, k]
                if block_mask[0]:
                    T.copy(A[by * block_M, k * block_K], A_shared)
                    T.copy(B[k * block_K, bx * block_N], B_shared)
                    T.gemm(A_shared, B_shared, C_local)

            T.copy(C_local, C_shared)
            T.copy(C_shared, C[by * block_M, bx * block_N])

    return block_sparse_matmul


def run_blocksparse_matmul(num_stages):
    import torch

    M = 256
    N = 256
    K = 256
    block_M = 128
    block_N = 128
    block_K = 32
    sparsity = 0.5

    # Initialize input matrices A and B on the GPU with half precision
    a = torch.randn(M, K).cuda().half()
    b = torch.randn(K, N).cuda().half()

    kernel = blocksparse_matmul(
        M, N, K, block_M=block_M, block_N=block_N, block_K=block_K, num_stages=num_stages)
    print(kernel.get_kernel_source())
    # Create block mask with desired sparsity
    mask_shape = (M // block_M, N // block_N, K // block_K)
    block_mask = torch.rand(mask_shape).cuda() > sparsity

    # Run the compiled kernel (either tuned or default) with the inputs
    c = kernel(a, b, block_mask)

    def ref_program(A, B, BlockMask, block_M, block_N, block_K):
        ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device)
        for i in range(M // block_M):
            for j in range(N // block_N):
                accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
                for k in range(K // block_K):
                    if BlockMask[i, j, k]:
                        accu += (
                            A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to(
                                torch.float32) @ B[k * block_K:(k + 1) * block_K,
                                                   j * block_N:(j + 1) * block_N].to(torch.float32))
                ref_c[i * block_M:(i + 1) * block_M,
                      j * block_N:(j + 1) * block_N] = accu.to(torch.float16)
        return ref_c

    # Compute the reference result using the naive PyTorch implementation
    ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K)

    torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)


def test_blocksparse_matmul():
    run_blocksparse_matmul(num_stages=1)
    run_blocksparse_matmul(num_stages=2)
    run_blocksparse_matmul(num_stages=3)


if __name__ == "__main__":
    tilelang.testing.main()