benchmark_matmul_sp.py 9.92 KB
Newer Older
1
2
3
4
5
6
import argparse
import itertools
import logging
import torch
from triton.testing import do_bench

7
import tilelang
8
9
10
import tilelang.language as T
from tilelang.autotuner import autotune
from tilelang import jit
11
from tilelang.contrib import nvcc
12
from tilelang.layout import make_cutlass_metadata_layout
13

14
15
16
17
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

18
19
20
21
arch = nvcc.get_target_compute_version()

ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

def ref_program(A, B):
    """
    A reference matrix multiplication program, used to compare performance.

    Parameters
    ----------
    A : numpy.ndarray
        The matrix with shape (M, K).
    B : numpy.ndarray
        The matrix with shape (N, K).

    Returns
    -------
    np.ndarray
        The result of A @ B.T, shape (M, N).
    """
    return A @ B.T


def get_configs(M, N, K):
    """
    Generate a list of configuration dictionaries that will be used for tuning.
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    Parameters
    ----------
    with_roller : bool
        Whether to enable bitblas roller to deduce search spaces

    Returns
    -------
    list of dict
        Each configuration dict includes various block sizes, pipeline stages,
        thread numbers, and other parameters to explore during autotuning.
    """
    block_M = [64, 128, 256]
    block_N = [64, 128, 256]
    block_K = [64, 128]
    num_stages = [0, 1, 2, 3]
    thread_num = [128, 256]
    enable_rasterization = [True, False]
    policy = [T.GemmWarpPolicy.Square]
    _configs = list(
        itertools.product(
            block_M,
            block_N,
            block_K,
            num_stages,
            thread_num,
            policy,
            enable_rasterization,
73
74
        )
    )
75
76
77
78
79
80
81
82
83
84

    configs = [
        {
            "block_M": c[0],
            "block_N": c[1],
            "block_K": c[2],
            "num_stages": c[3],
            "thread_num": c[4],
            "policy": c[5],
            "enable_rasterization": c[6],  # keep param name for backward-compat
85
86
        }
        for c in _configs
87
88
89
90
    ]
    return configs


91
def matmul_sp(M, N, K, in_dtype, accum_dtype):
92
93
94
    """
    Create an autotuned matrix multiplication kernel for matrices of shape:
      - A: (M, K)
95
      - B: (K, N)
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
      - C: (M, N)

    Parameters
    ----------
    M : int
        The dimension M of the matrix multiplication.
    N : int
        The dimension N of the matrix multiplication.
    K : int
        The dimension K of the matrix multiplication.

    Returns
    -------
    (best_latency, best_config, ref_latency)
        best_latency : float
            The best latency found among the tuned configurations.
        best_config : dict
            The parameter configuration that yielded best_latency.
        ref_latency : float
            The baseline latency of the reference program (for computing speedup).
    """

    # Decorate the kernel with autotune & jit, specifying:
    #  - Tuning config list
    #  - Profiling keys
    #  - Warmup and repetition counts for better measurement
    #  - A reference program for correctness verification
    #  - The "tvm" profiler backend
    #  - HIP as the compilation target (modify as needed for your hardware)

    @autotune(
        configs=get_configs(M, N, K),
        warmup=3,
        rep=20,
    )
131
132
133
    @jit(
        out_idx=[2],
    )
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    def kernel(
        block_M=None,
        block_N=None,
        block_K=None,
        num_stages=None,
        thread_num=None,
        policy=None,
        enable_rasterization=None,
    ):
        """
        The actual kernel to compute C = A @ B^T.

        Parameters
        ----------
        block_M : int
            Block size in M dimension.
        block_N : int
            Block size in N dimension.
        block_K : int
            Block size in K dimension.
        num_stages : int
            Number of pipelined stages (for asynchronous load).
        thread_num : int
            Number of threads to use per block.
        k_pack : int
            K dimension packing factor to improve memory coalescing.

        Returns
        -------
        Function
            A TVM Tensor Language function (T.prim_func) that computes matmul.
        """
        # Use half-precision for input data to reduce memory bandwidth,
        # accumulate in float for better numerical accuracy
168
        e_factor, e_dtype = ARCH_INFO[arch]
169
170
171

        @T.prim_func
        def main(
172
173
174
175
            A_sparse: T.Tensor((M, K // 2), in_dtype),
            E: T.Tensor((M, K // e_factor), e_dtype),
            B: T.Tensor((K, N), in_dtype),
            C: T.Tensor((M, N), accum_dtype),
176
177
178
179
180
181
182
183
184
185
186
187
188
        ):
            """
            The compiled TVM function for block-level matrix multiplication.

            - We divide the entire (M, N) domain into blocks of shape
              (block_M, block_N).
            - Each block has its own allocated shared memory for sub-blocks
              of A and B.
            - The partial results go into C_local, and then we copy them back
              to global memory C.
            """
            # Bind x-dimension to block index in N,
            #     y-dimension to block index in M.
189
            with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
190
                # Allocate shared memory for A sub-block of shape (block_M, block_K)
191
                A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype)
192
                # Allocate shared memory for B sub-block of shape (block_N, block_K)
193
                B_shared = T.alloc_shared((block_K, block_N), in_dtype)
194
                # Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor)
195
                E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
196
197
198
                # Allocate a local fragment for intermediate accumulation
                C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
                # Allocate a shared memory for C sub-block of shape (block_M, block_N)
199
                C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
200
201
202

                # Clear out the accumulation buffer
                T.clear(C_local)
203
                T.disable_warp_group_reg_alloc()
204
205

                T.use_swizzle(panel_size=10, enable=enable_rasterization)
206
207
208
209
210
211
                T.annotate_layout(
                    {
                        E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K),
                        E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K),
                    }
                )
212
213
214
                # Loop over sub-blocks in K dimension, pipelined by num_stages
                for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                    # Load a sub-block of A from global memory into A_shared
215
                    T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
216
                    # Load a sub-block of E from global memory into E_shared
217
                    T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
218
                    # Load a sub-block of B from global memory into B_shared
219
                    T.copy(B[k * block_K, bx * block_N], B_shared)
220
                    # Perform a partial matrix multiplication:
221
                    #   C_local += A_shared @ B_shared
222
                    T.gemm_sp_v2(
223
224
225
226
                        A_shared,
                        E_shared,
                        B_shared,
                        C_local,
227
                        transpose_B=False,
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
                        policy=policy,
                    )
                # Write back the results from C_local to the global memory C
                T.copy(C_local, C_shared)
                T.copy(C_shared, C[by * block_M, bx * block_N])

        return main

    return kernel()


if __name__ == "__main__":
    # Parse command-line arguments for matrix dimensions
    parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
    parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
    parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
    parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
245
    parser.add_argument("--disable_cache", action="store_true")
246
    parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")
247
248
249
    parser.add_argument(
        "--bench_torch_sparse",
        type=str,
250
        choices=["cutlass", "cusparselt"],
251
        default=None,
252
        help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported",
253
    )
254
255
    args = parser.parse_args()

256
257
258
    if args.disable_cache:
        tilelang.disable_cache()

259
260
261
262
263
264
    M, N, K = args.m, args.n, args.k

    # Compute total floating-point operations to measure throughput
    total_flops = 2 * M * N * K

    # matmul(...) returns (best_latency, best_config, ref_latency)
265
    best_result = matmul_sp(M, N, K, T.float16, args.accum_dtype)
266
267
268
    best_latency = best_result.latency
    best_config = best_result.config
    A = torch.randn(M, K, dtype=torch.float16, device="cuda")
269
270
271
272
273
    B = torch.randn(K, N, dtype=torch.float16, device="cuda")
    ref_latency = do_bench(lambda: A @ B)

    if args.bench_torch_sparse is not None:
        from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
274
275

        if args.bench_torch_sparse == "cutlass":
276
277
278
            SparseSemiStructuredTensor._FORCE_CUTLASS = True
        A_sp = to_sparse_semi_structured(A, transposed=False)
        torch_sparse_latency = do_bench(lambda: A_sp @ B)
279
280
281
282
283
284

    # Print out the benchmark results
    print(f"Best latency (s): {best_latency}")
    print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}")
    print(f"Best config: {best_config}")

285
    if args.bench_torch_sparse is not None:
286
        print(f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}")
287
288

    print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}")