test_tilelang_tilelibrary_gemm_sp.py 11.9 KB
Newer Older
1
import pytest
2
3
4
import torch
import tilelang
import tilelang.testing
5
import tilelang.language as T
6

7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse
from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.tensor import torch_assert_close, map_torch_type
from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter

torch.backends.cuda.matmul.allow_tf32 = False
# torch.manual_seed(42)  # only enable when debugging


def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
    is_8bit = "8" in in_dtype
    is_unsigned = "uint" in in_dtype
    is_int = "int" in in_dtype
    if is_int:
        if is_8bit:
            low, high = (0, 4) if is_unsigned else (-2, 2)
        else:
            low, high = (0, 128) if is_unsigned else (-64, 64)
25
26
        A = randint_semi_sparse(M, K, low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A)
        B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda")
27
    else:
28
29
        A = randn_semi_sparse(M, K, dtype=torch.float32, device="cuda", transposed=trans_A).to(map_torch_type(in_dtype))
        B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(map_torch_type(in_dtype))
30
    return A, B
31
32


33
def matmul_sp_sm90(
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    threads,
    trans_A,
    trans_B,
):
48
    E_factor = 4 if in_dtype == T.float32 else 8
49
50
51
52
53
54
55
    A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
    B_shape = (K, N) if not trans_B else (N, K)
    A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
    B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)

    @T.prim_func
    def main(
56
57
58
59
        A_sparse: T.Tensor(A_sparse_shape, in_dtype),
        E: T.Tensor((M, K // E_factor), "uint8"),
        B: T.Tensor(B_shape, in_dtype),
        C: T.Tensor((M, N), out_dtype),
60
61
62
63
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
64
            E_shared = T.alloc_shared((block_M, block_K // E_factor), "uint8")
65
            C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
66
67
68
69
70
71
            T.annotate_layout(
                {
                    E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
                    E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
                }
            )
72
            T.disable_warp_group_reg_alloc()
73
            T.clear(C_frag)
74
75
76
77
78
79
80
81
82
83
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
                if trans_A:
                    T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
                else:
                    T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
                if trans_B:
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                else:
                    T.copy(B[k * block_K, bx * block_N], B_shared)
84
85
                T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)
            T.copy(C_frag, C[by * block_M, bx * block_N])
86
87
88
89

    return main


90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
def matmul_sp_sm80(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    threads,
    trans_A,
    trans_B,
):
    is_8_bit = "8" in in_dtype
106
    metadata_dtype = T.int32 if is_8_bit else T.int16
107
    E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype]
108
109
110
111
112
113
114
    A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
    B_shape = (K, N) if not trans_B else (N, K)
    A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
    B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)

    @T.prim_func
    def main(
115
116
117
118
        A_sparse: T.Tensor(A_sparse_shape, in_dtype),
        E: T.Tensor((M, K // E_factor), metadata_dtype),
        B: T.Tensor(B_shape, in_dtype),
        C: T.Tensor((M, N), out_dtype),
119
120
121
122
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
123
            E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
124
            C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
125
126
127
128
129
130
            T.annotate_layout(
                {
                    E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
                    E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
                }
            )
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
            T.clear(C_frag)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
                if trans_A:
                    T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
                else:
                    T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
                if trans_B:
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                else:
                    T.copy(B[k * block_K, bx * block_N], B_shared)
                T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)
            T.copy(C_frag, C[by * block_M, bx * block_N])

    return main


148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def normalize(tensor, max_range=100.0):
    assert max_range <= 448.0
    max_v = tensor.abs().max().clamp(1e-4)
    scaler = max_range / max_v
    return tensor * scaler


def calc_diff(x, y):
    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


def run_gemm_sp(
163
    kernel,
164
165
166
167
168
169
    M,
    N,
    K,
    in_dtype,
    out_dtype,
    block_K,
170
171
    trans_A,
    trans_B,
172
173
):
    kernel = tilelang.compile(
174
        kernel,
175
176
        out_idx=[-1],
    )
177
178
179
180
181
182
183
184
    A, B = generate_dense_input(
        M=M,
        N=N,
        K=K,
        trans_A=trans_A,
        trans_B=trans_B,
        in_dtype=in_dtype,
    )
185
    A_sparse, E = compress(A, transposed=trans_A, block_k=block_K)
186
187
188
189
190
191
192
193
194
195
196

    C_sp = kernel(A_sparse, E, B)

    def _matmul(A, B):
        if trans_A:
            A = A.T
        if trans_B:
            B = B.T
        if "float8" in in_dtype or "int8" in in_dtype:
            A = A.to(torch.float32)
            B = B.to(torch.float32)
197
        return torch.matmul(A, B)
198
199

    C = _matmul(A, B)
200

201
    if "float8" in in_dtype:
202
203
204
        diff = calc_diff(C_sp, C)
        assert diff < 1e-3, f"{diff=}"
    else:
205
206
207
208
209
210
211
212
        torch_assert_close(
            C_sp.to(torch.float32),
            C.to(torch.float32),
            rtol=1e-3,
            atol=1e-3,
            base_name="tilelang_sp",
            ref_name="ref_dense",
        )
213
214
215
    print("pass")


216
217
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
218
219
220
221
222
223
224
225
226
227
228
229
def run_gemm_sp_sm90(
    M,
    N,
    K,
    in_dtype,
    out_dtype,
    accum_dtype,
    block_M,
    block_N,
    block_K,
    num_stages,
    num_threads,
230
231
    trans_A,
    trans_B,
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
):
    kernel = matmul_sp_sm90(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        in_dtype,
        out_dtype,
        accum_dtype,
        num_stages,
        num_threads,
        trans_A,
        trans_B,
    )
    run_gemm_sp(
        kernel,
        M,
        N,
        K,
        in_dtype,
        out_dtype,
        block_K,
        trans_A,
        trans_B,
    )


261
262
263
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(8, 0)
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
264
265
266
267
268
269
270
271
272
273
274
275
def run_gemm_sp_sm80(
    M,
    N,
    K,
    in_dtype,
    out_dtype,
    accum_dtype,
    block_M,
    block_N,
    block_K,
    num_stages,
    num_threads,
276
277
    trans_A,
    trans_B,
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
):
    kernel = matmul_sp_sm80(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        in_dtype,
        out_dtype,
        accum_dtype,
        num_stages,
        num_threads,
        trans_A,
        trans_B,
    )
    run_gemm_sp(
        kernel,
        M,
        N,
        K,
        in_dtype,
        out_dtype,
        block_K,
        trans_A,
        trans_B,
    )


307
308
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
309
310
311
@pytest.mark.parametrize(
    "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B",
    [
312
313
314
315
316
317
318
319
320
321
322
323
        (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 32, 2, 128, False, False),
        (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 32, 0, 256, False, False),
        (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False),
        (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 2, 128, False, False),
        (512, 1024, 768, T.float16, T.float32, T.float32, 128, 128, 128, 0, 128, False, False),
        (512, 1024, 768, T.float16, T.float32, T.float32, 128, 128, 128, 2, 128, False, False),
        (512, 1024, 768, T.float16, T.float32, T.float32, 64, 128, 256, 0, 128, False, False),
        (512, 1024, 768, T.float16, T.float32, T.float32, 64, 128, 256, 2, 128, False, False),
        (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, True),
        (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False),
        (512, 1024, 768, T.float8_e4m3fn, T.float16, T.float16, 64, 64, 64, 2, 128, False, True),
        (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 2, 128, False, True),
324
325
326
327
    ],
)
def test_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B):
    run_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B)
328

329
330
331
332

@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(8, 0)
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
333
334
335
@pytest.mark.parametrize(
    "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B",
    [
336
337
338
339
340
341
342
343
344
345
346
347
348
349
        (512, 1024, 768, T.float16, T.float32, T.float32, 32, 32, 32, 0, 32, False, False),
        (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 32, False, False),
        (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False),
        (512, 1024, 768, T.float16, T.float32, T.float32, 32, 32, 64, 0, 32, False, True),
        (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 32, False, True),
        (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, True),
        (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 1, 128, False, False),
        (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 2, 128, False, False),
        (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 3, 128, False, False),
        (512, 1024, 768, T.int8, T.int32, T.int32, 32, 32, 64, 0, 32, False, True),
        (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 0, 32, False, True),
        (512, 1024, 768, T.int8, T.int32, T.int32, 128, 128, 128, 0, 128, False, True),
        (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 1, 128, False, True),
        (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 2, 128, False, True),
350
351
352
353
    ],
)
def test_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B):
    run_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B)
354
355
356
357


if __name__ == "__main__":
    tilelang.testing.main()