benchmark_matmul_sp.py 9.99 KB
Newer Older
1
2
3
4
5
6
import argparse
import itertools
import logging
import torch
from triton.testing import do_bench

7
import tilelang
8
9
10
import tilelang.language as T
from tilelang.autotuner import autotune
from tilelang import jit
11
from tilelang.contrib import nvcc
12
from tilelang.layout import make_cutlass_metadata_layout
13

14
15
16
17
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

18
19
20
21
arch = nvcc.get_target_compute_version()

ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

def ref_program(A, B):
    """
    A reference matrix multiplication program, used to compare performance.

    Parameters
    ----------
    A : numpy.ndarray
        The matrix with shape (M, K).
    B : numpy.ndarray
        The matrix with shape (N, K).

    Returns
    -------
    np.ndarray
        The result of A @ B.T, shape (M, N).
    """
    return A @ B.T


def get_configs(M, N, K):
    """
    Generate a list of configuration dictionaries that will be used for tuning.
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    Parameters
    ----------
    with_roller : bool
        Whether to enable bitblas roller to deduce search spaces

    Returns
    -------
    list of dict
        Each configuration dict includes various block sizes, pipeline stages,
        thread numbers, and other parameters to explore during autotuning.
    """
    block_M = [64, 128, 256]
    block_N = [64, 128, 256]
    block_K = [64, 128]
    num_stages = [0, 1, 2, 3]
    thread_num = [128, 256]
    enable_rasterization = [True, False]
    policy = [T.GemmWarpPolicy.Square]
    _configs = list(
        itertools.product(
            block_M,
            block_N,
            block_K,
            num_stages,
            thread_num,
            policy,
            enable_rasterization,
        ))

    configs = [
        {
            "block_M": c[0],
            "block_N": c[1],
            "block_K": c[2],
            "num_stages": c[3],
            "thread_num": c[4],
            "policy": c[5],
            "enable_rasterization": c[6],  # keep param name for backward-compat
        } for c in _configs
    ]
    return configs


89
def matmul_sp(M, N, K, in_dtype, accum_dtype):
90
91
92
    """
    Create an autotuned matrix multiplication kernel for matrices of shape:
      - A: (M, K)
93
      - B: (K, N)
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
      - C: (M, N)

    Parameters
    ----------
    M : int
        The dimension M of the matrix multiplication.
    N : int
        The dimension N of the matrix multiplication.
    K : int
        The dimension K of the matrix multiplication.

    Returns
    -------
    (best_latency, best_config, ref_latency)
        best_latency : float
            The best latency found among the tuned configurations.
        best_config : dict
            The parameter configuration that yielded best_latency.
        ref_latency : float
            The baseline latency of the reference program (for computing speedup).
    """

    # Decorate the kernel with autotune & jit, specifying:
    #  - Tuning config list
    #  - Profiling keys
    #  - Warmup and repetition counts for better measurement
    #  - A reference program for correctness verification
    #  - The "tvm" profiler backend
    #  - HIP as the compilation target (modify as needed for your hardware)

    @autotune(
        configs=get_configs(M, N, K),
        warmup=3,
        rep=20,
    )
    @jit(out_idx=[2],)
    def kernel(
        block_M=None,
        block_N=None,
        block_K=None,
        num_stages=None,
        thread_num=None,
        policy=None,
        enable_rasterization=None,
    ):
        """
        The actual kernel to compute C = A @ B^T.

        Parameters
        ----------
        block_M : int
            Block size in M dimension.
        block_N : int
            Block size in N dimension.
        block_K : int
            Block size in K dimension.
        num_stages : int
            Number of pipelined stages (for asynchronous load).
        thread_num : int
            Number of threads to use per block.
        k_pack : int
            K dimension packing factor to improve memory coalescing.

        Returns
        -------
        Function
            A TVM Tensor Language function (T.prim_func) that computes matmul.
        """
        # Use half-precision for input data to reduce memory bandwidth,
        # accumulate in float for better numerical accuracy
164
        e_factor, e_dtype = ARCH_INFO[arch]
165
166
167

        @T.prim_func
        def main(
168
                A_sparse: T.Tensor((M, K // 2), in_dtype),
169
                E: T.Tensor((M, K // e_factor), e_dtype),
170
                B: T.Tensor((K, N), in_dtype),
171
                C: T.Tensor((M, N), accum_dtype),
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        ):
            """
            The compiled TVM function for block-level matrix multiplication.

            - We divide the entire (M, N) domain into blocks of shape
              (block_M, block_N).
            - Each block has its own allocated shared memory for sub-blocks
              of A and B.
            - The partial results go into C_local, and then we copy them back
              to global memory C.
            """
            # Bind x-dimension to block index in N,
            #     y-dimension to block index in M.
            with T.Kernel(
                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):

                # Allocate shared memory for A sub-block of shape (block_M, block_K)
189
                A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype)
190
                # Allocate shared memory for B sub-block of shape (block_N, block_K)
191
                B_shared = T.alloc_shared((block_K, block_N), in_dtype)
192
                # Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor)
193
                E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
194
195
196
                # Allocate a local fragment for intermediate accumulation
                C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
                # Allocate a shared memory for C sub-block of shape (block_M, block_N)
197
                C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
198
199
200

                # Clear out the accumulation buffer
                T.clear(C_local)
201
                T.disable_warp_group_reg_alloc()
202
203
204
205

                T.use_swizzle(panel_size=10, enable=enable_rasterization)
                T.annotate_layout({
                    E:
206
                        make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K),
207
                    E_shared:
208
                        make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K),
209
210
211
212
                })
                # Loop over sub-blocks in K dimension, pipelined by num_stages
                for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                    # Load a sub-block of A from global memory into A_shared
213
                    T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
214
                    # Load a sub-block of E from global memory into E_shared
215
                    T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
216
                    # Load a sub-block of B from global memory into B_shared
217
                    T.copy(B[k * block_K, bx * block_N], B_shared)
218
                    # Perform a partial matrix multiplication:
219
                    #   C_local += A_shared @ B_shared
220
                    T.gemm_sp_v2(
221
222
223
224
                        A_shared,
                        E_shared,
                        B_shared,
                        C_local,
225
                        transpose_B=False,
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
                        policy=policy,
                    )
                # Write back the results from C_local to the global memory C
                T.copy(C_local, C_shared)
                T.copy(C_shared, C[by * block_M, bx * block_N])

        return main

    return kernel()


if __name__ == "__main__":
    # Parse command-line arguments for matrix dimensions
    parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
    parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
    parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
    parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    parser.add_argument("--disable_cache", action="store_true")
    parser.add_argument(
        "--accum_dtype",
        type=str,
        default="float",
        choices=["float", "float16"],
        help="Accumulation datatype")
    parser.add_argument(
        "--bench_torch_sparse",
        type=str,
        choices=['cutlass', 'cusparselt'],
        default=None,
        help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported"
    )
257
258
    args = parser.parse_args()

259
260
261
    if args.disable_cache:
        tilelang.disable_cache()

262
263
264
265
266
267
    M, N, K = args.m, args.n, args.k

    # Compute total floating-point operations to measure throughput
    total_flops = 2 * M * N * K

    # matmul(...) returns (best_latency, best_config, ref_latency)
268
    best_result = matmul_sp(M, N, K, "float16", args.accum_dtype)
269
270
271
    best_latency = best_result.latency
    best_config = best_result.config
    A = torch.randn(M, K, dtype=torch.float16, device="cuda")
272
273
274
275
276
277
278
279
280
    B = torch.randn(K, N, dtype=torch.float16, device="cuda")
    ref_latency = do_bench(lambda: A @ B)

    if args.bench_torch_sparse is not None:
        from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
        if args.bench_torch_sparse == 'cutlass':
            SparseSemiStructuredTensor._FORCE_CUTLASS = True
        A_sp = to_sparse_semi_structured(A, transposed=False)
        torch_sparse_latency = do_bench(lambda: A_sp @ B)
281
282
283
284
285
286

    # Print out the benchmark results
    print(f"Best latency (s): {best_latency}")
    print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}")
    print(f"Best config: {best_config}")

287
288
289
290
291
292
    if args.bench_torch_sparse is not None:
        print(
            f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}"
        )

    print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}")