gemm_tcgen5mma.py 2.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import tilelang
import tilelang.language as T

tilelang.disable_cache()


def matmul(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    threads,
):
    A_shape = (K, M) if trans_A else (M, K)
    B_shape = (N, K) if trans_B else (K, N)
    A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
    B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

    @T.prim_func
    def main(
            A: T.Tensor(A_shape, in_dtype),
            B: T.Tensor(B_shape, in_dtype),
            C: T.Tensor((M, N), out_dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
38
            mbar = T.alloc_barrier(1)
39
40
41
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            C_shared = T.alloc_shared((block_M, block_N), out_dtype)

42
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
43
44
45
46
47
48
49
50
51
52
53
54
55
                T.copy(A[by * block_M, k * block_K], A_shared)
                T.copy(B[bx * block_N, k * block_K], B_shared)
                T.gemm(
                    A_shared,
                    B_shared,
                    C_tmem,
                    trans_A,
                    trans_B,
                    mbar=mbar,
                    wg_wait=-1,
                    clear_accum=k == 0)
                T.mbarrier_wait_parity(mbar, k % 2)

56
57
            T.copy(C_tmem, C_local)
            T.copy(C_local, C_shared)
58
59
60
61
62
63
64
65
66
67

            T.copy(C_shared, C[by * block_M, bx * block_N])

    return main


M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 128, 256, 128
trans_A, trans_B = False, True
in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float"
68
num_stages = 2
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
threads = 256

func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
              accum_dtype, num_stages, threads)
jit_kernel = tilelang.compile(
    func,
    out_idx=[2],
    target="cuda",
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
    })

print(jit_kernel.get_kernel_source())

a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
c = jit_kernel(a, b)
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)

profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Flops: {2 * M * N * K / (latency/1e3) / 1e12} TFLOPS")