test_tilelang_gemm_mfma_preshuffle.py 9.9 KB
Newer Older
1
import pytest
2
3
4
5
6
import torch
import tilelang.testing
from tilelang import tvm as tvm
import tilelang.language as T
from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
7
from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from tilelang.transform import simplify_prim_func

tilelang.testing.set_random_seed(0)


@simplify_prim_func
def tl_matmul(
    M,
    N,
    K,
    in_dtype,
    out_dtype,
    accum_dtype,
    a_transposed=False,
    b_transposed=True,
    k_pack=1,
    b_preshuffle=False,
25
    b_g2l_load=False,
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
):
    micro_size_x = micro_size_y = micro_size_k = 16

    if in_dtype in {"float8_e4m3fnuz", "int8"}:
        micro_size_k = 32

    block_row_warps = 2
    block_col_warps = 2
    warp_row_tiles = 32
    warp_col_tiles = 32

    # for preshuffle_b, warp_layout = {1, 4}
    if b_preshuffle:
        block_row_warps = 1
        block_col_warps = 4
41
42
        warp_row_tiles = 64
        warp_col_tiles = 16
43

44
    chunk = 256 * k_pack
45
46
47
48
49
50
51
52
53
54
55

    pack_size_k = micro_size_k * k_pack

    shared_scope = "shared"

    block_M = block_row_warps * warp_row_tiles
    block_N = block_col_warps * warp_col_tiles
    block_K = chunk

    A_shape = (K, M) if a_transposed else (M, K)
    if b_preshuffle:
56
57
58
59
60
        B_shape = (
            (N // micro_size_y, K // pack_size_k, micro_size_y, pack_size_k)
            if b_transposed
            else (K // pack_size_k, N // micro_size_y, pack_size_k, micro_size_y)
        )
61
62
    else:
        B_shape = (N, K) if b_transposed else (K, N)
63

64
65
    A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K)
    if b_preshuffle:
66
67
68
69
70
        B_shared_shape = (
            (block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k)
            if b_transposed
            else (block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y)
        )
71
72
73
74
75
76
77
78
79
80
81
82
    else:
        B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)

    warp_size = 64
    threads = warp_size * (block_row_warps * block_col_warps)
    local_size_a = (k_pack * micro_size_x * micro_size_k) // warp_size
    local_size_b = (k_pack * micro_size_y * micro_size_k) // warp_size
    local_size_c = (micro_size_x * micro_size_y) // warp_size
    warp_rows = warp_row_tiles // micro_size_x
    warp_cols = warp_col_tiles // micro_size_y

    # MMA Wrapper to Auto Generate Code for MMA
83
    mfma_emitter = MatrixCorePreshuffleIntrinEmitter(
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        a_dtype=in_dtype,
        b_dtype=in_dtype,
        accum_dtype=accum_dtype,
        a_transposed=a_transposed,
        b_transposed=b_transposed,
        block_row_warps=block_row_warps,
        block_col_warps=block_col_warps,
        warp_row_tiles=warp_row_tiles,
        warp_col_tiles=warp_col_tiles,
        chunk=chunk,
        k_pack=k_pack,
        b_preshuffle=b_preshuffle,
    )

    @T.prim_func
    def main(
100
101
102
        A: T.Tensor(A_shape, in_dtype),
        B: T.Tensor(B_shape, in_dtype),
        C: T.Tensor((M, N), out_dtype),
103
104
105
106
107
108
109
110
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
            A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
            B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
            C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

111
112
113
114
115
            T.annotate_layout(
                {
                    A_shared: make_swizzle_layout(A_shared),
                }
            )
116

117
118
119
            num_ko = K // block_K
            num_ki = block_K // (k_pack * micro_size_k)

120
121
122
123
124
            # Improve L2 Cache
            T.use_swizzle(panel_size=10)

            T.clear(C_local)

125
            for ko in T.Pipelined(num_ko, num_stages=0):
126
127
128
129
130
131
132
                # Load A into shared memory
                if a_transposed:
                    T.copy(A[ko * block_K, by * block_M], A_shared)
                else:
                    T.copy(A[by * block_M, ko * block_K], A_shared)

                # Load B into shared memory
133
                if b_g2l_load is False:
134
                    if b_transposed:
135
136
                        for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k):
                            B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j, ko * block_K // pack_size_k + k, jj, kk]
137
                    else:
138
139
                        for k, j, kk, jj in T.Parallel(block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y):
                            B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k, bx * block_N // micro_size_y + j, kk, jj]
140

141
142
                for ki in T.serial(0, num_ki):
                    # Load A S2L
143
144
145
146
147
148
                    mfma_emitter.ldmatrix_a(
                        A_local,
                        A_shared,
                        ki,
                    )

149
150
151
152
153
154
155
156
157
158
                    if b_g2l_load:
                        # Load B G2L
                        mfma_emitter.ldmatrix_b(B_local, B, ki + ko * num_ki, pid_m=by, pid_n=bx)
                    else:
                        # Load B S2L
                        mfma_emitter.ldmatrix_b(
                            B_local,
                            B_shared,
                            ki,
                        )
159
160
161
162
163

                    # Perform Matrix Multiplication
                    mfma_emitter.mfma(A_local, B_local, C_local)

            # Perform STMatrix
164
165
166
167
168
169
            mfma_emitter.stmatrix(
                C_local,
                C,
                pid_m=by,
                pid_n=bx,
            )
170
171
172
173
174

    return main


def shuffle_weight(
175
176
177
178
    x: torch.Tensor,
    layout=(16, 32),
    k_pack=1,
    is_transpose=False,
179
180
181
182
183
184
185
186
187
188
189
190
191
192
) -> torch.Tensor:
    IN, IK = layout
    BK = IK * k_pack
    BN = IN

    N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2])
    assert N % BN == 0
    assert K % BK == 0

    x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN)
    x = x.permute(0, 2, 1, 3)
    return x.contiguous()


193
194
195
196
197
198
199
200
201
202
203
204
205
206
def assert_tl_matmul_correctness(
    M,
    N,
    K,
    in_dtype,
    out_dtype,
    accum_dtype="float32",
    a_transposed=False,
    b_transposed=True,
    k_pack=1,
    b_preshuffle=False,
    b_g2l_load=False,
):
    matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle, b_g2l_load)
207
208
209
210
211
212
213
214
215
216
    print(matmul)
    kernel = tilelang.compile(matmul)
    src_code = kernel.get_kernel_source()
    # src_code is the generated cuda source
    assert src_code is not None
    A_shape = (K, M) if a_transposed else (M, K)
    B_shape = (N, K) if b_transposed else (K, N)
    if in_dtype == "int8":
        A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
        B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
217
218
219
    elif in_dtype == "float8_e4m3fnuz":
        A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
        B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    else:
        A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype))
        B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype))

    C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))

    B_preshuffle = B
    if b_preshuffle:
        B_preshuffle = shuffle_weight(B_preshuffle, k_pack=k_pack, is_transpose=b_transposed)
        kernel(A, B_preshuffle, C)
    else:
        kernel(A, B, C)

    print(kernel.get_kernel_source())

    profiler = kernel.get_profiler()

    latency = profiler.do_bench()

    # Ensure that the latency is not None
    assert latency is not None

    if a_transposed and b_transposed:
        # Get Reference Result
244
        ref_c = torch.matmul(A.T.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype))
245
246
    elif a_transposed and not b_transposed:
        # Get Reference Result
247
        ref_c = torch.matmul(A.Tto(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
248
249
    elif not a_transposed and b_transposed:
        # Get Reference Result
250
        ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype))
251
252
253
254
255
256
    else:
        # Get Reference Result
        ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))

    print(C)
    print(ref_c)
257

258
259
260
    torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)


261
262
263
264
265
266
267
268
269
270
271
272
273
@pytest.mark.parametrize(
    "M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle, b_g2l_load",
    [
        (256, 256, 512, "int8", "int32", "int32", False, True, 1, True, False),
        (256, 256, 512, "int8", "int32", "int32", False, False, 1, True, False),
        (256, 256, 512, "int8", "int32", "int32", False, True, 2, True, False),
        (256, 256, 512, "int8", "int32", "int32", False, False, 2, True, False),
        (256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, True, 1, True, False),
        (256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, False, 1, True, False),
        (256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, True, 2, True, False),
        (256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, False, 2, True, False),
    ],
)
274
@tilelang.testing.requires_rocm
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def test_assert_tl_matmul(
    M,
    N,
    K,
    in_dtype,
    out_dtype,
    accum_dtype,
    a_transposed,
    b_transposed,
    k_pack,
    b_preshuffle,
    b_g2l_load,
):
    assert_tl_matmul_correctness(
        M,
        N,
        K,
        in_dtype,
        out_dtype,
        accum_dtype=accum_dtype,
        a_transposed=a_transposed,
        b_transposed=b_transposed,
        k_pack=k_pack,
        b_preshuffle=b_preshuffle,
        b_g2l_load=b_g2l_load,
    )
301

302
303
304

if __name__ == "__main__":
    tilelang.testing.main()