"docs/vscode:/vscode.git/clone" did not exist on "16748d1eba10a96ad42d470000fbabbe61e34d2e"
test_tilelang_kernel_gemm_simt.py 5.98 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.transform import simplify_prim_func

tilelang.testing.set_random_seed(0)


def make_swizzle_layout(shared_buf):
    dtype = shared_buf.dtype
    shape = shared_buf.shape

    can_swizzle = shape[-1] * DataType(dtype).bits == 512
    if not can_swizzle:
        return T.Layout(shape, lambda *args: args)

    def transform_func(i, j):
        new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
        return [new_warp_i, new_warp_j]

    return T.Layout(shape, transform_func)


@simplify_prim_func
def tl_matmul_simt(
    M,
    N,
    K,
    in_dtype,
    out_dtype,
    accum_dtype,
):
    assert in_dtype in [
        "float16",
        "int8",
    ], "Currently only float16 and int8 are supported"
    assert out_dtype in [
        "float16",
        "float32",
        "int32",
    ], "Currently only float16, float32 and int32 are supported"

    # This is a debug config
    block_size_x = 8
    block_size_y = 8
    thread_row_tiles = 16
    thread_col_tiles = 16
    chunk = 16

    shared_scope = "shared"

    block_M = block_size_x * thread_row_tiles
    block_N = block_size_y * thread_col_tiles
    block_K = chunk

    # Pipeline Stage

    A_shape = (M, K)
    B_shape = (N, K)
    C_shape = (M, N)
    A_shared_shape = (block_M, block_K)
    B_shared_shape = (block_N, block_K)

    threads = thread_row_tiles * thread_col_tiles
    local_size_a = block_M // thread_row_tiles
    local_size_b = block_N // thread_col_tiles
    local_size_c = (block_M // thread_row_tiles) * (block_N // thread_col_tiles)

    micro_size_k = 128 // DataType(in_dtype).bits
    dp4a_size = 4
    use_dp4a = in_dtype == "int8" and accum_dtype == "int32"

    @T.prim_func
    def main(
            A: T.Tensor(A_shape, in_dtype),
            B: T.Tensor(B_shape, in_dtype),
            C: T.Tensor(C_shape, out_dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):

            A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)

            A_local = T.alloc_local((local_size_a, micro_size_k), in_dtype)
            B_local = T.alloc_local((local_size_b, micro_size_k), in_dtype)
            C_local = T.alloc_local((local_size_c,), accum_dtype)

            tid = T.get_thread_binding()

            warp_m = tid % thread_row_tiles
            warp_n = tid // thread_row_tiles

            T.clear(C_local)

            for ko in T.serial(K // block_K):

                # Load A into shared memory
                for i, k in T.Parallel(block_M, block_K):
                    A_shared[i, k] = A[by * block_M + i, ko * block_K + k]

                # Load B into shared memory
                for j, k in T.Parallel(block_N, block_K):
                    B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]

                for ki in T.serial((block_K // micro_size_k)):
                    for i in T.serial(local_size_a):
                        for mk in T.vectorized(micro_size_k):
                            A_local[i, mk] = A_shared[warp_m * local_size_a + i,
                                                      ki * micro_size_k + mk]

                    for i in T.serial(local_size_b):
                        for mk in T.vectorized(micro_size_k):
                            B_local[i, mk] = B_shared[warp_n * local_size_b + i,
                                                      ki * micro_size_k + mk]

                    for i, j in T.grid(local_size_a, local_size_b):
                        for mk in T.serial(micro_size_k // dp4a_size):
                            if use_dp4a:
                                T.dp4a(A_local[i, mk * dp4a_size], B_local[j, mk * dp4a_size],
                                       C_local[i * local_size_b + j])
                            else:
                                for dp4a_idx in T.serial(dp4a_size):
                                    C_local[i * local_size_b +
                                            j] += A_local[i, mk * dp4a_size +
                                                          dp4a_idx] * B_local[j, mk * dp4a_size +
                                                                              dp4a_idx]

            for i, j in T.grid(local_size_a, local_size_b):
                C[by * block_M + warp_m * local_size_a + i,
                  bx * block_N + warp_n * local_size_b + j] = C_local[i * local_size_b + j]

    return main


def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
    matmul = tl_matmul_simt(M, N, K, in_dtype, out_dtype, accum_dtype)
    kernel = tilelang.compile(matmul, out_idx=[2])
    profiler = kernel.get_profiler()

    src_code = kernel.get_kernel_source()
    print(src_code)
    # src_code is the generated cuda source
    assert src_code is not None

    if in_dtype == "int8":
        A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8)
        B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8)
    else:
        A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
        B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))

    C = kernel(A, B)

    latency = profiler.do_bench()

    # Ensure that the latency is not None
    assert latency is not None

    # Get Reference Result
    ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype))
    print(C)
    print(ref_c)
    torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)


def test_assert_tl_matmul():
    assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
    assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32")
    assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32")


if __name__ == "__main__":
    tilelang.testing.main()