example_tilelang_gemm_amd.py 4.38 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import torch_assert_close
import itertools


def ref_program(A, B):
    return (A.half() @ B.half().T).to(dtype=torch.float32)


def manual_check_prog(C, C_ref):
    torch_assert_close(C[0], C_ref[0], rtol=0.01, atol=0.1)


def supply_prog(args):
    a_param, b_param = args
    M, K = a_param.shape
    N, _ = b_param.shape
    a = (torch.randn(M, K, dtype=torch.float16, device='cuda') *
         0.01).to(dtype=torch.float8_e4m3fnuz)
    b = (torch.randn(N, K, dtype=torch.float16, device='cuda') *
         0.01).to(dtype=torch.float8_e4m3fnuz)
    return [a, b]


def get_configs():
    block_Ms = [32, 64, 128]
    block_Ns = [32, 64, 128]
    block_Ks = [64, 128]
    num_stages = [0]
    num_threads = [256]
    k_packs = [1, 2]
    gemm_types = ["ss", "rs"]

    valid_configs = []

    for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks,
                                                               num_stages, num_threads, k_packs,
                                                               gemm_types):
        valid_configs.append({
            "block_M": m,
            "block_N": n,
            "block_K": k,
            "num_stages": stages,
            "num_threads": t,
            "k_pack": kp,
            "gemm_type": gemm_type,
        })
    return valid_configs


@tilelang.autotune(
    configs=get_configs(),
    cache_input_tensors=True,
    ref_prog=ref_program,
    manual_check_prog=manual_check_prog,
    supply_prog=supply_prog)
@tilelang.jit(out_idx=[-1])
def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type):
    dtype = "float8_e4m3fnuz"
    accum_dtype = "float"

    @T.prim_func
    def gemm_fp8_rs(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((N, K), dtype),
            C: T.Tensor((M, N), accum_dtype),
    ):
        with T.Kernel(
                T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
            A_local = T.alloc_fragment((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(A[by * block_M, k * block_K], A_local)
                T.copy(B[bx * block_N, k * block_K], B_shared)
                T.gemm(
                    A_local,
                    B_shared,
                    C_local,
                    transpose_B=True,
                    k_pack=k_pack,
                    policy=T.GemmWarpPolicy.FullRow)

            T.copy(C_local, C[by * block_M, bx * block_N])

    @T.prim_func
    def gemm_fp8_ss(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((N, K), dtype),
            C: T.Tensor((M, N), accum_dtype),
    ):
        with T.Kernel(
                T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(A[by * block_M, k * block_K], A_shared)
                T.copy(B[bx * block_N, k * block_K], B_shared)
                T.gemm(
                    A_shared,
                    B_shared,
                    C_local,
                    transpose_B=True,
                    k_pack=k_pack,
                    policy=T.GemmWarpPolicy.FullRow)

            T.copy(C_local, C[by * block_M, bx * block_N])

    if gemm_type == "ss":
        return gemm_fp8_ss
    elif gemm_type == "rs":
        return gemm_fp8_rs
    else:
        raise ValueError(f"Invalid gemm_type: {gemm_type}")


def test_gemm_fp8(M, N, K):
    kernel = fp8_matmul(M, N, K)
    a = (torch.randn(M, K, dtype=torch.float16, device='cuda') *
         0.01).to(dtype=torch.float8_e4m3fnuz)
    b = (torch.randn(N, K, dtype=torch.float16, device='cuda') *
         0.01).to(dtype=torch.float8_e4m3fnuz)
    c = kernel(a, b)
    ref_c = ref_program(a, b)
    torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
    print("passed~")


if __name__ == "__main__":
    test_gemm_fp8(512, 512, 512)