example_gemm_sp.py 5.03 KB
Newer Older
1
2
3
4
5
6
7
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
import argparse

import tilelang
import tilelang.language as T

8
from tilelang.layout import make_cutlass_metadata_layout
9
from tilelang.utils.sparse import compress, randn_semi_sparse
10
11
12
13
14
15
16
from tilelang.contrib import nvcc
from triton.testing import do_bench

import torch

arch = nvcc.get_target_compute_version()

17
DEFAULT_CONFIG = {  # take best config from autotune script
18
    "4090": {
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
        "float": {
            "block_M": 128,
            "block_N": 64,
            "block_K": 64,
            "num_stages": 1,
            "thread_num": 128,
            "policy": T.GemmWarpPolicy.Square,
            "enable_rasterization": True,
        },
        "float16": {
            "block_M": 256,
            "block_N": 128,
            "block_K": 64,
            "num_stages": 2,
            "thread_num": 128,
            "policy": T.GemmWarpPolicy.Square,
            "enable_rasterization": True,
36
37
38
        },
    },
    "h20": {
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        "float": {
            "block_M": 128,
            "block_N": 64,
            "block_K": 128,
            "num_stages": 3,
            "thread_num": 128,
            "policy": T.GemmWarpPolicy.Square,
            "enable_rasterization": True,
        },
        "float16": {
            "block_M": 128,
            "block_N": 64,
            "block_K": 128,
            "num_stages": 3,
            "thread_num": 128,
            "policy": T.GemmWarpPolicy.Square,
            "enable_rasterization": True,
56
        },
57
    },
58
59
}

60
61
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}

62
63

@tilelang.jit(out_idx=[-1])
64
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization):
65
66
67
68
    e_factor, e_dtype = ARCH_INFO[arch]

    @T.prim_func
    def gemm_sp_fp16(
69
70
71
72
        A_sparse: T.Tensor((M, K // 2), "float16"),
        E: T.Tensor((M, K // e_factor), e_dtype),
        B: T.Tensor((K, N), "float16"),
        C: T.Tensor((M, N), accum_dtype),
73
74
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
75
            A_shared = T.alloc_shared((block_M, block_K // 2), "float16")
76
            E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
77
            B_shared = T.alloc_shared((block_K, block_N), "float16")
78
79
80
81
82
83
            C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            T.clear(C_local)
            T.disable_warp_group_reg_alloc()
            T.use_swizzle(panel_size=10, enable=enable_rasterization)
84
85
86
87
88
89
            T.annotate_layout(
                {
                    E: make_cutlass_metadata_layout(E, mma_dtype="float16", block_k=block_K, arch=arch),
                    E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", block_k=block_K, arch=arch),
                }
            )
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
                T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
                T.copy(B[k * block_K, bx * block_N], B_shared)
                T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False, policy=policy)

            T.copy(C_local, C_shared)
            T.copy(C_shared, C[by * block_M, bx * block_N])

    return gemm_sp_fp16


def main():
    parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
    parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
    parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
    parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
107
    parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")
108
    parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090")
109
    args = parser.parse_args()
110
    kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype])
111

112
113
    a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half)
    b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half)
114

115
    a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]["block_K"], arch=arch)
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    c = kernel(a_sparse, e, b)

    ref_c = a @ b

    assert not c.isnan().any(), "Reference result contains NaNs, please report an issue"
    torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2)
    print(f"Precision check passed. diff: {(c - ref_c).abs().mean()}")

    latency = do_bench(lambda: kernel(a_sparse, e, b))
    ref_latency = do_bench(lambda: a @ b)

    total_flops = 2 * args.m * args.n * args.k
    tflops = total_flops / latency / 1e9
    ref_tflops = total_flops / ref_latency / 1e9
130
131
    print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s")
    print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s")
132
133
134
135


if __name__ == "__main__":
    main()