benchmark_matmul.py 8.14 KB
Newer Older
1
2
3
4
import argparse
import itertools
import logging

5
import tilelang
6
import tilelang.language as T
7
8
from tilelang.autotuner import autotune
from tilelang import jit
9

10
11
12
13
14
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


15
def ref_program(A, B):
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    """
    A reference matrix multiplication program, used to compare performance.

    Parameters
    ----------
    A : numpy.ndarray
        The matrix with shape (M, K).
    B : numpy.ndarray
        The matrix with shape (N, K).

    Returns
    -------
    np.ndarray
        The result of A @ B.T, shape (M, N).
    """
    return A @ B.T


34
def get_configs(args, kwargs):
35
36
    """
    Generate a list of configuration dictionaries that will be used for tuning.
37

38
39
40
41
42
43
44
45
46
47
48
    Parameters
    ----------
    with_roller : bool
        Whether to enable bitblas roller to deduce search spaces

    Returns
    -------
    list of dict
        Each configuration dict includes various block sizes, pipeline stages,
        thread numbers, and other parameters to explore during autotuning.
    """
49
50
    M, N, K, with_roller = args[:4]

51
    if with_roller:
52
53
        from tilelang.carver.template import MatmulTemplate
        from tilelang.carver.arch import CUDA
54
        from tilelang.carver.arch import CDNA
55
        from tilelang.carver.roller.rasterization import NoRasterization
56
57
        import torch

58
        arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
59
        topk = 10
60

61
        carve_template = MatmulTemplate(
62
63
64
65
66
            M=M,
            N=N,
            K=K,
            in_dtype="float16",
            out_dtype="float16",
67
            accum_dtype="float",
68
        ).with_arch(arch)
69

70
71
72
73
        func = carve_template.equivalent_function()
        assert func is not None, "Function is None"

        roller_hints = carve_template.recommend_hints(topk=topk)
74
75
76

        if roller_hints is None:
            raise ValueError("No Roller Hints Found for TensorCore Scheduling")
77

78
79
80
81
82
        configs = []
        for hint in roller_hints:
            config = {}
            block_m, block_n = hint.block
            warp_m, warp_n = hint.warp
83
84
            # block_rows, block_cols represents warp partitioning
            block_rows, block_cols = block_m // warp_m, block_n // warp_n
85
86
87
            config["block_M"] = block_m
            config["block_N"] = block_n
            config["block_K"] = hint.rstep[0]
88
89
90
            config["num_stages"] = hint.pipeline_stage
            config["thread_num"] = block_rows * block_cols * 32
            config["policy"] = T.GemmWarpPolicy.from_warp_partition(block_rows, block_cols)
91
92
93
94
95
            config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
            configs.append(config)
        for config in configs:
            print(config)
    else:
96
97
98
99
100
101
102
103
104
        iter_params = dict(
            block_M=[64, 128, 256],
            block_N=[64, 128, 256],
            block_K=[32, 64],
            num_stages=[0, 1, 2, 3],
            thread_num=[128, 256],
            policy=[T.GemmWarpPolicy.Square],
            enable_rasteration=[True, False],
        )
105
        return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
106
107
108
    return configs


109
110
111
112
113
@autotune(
    configs=get_configs,
    warmup=3,
    rep=20,
)
114
115
116
@jit(
    out_idx=[2],
)
117
118
119
120
121
122
123
124
125
126
127
128
129
def matmul(
    M,
    N,
    K,
    with_roller,
    block_M=None,
    block_N=None,
    block_K=None,
    num_stages=None,
    thread_num=None,
    policy=None,
    enable_rasteration=None,
):
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    """
    Create an autotuned matrix multiplication kernel for matrices of shape:
      - A: (M, K)
      - B: (N, K)
      - C: (M, N)

    Parameters
    ----------
    M : int
        The dimension M of the matrix multiplication.
    N : int
        The dimension N of the matrix multiplication.
    K : int
        The dimension K of the matrix multiplication.

    Returns
    -------
    (best_latency, best_config, ref_latency)
        best_latency : float
            The best latency found among the tuned configurations.
        best_config : dict
            The parameter configuration that yielded best_latency.
        ref_latency : float
            The baseline latency of the reference program (for computing speedup).
    """

156
157
158
159
    # Use half-precision for input data to reduce memory bandwidth,
    # accumulate in float for better numerical accuracy
    dtype = "float16"
    accum_dtype = "float"
160

161
162
    @T.prim_func
    def main(
163
164
165
        A: T.Tensor((M, K), dtype),
        B: T.Tensor((N, K), dtype),
        C: T.Tensor((M, N), dtype),
166
167
    ):
        """
168
169
170
171
172
173
174
175
        The compiled TVM function for block-level matrix multiplication.

        - We divide the entire (M, N) domain into blocks of shape
            (block_M, block_N).
        - Each block has its own allocated shared memory for sub-blocks
            of A and B.
        - The partial results go into C_local, and then we copy them back
            to global memory C.
176
        """
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        # Bind x-dimension to block index in N,
        #     y-dimension to block index in M.
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
            # Allocate shared memory for A sub-block of shape (block_M, block_K)
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            # Allocate shared memory for B sub-block of shape (block_N, block_K)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            # Allocate a local fragment for intermediate accumulation
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            # Allocate a shared memory for C sub-block of shape (block_M, block_N)
            C_shared = T.alloc_shared((block_M, block_N), dtype)

            # Enable (or disable) swizzling optimization
            T.use_swizzle(panel_size=10, enable=enable_rasteration)
191
192
            # to utilize swizzle tma layout
            T.annotate_layout({C_shared: tilelang.layout.make_swizzled_layout(C_shared)})
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

            # Clear out the accumulation buffer
            T.clear(C_local)

            # Loop over sub-blocks in K dimension, pipelined by num_stages
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                # Load a sub-block of A from global memory into A_shared
                T.copy(A[by * block_M, k * block_K], A_shared)
                # Load a sub-block of B from global memory into B_shared
                T.copy(B[bx * block_N, k * block_K], B_shared)
                # Perform a partial matrix multiplication:
                #   C_local += A_shared @ B_shared^T
                T.gemm(
                    A_shared,
                    B_shared,
                    C_local,
                    transpose_B=True,
                    policy=policy,
                )
            # Write back the results from C_local to the global memory C
            T.copy(C_local, C_shared)
            T.copy(C_shared, C[by * block_M, bx * block_N])

    return main
217
218
219
220
221


if __name__ == "__main__":
    # Parse command-line arguments for matrix dimensions
    parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
222
223
224
    parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
    parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
    parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
225
226
227
228
229
230
231
232
    parser.add_argument(
        "--with_roller",
        action="store_true",
        help="Whether to enable BitBLAS roller for search space",
    )
    args = parser.parse_args()

    M, N, K = args.m, args.n, args.k
233
    with_roller = args.with_roller
234
235
236
237
238

    # Compute total floating-point operations to measure throughput
    total_flops = 2 * M * N * K

    # matmul(...) returns (best_latency, best_config, ref_latency)
239
240
241
242
    best_result = matmul(M, N, K, with_roller)
    best_latency = best_result.latency
    best_config = best_result.config
    ref_latency = best_result.ref_latency
243
244
245
246
247
248

    # Print out the benchmark results
    print(f"Best latency (s): {best_latency}")
    print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}")
    print(f"Best config: {best_config}")

249
250
    if ref_latency is not None:
        print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}")