test_tilelang_jit_nullptr.py 3.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import tilelang.language as T
from tilelang.utils import map_torch_type


@tl.jit
def ptr_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):

    @T.prim_func
    def main(
        a_ptr: T.ptr,
        b_ptr: T.ptr,
        c_ptr: T.ptr,
        bias_ptr: T.ptr,
        m: T.int32,
        n: T.int32,
        k: T.int32,
        with_bias: T.bool,
    ):
        A = T.make_tensor(a_ptr, (m, k), dtype)
        B = T.make_tensor(b_ptr, (k, n), dtype)
        C = T.make_tensor(c_ptr, (m, n), accum_dtype)
        Bias = T.make_tensor(bias_ptr, (n), accum_dtype)

        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            T.clear(C_local)

            for ko in T.Pipelined(T.ceildiv(k, block_K), num_stages=3):
                # Copy tile of A
                T.copy(A[by * block_M, ko * block_K], A_shared)
                T.copy(B[bx * block_N, ko * block_K], B_shared)
                T.gemm(A_shared, B_shared, C_local, transpose_B=True)

            if with_bias:
                for i, j in T.Parallel(block_M, block_N):
                    C_local[i, j] += Bias[bx * block_N + j]

            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


@tl.jit
def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):

    @T.prim_func
    def main(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), accum_dtype),
            Bias: T.Tensor((N), accum_dtype),
            with_bias: T.bool,
    ):
        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            T.clear(C_local)

            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
                # Copy tile of A
                T.copy(A[by * block_M, ko * block_K], A_shared)
                T.copy(B[bx * block_N, ko * block_K], B_shared)
                T.gemm(A_shared, B_shared, C_local, transpose_B=True)

            if with_bias:
                for i, j in T.Parallel(block_M, block_N):
                    C_local[i, j] += Bias[bx * block_N + j]

            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
86
    kernel = ptr_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
87
88
89
90
91

    a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype))
    b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype))
    c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype))
    d = torch.randn(N, device="cuda", dtype=map_torch_type(accum_dtype))
92
    kernel(a, b, c, None, M, N, K, False)
93
94
95
96
97
98

    ref_no_bias = (a @ b.T).to(map_torch_type(accum_dtype))
    ref_with_bias = ref_no_bias + d

    torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2)

99
    kernel(a, b, c, d, M, N, K, True)
100
101
102

    torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2)

103
104
    kernel = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
    kernel(a, b, c, None, False)
105
    torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2)
106
    kernel(a, b, c, d, True)
107
108
109
110
111
112
113
114
115
    torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2)


def test_nullptr():
    run_test(1024, 1024, 1024, 128, 128, 32)


if __name__ == "__main__":
    tilelang.testing.main()