example_tilelang_gemm_amd.py 3.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import torch_assert_close
import itertools


def ref_program(A, B):
    return (A.half() @ B.half().T).to(dtype=torch.float32)


def manual_check_prog(C, C_ref):
    torch_assert_close(C[0], C_ref[0], rtol=0.01, atol=0.1)


def supply_prog(args):
    a_param, b_param = args
    M, K = a_param.shape
    N, _ = b_param.shape
20
21
    a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
    b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    return [a, b]


def get_configs():
    block_Ms = [32, 64, 128]
    block_Ns = [32, 64, 128]
    block_Ks = [64, 128]
    num_stages = [0]
    num_threads = [256]
    k_packs = [1, 2]
    gemm_types = ["ss", "rs"]

    valid_configs = []

36
37
38
39
40
41
42
43
44
45
46
47
    for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types):
        valid_configs.append(
            {
                "block_M": m,
                "block_N": n,
                "block_K": k,
                "num_stages": stages,
                "num_threads": t,
                "k_pack": kp,
                "gemm_type": gemm_type,
            }
        )
48
49
50
51
    return valid_configs


@tilelang.autotune(
52
53
    configs=get_configs(), cache_input_tensors=True, ref_prog=ref_program, manual_check_prog=manual_check_prog, supply_prog=supply_prog
)
54
55
@tilelang.jit(out_idx=[-1])
def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type):
56
57
    dtype = T.float8_e4m3fnuz
    accum_dtype = T.float32
58
59
60

    @T.prim_func
    def gemm_fp8_rs(
61
62
63
        A: T.Tensor((M, K), dtype),
        B: T.Tensor((N, K), dtype),
        C: T.Tensor((M, N), accum_dtype),
64
    ):
65
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
66
67
68
69
70
71
72
73
            A_local = T.alloc_fragment((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(A[by * block_M, k * block_K], A_local)
                T.copy(B[bx * block_N, k * block_K], B_shared)
74
                T.gemm(A_local, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow)
75
76
77
78
79

            T.copy(C_local, C[by * block_M, bx * block_N])

    @T.prim_func
    def gemm_fp8_ss(
80
81
82
        A: T.Tensor((M, K), dtype),
        B: T.Tensor((N, K), dtype),
        C: T.Tensor((M, N), accum_dtype),
83
    ):
84
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
85
86
87
88
89
90
91
92
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(A[by * block_M, k * block_K], A_shared)
                T.copy(B[bx * block_N, k * block_K], B_shared)
93
                T.gemm(A_shared, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow)
94
95
96
97
98
99
100
101
102
103
104
105
106

            T.copy(C_local, C[by * block_M, bx * block_N])

    if gemm_type == "ss":
        return gemm_fp8_ss
    elif gemm_type == "rs":
        return gemm_fp8_rs
    else:
        raise ValueError(f"Invalid gemm_type: {gemm_type}")


def test_gemm_fp8(M, N, K):
    kernel = fp8_matmul(M, N, K)
107
108
    a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
    b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
109
110
111
112
113
114
115
116
    c = kernel(a, b)
    ref_c = ref_program(a, b)
    torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
    print("passed~")


if __name__ == "__main__":
    test_gemm_fp8(512, 512, 512)