"docs/source/reference.rst" did not exist on "715b1899ed1a3caa29426a1c2e68a236bec87088"
example_gemm.py 6.83 KB
Newer Older
1
2
3
4
import argparse
import torch
import itertools
import tilelang as tl
5
import tilelang.language as T
yyttt6's avatar
yyttt6 committed
6
from tilelang.autotuner import AutoTuner
7
8
9
10
11
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization


12
13
def ref_program(A, B):
    return A @ B.T
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94


def get_configs(M, N, K, with_roller=False):
    if with_roller:
        arch = CUDA("cuda")
        topk = 10
        carve_template = MatmulTemplate(
            M=M,
            N=N,
            K=K,
            in_dtype="float16",
            out_dtype="float16",
            accum_dtype="float",
        ).with_arch(arch)

        func = carve_template.equivalent_function()
        assert func is not None, "Function is None"
        roller_hints = carve_template.recommend_hints(topk=topk)
        if roller_hints is None:
            raise ValueError("No Roller Hints Found for TensorCore Scheduling")
        configs = []
        for hint in roller_hints:
            config = {}
            block_m, block_n = hint.block
            warp_m, warp_n = hint.warp
            # block_rows, block_cols represents warp partitioning
            block_rows, block_cols = block_m // warp_m, block_n // warp_n
            config["block_M"] = block_m
            config["block_N"] = block_n
            config["block_K"] = hint.rstep[0]
            config["num_stages"] = hint.pipeline_stage
            config["thread_num"] = block_rows * block_cols * 32
            config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
            configs.append(config)
        for config in configs:
            print(config)
    else:
        block_M = [64, 128, 256]
        block_N = [64, 128, 256]
        block_K = [32, 64]
        num_stages = [0, 1, 2, 3]
        thread_num = [128, 256]
        enable_rasterization = [True, False]
        _configs = list(
            itertools.product(
                block_M,
                block_N,
                block_K,
                num_stages,
                thread_num,
                enable_rasterization,
            ))

        configs = [
            {
                "block_M": c[0],
                "block_N": c[1],
                "block_K": c[2],
                "num_stages": c[3],
                "thread_num": c[4],
                "enable_rasteration": c[5],  # keep param name for backward-compat
            } for c in _configs
        ]
    return configs


def get_best_config(M, N, K, with_roller=False):

    def kernel(
        block_M=None,
        block_N=None,
        block_K=None,
        num_stages=None,
        thread_num=None,
        enable_rasteration=None,
    ):
        dtype = "float16"
        accum_dtype = "float"

        @T.prim_func
        def main(
95
96
97
                A: T.Tensor((M, K), dtype),
                B: T.Tensor((N, K), dtype),
                C: T.Tensor((M, N), dtype),
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        ):
            with T.Kernel(
                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
                A_shared = T.alloc_shared((block_M, block_K), dtype)
                B_shared = T.alloc_shared((block_N, block_K), dtype)
                C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
                C_shared = T.alloc_shared((block_M, block_N), dtype)
                T.use_swizzle(panel_size=10, enable=enable_rasteration)
                T.clear(C_local)
                for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                    T.copy(A[by * block_M, k * block_K], A_shared)
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                    T.gemm(
                        A_shared,
                        B_shared,
                        C_local,
                        transpose_B=True,
                    )
                T.copy(C_local, C_shared)
                T.copy(C_shared, C[by * block_M, bx * block_N])

        return main

yyttt6's avatar
yyttt6 committed
121
122
123
124
125
126
127
128
129
    autotuner = AutoTuner.from_kernel(
        kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args(
            out_idx=[-1],
            supply_type=tl.TensorSupplyType.Integer,
            ref_prog=ref_program,
            skip_check=False,
            target="auto",
        )
    return autotuner.run(warmup=3, rep=20)
130
131
132
133
134
135
136
137
138
139
140
141
142


def matmul(M,
           N,
           K,
           block_M,
           block_N,
           block_K,
           num_stages,
           thread_num,
           enable_rasteration,
           dtype="float16",
           accum_dtype="float"):
143

144
145
    @T.prim_func
    def main(
146
147
148
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((N, K), dtype),
            C: T.Tensor((M, N), dtype),
149
    ):
150
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
151
            A_shared = T.alloc_shared((block_M, block_K), dtype)
152
            B_shared = T.alloc_shared((block_N, block_K), dtype)
153
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
154
155
            C_shared = T.alloc_shared((block_M, block_N), dtype)
            T.use_swizzle(panel_size=10, enable=enable_rasteration)
156
            T.clear(C_local)
157
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
158
                T.copy(A[by * block_M, k * block_K], A_shared)
159
160
161
162
163
164
165
166
167
                T.copy(B[bx * block_N, k * block_K], B_shared)
                T.gemm(
                    A_shared,
                    B_shared,
                    C_local,
                    transpose_B=True,
                )
            T.copy(C_local, C_shared)
            T.copy(C_shared, C[by * block_M, bx * block_N])
168
169
170

    return main

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
    parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
    parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
    parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
    parser.add_argument(
        "--use_autotune",
        action="store_true",
        default=True,
        help="Whether to use autotune for matmul configs")
    parser.add_argument(
        "--with_roller",
        action="store_true",
        default=True,
        help="Whether to enable BitBLAS roller for search space")
    args = parser.parse_args()
    M, N, K = args.m, args.n, args.k
    a = torch.randn(M, K).cuda().half()
    b = torch.randn(N, K).cuda().half()
    configs = []
    use_autotune = args.use_autotune
    with_roller = args.with_roller
    if use_autotune:
yyttt6's avatar
yyttt6 committed
195
196
197
        result = get_best_config(M, N, K, with_roller)
        print(f"best latency {result.latency}")
        kernel = result.kernel
198
    else:
yyttt6's avatar
yyttt6 committed
199
        kernel = tl.compile(matmul(M, N, K, 128, 128, 32, 3, 128, True), out_idx=-1)
200
201

    out_c = kernel(a, b)
yyttt6's avatar
yyttt6 committed
202
    ref_c = ref_program(a, b)
203
    torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)