benchmark_matmul_sp.py 10 KB
Newer Older
1
2
3
4
5
6
import argparse
import itertools
import logging
import torch
from triton.testing import do_bench

7
import tilelang
8
9
10
import tilelang.language as T
from tilelang.autotuner import autotune
from tilelang import jit
11
from tilelang.contrib import nvcc
12
from tilelang.layout import make_metadata_layout
13

14
15
16
17
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

18
19
20
21
arch = nvcc.get_target_compute_version()

ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

def ref_program(A, B):
    """
    A reference matrix multiplication program, used to compare performance.

    Parameters
    ----------
    A : numpy.ndarray
        The matrix with shape (M, K).
    B : numpy.ndarray
        The matrix with shape (N, K).

    Returns
    -------
    np.ndarray
        The result of A @ B.T, shape (M, N).
    """
    return A @ B.T


def get_configs(M, N, K):
    """
    Generate a list of configuration dictionaries that will be used for tuning.
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    Parameters
    ----------
    with_roller : bool
        Whether to enable bitblas roller to deduce search spaces

    Returns
    -------
    list of dict
        Each configuration dict includes various block sizes, pipeline stages,
        thread numbers, and other parameters to explore during autotuning.
    """
    block_M = [64, 128, 256]
    block_N = [64, 128, 256]
    block_K = [64, 128]
    num_stages = [0, 1, 2, 3]
    thread_num = [128, 256]
    enable_rasterization = [True, False]
    policy = [T.GemmWarpPolicy.Square]
    _configs = list(
        itertools.product(
            block_M,
            block_N,
            block_K,
            num_stages,
            thread_num,
            policy,
            enable_rasterization,
        ))

    configs = [
        {
            "block_M": c[0],
            "block_N": c[1],
            "block_K": c[2],
            "num_stages": c[3],
            "thread_num": c[4],
            "policy": c[5],
            "enable_rasterization": c[6],  # keep param name for backward-compat
        } for c in _configs
    ]
    return configs


89
def matmul_sp(M, N, K, accum_dtype):
90
91
92
    """
    Create an autotuned matrix multiplication kernel for matrices of shape:
      - A: (M, K)
93
      - B: (K, N)
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
      - C: (M, N)

    Parameters
    ----------
    M : int
        The dimension M of the matrix multiplication.
    N : int
        The dimension N of the matrix multiplication.
    K : int
        The dimension K of the matrix multiplication.

    Returns
    -------
    (best_latency, best_config, ref_latency)
        best_latency : float
            The best latency found among the tuned configurations.
        best_config : dict
            The parameter configuration that yielded best_latency.
        ref_latency : float
            The baseline latency of the reference program (for computing speedup).
    """

    # Decorate the kernel with autotune & jit, specifying:
    #  - Tuning config list
    #  - Profiling keys
    #  - Warmup and repetition counts for better measurement
    #  - A reference program for correctness verification
    #  - The "tvm" profiler backend
    #  - HIP as the compilation target (modify as needed for your hardware)

    @autotune(
        configs=get_configs(M, N, K),
        warmup=3,
        rep=20,
    )
    @jit(out_idx=[2],)
    def kernel(
        block_M=None,
        block_N=None,
        block_K=None,
        num_stages=None,
        thread_num=None,
        policy=None,
        enable_rasterization=None,
    ):
        """
        The actual kernel to compute C = A @ B^T.

        Parameters
        ----------
        block_M : int
            Block size in M dimension.
        block_N : int
            Block size in N dimension.
        block_K : int
            Block size in K dimension.
        num_stages : int
            Number of pipelined stages (for asynchronous load).
        thread_num : int
            Number of threads to use per block.
        k_pack : int
            K dimension packing factor to improve memory coalescing.

        Returns
        -------
        Function
            A TVM Tensor Language function (T.prim_func) that computes matmul.
        """
        # Use half-precision for input data to reduce memory bandwidth,
        # accumulate in float for better numerical accuracy
        dtype = "float16"
165
        e_factor, e_dtype = ARCH_INFO[arch]
166
167
168
169

        @T.prim_func
        def main(
                A_sparse: T.Tensor((M, K // 2), dtype),
170
171
172
                E: T.Tensor((M, K // e_factor), e_dtype),
                B: T.Tensor((K, N), dtype),
                C: T.Tensor((M, N), accum_dtype),
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        ):
            """
            The compiled TVM function for block-level matrix multiplication.

            - We divide the entire (M, N) domain into blocks of shape
              (block_M, block_N).
            - Each block has its own allocated shared memory for sub-blocks
              of A and B.
            - The partial results go into C_local, and then we copy them back
              to global memory C.
            """
            # Bind x-dimension to block index in N,
            #     y-dimension to block index in M.
            with T.Kernel(
                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):

                # Allocate shared memory for A sub-block of shape (block_M, block_K)
                A_shared = T.alloc_shared((block_M, block_K // 2), dtype)
                # Allocate shared memory for B sub-block of shape (block_N, block_K)
192
                B_shared = T.alloc_shared((block_K, block_N), dtype)
193
                # Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor)
194
                E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
195
196
197
                # Allocate a local fragment for intermediate accumulation
                C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
                # Allocate a shared memory for C sub-block of shape (block_M, block_N)
198
                C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
199
200
201

                # Clear out the accumulation buffer
                T.clear(C_local)
202
                T.disable_warp_group_reg_alloc()
203
204
205
206
207

                T.use_swizzle(panel_size=10, enable=enable_rasterization)
                T.annotate_layout({
                    E:
                        make_metadata_layout(
208
                            E, mma_dtype="float16", backend="cutlass", block_k=block_K),
209
210
                    E_shared:
                        make_metadata_layout(
211
                            E_shared, mma_dtype="float16", backend="cutlass", block_k=block_K),
212
213
214
215
                })
                # Loop over sub-blocks in K dimension, pipelined by num_stages
                for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                    # Load a sub-block of A from global memory into A_shared
216
                    T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
217
                    # Load a sub-block of E from global memory into E_shared
218
                    T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
219
                    # Load a sub-block of B from global memory into B_shared
220
                    T.copy(B[k * block_K, bx * block_N], B_shared)
221
                    # Perform a partial matrix multiplication:
222
                    #   C_local += A_shared @ B_shared
223
224
225
226
227
                    T.gemm_sp(
                        A_shared,
                        E_shared,
                        B_shared,
                        C_local,
228
                        transpose_B=False,
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
                        policy=policy,
                    )
                # Write back the results from C_local to the global memory C
                T.copy(C_local, C_shared)
                T.copy(C_shared, C[by * block_M, bx * block_N])

        return main

    return kernel()


if __name__ == "__main__":
    # Parse command-line arguments for matrix dimensions
    parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
    parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
    parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
    parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    parser.add_argument("--disable_cache", action="store_true")
    parser.add_argument(
        "--accum_dtype",
        type=str,
        default="float",
        choices=["float", "float16"],
        help="Accumulation datatype")
    parser.add_argument(
        "--bench_torch_sparse",
        type=str,
        choices=['cutlass', 'cusparselt'],
        default=None,
        help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported"
    )
260
261
    args = parser.parse_args()

262
263
264
    if args.disable_cache:
        tilelang.disable_cache()

265
266
267
268
269
270
    M, N, K = args.m, args.n, args.k

    # Compute total floating-point operations to measure throughput
    total_flops = 2 * M * N * K

    # matmul(...) returns (best_latency, best_config, ref_latency)
271
    best_result = matmul_sp(M, N, K, args.accum_dtype)
272
273
274
    best_latency = best_result.latency
    best_config = best_result.config
    A = torch.randn(M, K, dtype=torch.float16, device="cuda")
275
276
277
278
279
280
281
282
283
    B = torch.randn(K, N, dtype=torch.float16, device="cuda")
    ref_latency = do_bench(lambda: A @ B)

    if args.bench_torch_sparse is not None:
        from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
        if args.bench_torch_sparse == 'cutlass':
            SparseSemiStructuredTensor._FORCE_CUTLASS = True
        A_sp = to_sparse_semi_structured(A, transposed=False)
        torch_sparse_latency = do_bench(lambda: A_sp @ B)
284
285
286
287
288
289

    # Print out the benchmark results
    print(f"Best latency (s): {best_latency}")
    print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}")
    print(f"Best config: {best_config}")

290
291
292
293
294
295
    if args.bench_torch_sparse is not None:
        print(
            f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}"
        )

    print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}")