test_tilelang_tilelibrary_gemm_sp.py 12.1 KB
Newer Older
1
import pytest
2
3
4
5
import torch
import tilelang
import tilelang.testing

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse
from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.tensor import torch_assert_close, map_torch_type
from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter

torch.backends.cuda.matmul.allow_tf32 = False
# torch.manual_seed(42)  # only enable when debugging


def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
    is_8bit = "8" in in_dtype
    is_unsigned = "uint" in in_dtype
    is_int = "int" in in_dtype
    if is_int:
        if is_8bit:
            low, high = (0, 4) if is_unsigned else (-2, 2)
        else:
            low, high = (0, 128) if is_unsigned else (-64, 64)
24
25
        A = randint_semi_sparse(M, K, low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A)
        B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda")
26
    else:
27
28
        A = randn_semi_sparse(M, K, dtype=torch.float32, device="cuda", transposed=trans_A).to(map_torch_type(in_dtype))
        B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(map_torch_type(in_dtype))
29
    return A, B
30
31


32
def matmul_sp_sm90(
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    threads,
    trans_A,
    trans_B,
):
    E_factor = 4 if in_dtype == "float32" else 8
    A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
    B_shape = (K, N) if not trans_B else (N, K)
    A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
    B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)

    import tilelang.language as T

    @T.prim_func
    def main(
57
58
59
60
        A_sparse: T.Tensor(A_sparse_shape, in_dtype),
        E: T.Tensor((M, K // E_factor), "uint8"),
        B: T.Tensor(B_shape, in_dtype),
        C: T.Tensor((M, N), out_dtype),
61
62
63
64
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
65
            E_shared = T.alloc_shared((block_M, block_K // E_factor), "uint8")
66
            C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
67
68
69
70
71
72
            T.annotate_layout(
                {
                    E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
                    E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
                }
            )
73
            T.disable_warp_group_reg_alloc()
74
            T.clear(C_frag)
75
76
77
78
79
80
81
82
83
84
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
                if trans_A:
                    T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
                else:
                    T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
                if trans_B:
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                else:
                    T.copy(B[k * block_K, bx * block_N], B_shared)
85
86
                T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)
            T.copy(C_frag, C[by * block_M, bx * block_N])
87
88
89
90

    return main


91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
def matmul_sp_sm80(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    threads,
    trans_A,
    trans_B,
):
    is_8_bit = "8" in in_dtype
107
    metadata_dtype = "int32" if is_8_bit else "int16"
108
    E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype]
109
110
111
112
113
114
115
116
117
    A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
    B_shape = (K, N) if not trans_B else (N, K)
    A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
    B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)

    import tilelang.language as T

    @T.prim_func
    def main(
118
119
120
121
        A_sparse: T.Tensor(A_sparse_shape, in_dtype),
        E: T.Tensor((M, K // E_factor), metadata_dtype),
        B: T.Tensor(B_shape, in_dtype),
        C: T.Tensor((M, N), out_dtype),
122
123
124
125
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
126
            E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
127
            C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
128
129
130
131
132
133
            T.annotate_layout(
                {
                    E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
                    E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
                }
            )
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
            T.clear(C_frag)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
                if trans_A:
                    T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
                else:
                    T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
                if trans_B:
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                else:
                    T.copy(B[k * block_K, bx * block_N], B_shared)
                T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)
            T.copy(C_frag, C[by * block_M, bx * block_N])

    return main


151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def normalize(tensor, max_range=100.0):
    assert max_range <= 448.0
    max_v = tensor.abs().max().clamp(1e-4)
    scaler = max_range / max_v
    return tensor * scaler


def calc_diff(x, y):
    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


def run_gemm_sp(
166
    kernel,
167
168
169
170
171
172
    M,
    N,
    K,
    in_dtype,
    out_dtype,
    block_K,
173
174
    trans_A,
    trans_B,
175
176
):
    kernel = tilelang.compile(
177
        kernel,
178
179
        out_idx=[-1],
    )
180
181
182
183
184
185
186
187
    A, B = generate_dense_input(
        M=M,
        N=N,
        K=K,
        trans_A=trans_A,
        trans_B=trans_B,
        in_dtype=in_dtype,
    )
188
    A_sparse, E = compress(A, transposed=trans_A, block_k=block_K)
189
190
191
192
193
194
195
196
197
198
199

    C_sp = kernel(A_sparse, E, B)

    def _matmul(A, B):
        if trans_A:
            A = A.T
        if trans_B:
            B = B.T
        if "float8" in in_dtype or "int8" in in_dtype:
            A = A.to(torch.float32)
            B = B.to(torch.float32)
200
        return torch.matmul(A, B)
201
202

    C = _matmul(A, B)
203

204
    if "float8" in in_dtype:
205
206
207
        diff = calc_diff(C_sp, C)
        assert diff < 1e-3, f"{diff=}"
    else:
208
209
210
211
212
213
214
215
        torch_assert_close(
            C_sp.to(torch.float32),
            C.to(torch.float32),
            rtol=1e-3,
            atol=1e-3,
            base_name="tilelang_sp",
            ref_name="ref_dense",
        )
216
217
218
    print("pass")


219
220
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
221
222
223
224
225
226
227
228
229
230
231
232
def run_gemm_sp_sm90(
    M,
    N,
    K,
    in_dtype,
    out_dtype,
    accum_dtype,
    block_M,
    block_N,
    block_K,
    num_stages,
    num_threads,
233
234
    trans_A,
    trans_B,
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
):
    kernel = matmul_sp_sm90(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        in_dtype,
        out_dtype,
        accum_dtype,
        num_stages,
        num_threads,
        trans_A,
        trans_B,
    )
    run_gemm_sp(
        kernel,
        M,
        N,
        K,
        in_dtype,
        out_dtype,
        block_K,
        trans_A,
        trans_B,
    )


264
265
266
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(8, 0)
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
267
268
269
270
271
272
273
274
275
276
277
278
def run_gemm_sp_sm80(
    M,
    N,
    K,
    in_dtype,
    out_dtype,
    accum_dtype,
    block_M,
    block_N,
    block_K,
    num_stages,
    num_threads,
279
280
    trans_A,
    trans_B,
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
):
    kernel = matmul_sp_sm80(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        in_dtype,
        out_dtype,
        accum_dtype,
        num_stages,
        num_threads,
        trans_A,
        trans_B,
    )
    run_gemm_sp(
        kernel,
        M,
        N,
        K,
        in_dtype,
        out_dtype,
        block_K,
        trans_A,
        trans_B,
    )


310
311
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
@pytest.mark.parametrize(
    "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B",
    [
        (512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 2, 128, False, False),
        (512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 0, 256, False, False),
        (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, False),
        (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128, False, False),
        (512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 0, 128, False, False),
        (512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 2, 128, False, False),
        (512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128, False, False),
        (512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128, False, False),
        (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True),
        (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, False),
        (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, True),
        (512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, True),
        (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True),
    ],
)
def test_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B):
    run_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B)
332

333
334
335
336

@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(8, 0)
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
@pytest.mark.parametrize(
    "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B",
    [
        (512, 1024, 768, "float16", "float32", "float32", 32, 32, 32, 0, 32, False, False),
        (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, False),
        (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, False),
        (512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, True),
        (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, True),
        (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True),
        (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128, False, False),
        (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128, False, False),
        (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 3, 128, False, False),
        (512, 1024, 768, "int8", "int32", "int32", 32, 32, 64, 0, 32, False, True),
        (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 0, 32, False, True),
        (512, 1024, 768, "int8", "int32", "int32", 128, 128, 128, 0, 128, False, True),
        (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 1, 128, False, True),
        (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True),
        (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 3, 128, False, True),
    ],
)
def test_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B):
    run_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B)
359
360
361
362


if __name__ == "__main__":
    tilelang.testing.main()