example_tilelang_gemm_fp8.py 1.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import torch
import tilelang
import tilelang.language as T


def calc_diff(x, y):
    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


13
@tilelang.jit(out_idx=[-1])
14
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32):
15
    @T.prim_func
16
    def gemm_fp8(
17
18
19
        A: T.Tensor((M, K), dtype),
        B: T.Tensor((N, K), dtype),
        C: T.Tensor((M, N), dtype),
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
                T.copy(A[by * block_M, k * block_K], A_shared)
                T.copy(B[bx * block_N, k * block_K], B_shared)
                T.gemm(A_shared, B_shared, C_local, transpose_B=True)

            T.copy(C_local, C[by * block_M, bx * block_N])

34
    return gemm_fp8
35
36
37


def test_gemm_fp8(M, N, K, dtype):
38
    torch_dtype = T.dtype(dtype).as_torch()
39

40
    kernel = matmul(M, N, K, 128, 128, 64, dtype)
41

42
43
    a = torch.randn(M, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype)
    b = torch.randn(N, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype)
44
45
46
47
48
49
50
51
52
53
54
55
56

    c = kernel(a, b)

    ref_c = (a.half() @ b.half().T).to(dtype=torch_dtype)

    print(c)
    print(ref_c)

    diff = calc_diff(c, ref_c)
    print(f"diff: {diff}")
    assert diff < 1e-3


57
def main():
58
59
    test_gemm_fp8(1024, 1024, 1024, T.float8_e4m3fn)
    test_gemm_fp8(1024, 1024, 1024, T.float8_e5m2)
60
61
62


if __name__ == "__main__":
63
    main()