example_tilelang_gemm_fp8_2xAcc.py 2.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type


def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
    # for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128.
    # if block_K < 128, promote after 128/block_K iters.
    # if block_K > 128, promote after every iter.
    update_interval = 128 // block_K if block_K < 128 else 1

    @T.prim_func
14
    def gemm_fp8_2xAcc(
15
16
17
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((N, K), dtype),
            C: T.Tensor((M, N), accum_dtype),
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)

            T.clear(C_local)
            T.clear(C_local_accum)
            K_iters = T.ceildiv(K, block_K)
            for k in T.Pipelined(K_iters, num_stages=3):
                T.copy(A[by * block_M, k * block_K], A_shared)
                T.copy(B[bx * block_N, k * block_K], B_shared)
                T.gemm(A_shared, B_shared, C_local, transpose_B=True)
                # Promote to enable 2xAcc
                if (k + 1) % update_interval == 0:
                    for i, j in T.Parallel(block_M, block_N):
                        C_local_accum[i, j] += C_local[i, j]
                    T.clear(C_local)
            # Tail processing
            if K_iters % update_interval != 0:
                for i, j in T.Parallel(block_M, block_N):
                    C_local_accum[i, j] += C_local[i, j]
            # TMA store
            T.copy(C_local_accum, C_shared)
            T.copy(C_shared, C[by * block_M, bx * block_N])

46
    return gemm_fp8_2xAcc
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76


def calc_diff(x, y):
    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


def test_gemm_fp8(M, N, K, dtype):
    torch_dtype = map_torch_type(dtype)

    func = matmul(M, N, K, 128, 128, 64, dtype)

    kernel = tilelang.compile(func, out_idx=-1)

    a = torch.rand(M, K, dtype=torch.float16, device='cuda')
    a = (100 * (2 * a - 1)).to(dtype=torch_dtype)
    b = torch.rand(N, K, dtype=torch.float16, device='cuda')
    b = (100 * (2 * b - 1)).to(dtype=torch_dtype)

    c = kernel(a, b)

    ref_c = (a.float() @ b.float().T)

    diff = calc_diff(c, ref_c)
    print(f"diff: {diff}")
    assert diff < 1e-3


77
def main():
78
79
    test_gemm_fp8(1024, 1024, 8192, 'e4m3_float8')
    test_gemm_fp8(1024, 1024, 8192, 'e5m2_float8')
80
81
82
83


if __name__ == "__main__":
    main()