test_tilelang_cpu_gemm.py 4.12 KB
Newer Older
1
2
3
4
import tilelang
import tilelang.testing
from tilelang import tvm as tvm
import tilelang.language as T
5
import torch
6

7
8
tilelang.disable_cache()

9
10
11
12
13
14

def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
    num_stages = 0

    @T.prim_func
    def matmul(
15
16
17
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), dtype),
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by):
            A_local = T.alloc_local((block_M, block_K), dtype)
            B_local = T.alloc_local((block_K, block_N), dtype)
            C_local = T.alloc_local((block_M, block_N), accum_dtype)

            T.clear(C_local)

            # Apply layout optimizations or define your own layout
            # (Optional).
            # T.annotate_layout(
            #     {
            #         A_local: make_swizzle_layout(A_local),
            #         B_local: make_swizzle_layout(B_local),
            #     }
            # )

            for ko in T.Pipelined(K // block_K, num_stages=num_stages):

                T.copy(A[by * block_M, ko * block_K], A_local)

                # Or Copy with Parallel
                for k, j in T.Parallel(block_K, block_N):
                    B_local[k, j] = B[ko * block_K + k, by * block_N + j]

                for i, j, k in T.grid(block_M, block_N, block_K):
                    C_local[i, j] += A_local[i, k] * B_local[k, j]

            T.copy(C_local, C[by * block_M, bx * block_N])

    return matmul


def assert_matmul_codegen(M=1024, N=1024, K=1024, block_M=128, block_N=128, block_K=32):
    func = matmul(M, N, K, block_M, block_N, block_K)

54
    artifact = tilelang.lower(func, target="c")
55

56
    code = artifact.kernel_source
57
58
59
60
61
62
63
64

    assert code is not None, "Code generation failed"


def test_matmul_codegen():
    assert_matmul_codegen(M=1024, N=1024, K=1024, block_M=128, block_N=128, block_K=32)


65
66
67
68
69
70
def test_matmul_compile():

    def matmul_jit_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
        # a simple kernel just for jit test
        @T.prim_func
        def matmul(
71
72
73
                A: T.Tensor((M, K), dtype),
                B: T.Tensor((K, N), dtype),
                C: T.Tensor((M, N), dtype),
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        ):
            with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by):
                A_local = T.alloc_local((block_M, block_K), dtype)
                B_local = T.alloc_local((block_K, block_N), dtype)
                C_local = T.alloc_local((block_M, block_N), accum_dtype)

                for p in T.serial(block_M):
                    for w in T.serial(block_N):
                        C_local[p, w] = 0
                for ko in T.serial(K // block_K):
                    for i in T.serial(block_M):
                        for k in T.serial(block_K):
                            A_local[i, k] = A[by * block_M + i, ko * block_K + k]

                    for k in T.serial(block_K):
                        for j in T.serial(block_N):
                            B_local[k, j] = B[ko * block_K + k, bx * block_N + j]

                    for i in T.serial(block_M):
                        for j in T.serial(block_N):
                            for k in T.serial(block_K):
                                C_local[i, j] += A_local[i, k] * B_local[k, j]

                for i in T.serial(block_M):
                    for j in T.serial(block_N):
                        C[by * block_M + i, bx * block_N + j] = C_local[i, j]

        return matmul

    M, N, K = 1024, 512, 512
    block_M, block_N, block_K = M // 4, N // 4, K // 4
    cpu_func = matmul_jit_test(M, N, K, block_M, block_N, block_K)
    complied_fun = tilelang.compile(cpu_func, -1, execution_backend="ctypes", target="c")

    in_dtype = "float16"
    A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype))
    B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype))

    C = complied_fun(A, B)
    C_torch = torch.matmul(A, B)

    tilelang.testing.torch_assert_close(C, C_torch, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)


118
119
if __name__ == "__main__":
    tilelang.testing.main()