"...resnet50_tensorflow.git" did not exist on "9ca59f8a50d079e89dc799dfd81c7597f59b3b79"
test_tilelang_jit_nullptr.py 3.91 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import tilelang.language as T
from tilelang.utils import map_torch_type


@tl.jit
def ptr_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):

    @T.prim_func
    def main(
        a_ptr: T.ptr,
        b_ptr: T.ptr,
        c_ptr: T.ptr,
        bias_ptr: T.ptr,
        m: T.int32,
        n: T.int32,
        k: T.int32,
        with_bias: T.bool,
    ):
        A = T.make_tensor(a_ptr, (m, k), dtype)
        B = T.make_tensor(b_ptr, (k, n), dtype)
        C = T.make_tensor(c_ptr, (m, n), accum_dtype)
        Bias = T.make_tensor(bias_ptr, (n), accum_dtype)

        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            T.clear(C_local)

            for ko in T.Pipelined(T.ceildiv(k, block_K), num_stages=3):
                # Copy tile of A
                T.copy(A[by * block_M, ko * block_K], A_shared)
                T.copy(B[bx * block_N, ko * block_K], B_shared)
                T.gemm(A_shared, B_shared, C_local, transpose_B=True)

            if with_bias:
                for i, j in T.Parallel(block_M, block_N):
                    C_local[i, j] += Bias[bx * block_N + j]

            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


@tl.jit
def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):

    @T.prim_func
    def main(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), accum_dtype),
            Bias: T.Tensor((N), accum_dtype),
            with_bias: T.bool,
    ):
        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            T.clear(C_local)

            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
                # Copy tile of A
                T.copy(A[by * block_M, ko * block_K], A_shared)
                T.copy(B[bx * block_N, ko * block_K], B_shared)
                T.gemm(A_shared, B_shared, C_local, transpose_B=True)

            if with_bias:
                for i, j in T.Parallel(block_M, block_N):
                    C_local[i, j] += Bias[bx * block_N + j]

            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
    func = ptr_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)

    a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype))
    b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype))
    c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype))
    d = torch.randn(N, device="cuda", dtype=map_torch_type(accum_dtype))

    func(a, b, c, None, M, N, K, False)

    ref_no_bias = (a @ b.T).to(map_torch_type(accum_dtype))
    ref_with_bias = ref_no_bias + d

    torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2)

    func(a, b, c, d, M, N, K, True)

    torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2)

    func = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
    func(a, b, c, None, False)
    torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2)
    func(a, b, c, d, True)
    torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2)


def test_nullptr():
    run_test(1024, 1024, 1024, 128, 128, 32)


if __name__ == "__main__":
    tilelang.testing.main()