test_tilelang_tilelibrary_gemm_sp.py 12.6 KB
Newer Older
1
2
3
4
import torch
import tilelang
import tilelang.testing

5
from tilelang.utils.sparse import compress
6
7
from tilelang.layout import make_metadata_layout

8
tilelang.disable_cache()
9
10
11
12
torch.set_printoptions(threshold=float('inf'), edgeitems=float('inf'), linewidth=10000)
torch.manual_seed(42)

STR_TO_TYPE = {
13
    'float32': torch.float32,
14
15
16
17
    "float16": torch.float16,
    "bfloat16": torch.bfloat16,
    "float8_e4m3": torch.float8_e4m3fn,
    "int8": torch.int8,
18
    "int32": torch.int32,
19
20
21
}

SPARSITY_MAP = {
22
    # 'float32': (1, 2),  # not supported for now
23
24
25
26
27
28
29
    torch.float16: (2, 4),
    torch.bfloat16: (2, 4),
    torch.float8_e4m3fn: (2, 4),
    torch.int8: (2, 4),
}


30
def matmul_sp_sm90(
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    threads,
    trans_A,
    trans_B,
):
    E_factor = 4 if in_dtype == "float32" else 8
    A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
    B_shape = (K, N) if not trans_B else (N, K)
    A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
    B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)

    import tilelang.language as T

    @T.prim_func
    def main(
            A_sparse: T.Tensor(A_sparse_shape, in_dtype),
            E: T.Tensor((M, K // E_factor), 'uint8'),
            B: T.Tensor(B_shape, in_dtype),
            C: T.Tensor((M, N), out_dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            E_shared = T.alloc_shared((block_M, block_K // E_factor), 'uint8')
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.annotate_layout({
                E:
                    make_metadata_layout(
68
                        E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K),
69
70
71
72
                E_shared:
                    make_metadata_layout(
                        E_shared,
                        mma_dtype="float16",
73
                        arch="9.0",
74
75
76
                        backend="cutlass",
                        block_k=block_K),
            })
77
            T.disable_warp_group_reg_alloc()
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
                if trans_A:
                    T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
                else:
                    T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
                if trans_B:
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                else:
                    T.copy(B[k * block_K, bx * block_N], B_shared)
                T.gemm_sp(A_shared, E_shared, B_shared, C_local, trans_A, trans_B)
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def matmul_sp_sm80(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    threads,
    trans_A,
    trans_B,
):
    is_8_bit = "8" in in_dtype
    E_factor = 32 if is_8_bit else 16
    A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
    B_shape = (K, N) if not trans_B else (N, K)
    A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
    B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)

    import tilelang.language as T

    @T.prim_func
    def main(
            A_sparse: T.Tensor(A_sparse_shape, in_dtype),
            E: T.Tensor((M, K // E_factor), 'int32' if is_8_bit else 'int16'),
            B: T.Tensor(B_shape, in_dtype),
            C: T.Tensor((M, N), out_dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            E_shared = T.alloc_shared((block_M, block_K // E_factor),
                                      'int32' if is_8_bit else 'int16')
            C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.annotate_layout({
                E:
                    make_metadata_layout(E, mma_dtype="float16", backend="cutlass", arch="8.0"),
                E_shared:
                    make_metadata_layout(
                        E_shared, mma_dtype="float16", backend="cutlass", arch="8.0"),
            })
            T.clear(C_frag)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
                if trans_A:
                    T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
                else:
                    T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
                if trans_B:
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                else:
                    T.copy(B[k * block_K, bx * block_N], B_shared)
                T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)
            T.copy(C_frag, C[by * block_M, bx * block_N])

    return main


156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def generate_sparse_tensor_float32(M: int, K: int, dtype: torch.dtype, device='cpu', trans_A=False):
    elem, group = SPARSITY_MAP[dtype]
    if K % group != 0:
        raise ValueError(
            f"Last dimension must be divisible by {group} for {elem}:{group} sparsity.")

    if trans_A:
        full_tensor = torch.randn(K * M, dtype=torch.float32, device=device).view(K, M)
        mask = torch.zeros_like(full_tensor, dtype=torch.bool)
        for j in range(M):
            for i in range(0, K, group):
                flat_idx = torch.randint(0, group, (elem,), dtype=torch.int64)
                for k in range(1, len(flat_idx)):
                    while flat_idx[k] in flat_idx[:k]:
                        flat_idx[k] = torch.randint(0, group, (1,), dtype=torch.int64)
                for idx in flat_idx:
                    mask[i + idx, j] = True
    else:
        full_tensor = torch.randn((M, K), dtype=torch.float32, device=device).view(M, K)
        mask = torch.zeros_like(full_tensor, dtype=torch.bool)
        for i in range(M):
            for j in range(0, K, group):
                flat_idx = torch.randint(0, group, (elem,), dtype=torch.int64)
                for k in range(1, len(flat_idx)):
                    while flat_idx[k] in flat_idx[:k]:
                        flat_idx[k] = torch.randint(0, group, (1,), dtype=torch.int64)
                for idx in flat_idx:
                    mask[i, j + idx] = True

    return full_tensor * mask


def normalize(tensor, max_range=100.0):
    assert max_range <= 448.0
    max_v = tensor.abs().max().clamp(1e-4)
    scaler = max_range / max_v
    return tensor * scaler


def calc_diff(x, y):
    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


def run_gemm_sp(
203
    kernel,
204
205
206
207
208
209
    M,
    N,
    K,
    in_dtype,
    out_dtype,
    block_K,
210
211
    trans_A,
    trans_B,
212
213
):
    kernel = tilelang.compile(
214
        kernel,
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        out_idx=[-1],
    )
    A = generate_sparse_tensor_float32(
        M, K, dtype=STR_TO_TYPE[in_dtype], device='cuda', trans_A=trans_A)
    if trans_B:
        B = torch.randn((N, K), device='cuda', dtype=torch.float32)
    else:
        B = torch.randn((K, N), device='cuda', dtype=torch.float32)

    if "float8" in in_dtype or "int8" in in_dtype:
        A = normalize(A)
        B = normalize(B)

    A = A.to(STR_TO_TYPE[in_dtype])
    B = B.to(STR_TO_TYPE[in_dtype])

231
    A_sparse, E = compress(A, transposed=trans_A, block_k=block_K)
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253

    C_sp = kernel(A_sparse, E, B)

    def _matmul(A, B):
        if trans_A:
            A = A.T
        if trans_B:
            B = B.T
        if "float8" in in_dtype or "int8" in in_dtype:
            A = A.to(torch.float32)
            B = B.to(torch.float32)
        return torch.matmul(A, B).to(STR_TO_TYPE[out_dtype])

    C = _matmul(A, B)
    if 'float8' in in_dtype:
        diff = calc_diff(C_sp, C)
        assert diff < 1e-3, f"{diff=}"
    else:
        torch.testing.assert_close(C_sp, C, atol=1e-3, rtol=1e-3)
    print("pass")


254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def run_gemm_sp_sm90(
    M,
    N,
    K,
    in_dtype,
    out_dtype,
    accum_dtype,
    block_M,
    block_N,
    block_K,
    num_stages,
    num_threads,
    trans_A=False,
    trans_B=False,
):
    kernel = matmul_sp_sm90(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        in_dtype,
        out_dtype,
        accum_dtype,
        num_stages,
        num_threads,
        trans_A,
        trans_B,
    )
    run_gemm_sp(
        kernel,
        M,
        N,
        K,
        in_dtype,
        out_dtype,
        block_K,
        trans_A,
        trans_B,
    )


def run_gemm_sp_sm80(
    M,
    N,
    K,
    in_dtype,
    out_dtype,
    accum_dtype,
    block_M,
    block_N,
    block_K,
    num_stages,
    num_threads,
    trans_A=False,
    trans_B=False,
):
    kernel = matmul_sp_sm80(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        in_dtype,
        out_dtype,
        accum_dtype,
        num_stages,
        num_threads,
        trans_A,
        trans_B,
    )
    run_gemm_sp(
        kernel,
        M,
        N,
        K,
        in_dtype,
        out_dtype,
        block_K,
        trans_A,
        trans_B,
    )


340
341
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
342
343
344
def test_gemm_sp_sm90():
    run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 2, 128)
    run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 0, 256)
345

346
347
    run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128)
    run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128)
348

349
350
    run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 0, 128)
    run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 2, 128)
351

352
353
    run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128)
    run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128)
354

355
356
357
358
359
360
    run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False,
                     True)
    run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True,
                     False)
    run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True,
                     True)
361

362
363
364
    run_gemm_sp_sm90(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False,
                     True)
    run_gemm_sp_sm90(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True)
365

366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392

@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(8, 0)
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_gemm_sp_sm80():
    run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 32, 0, 32)
    run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32)
    run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128)

    run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False,
                     True)
    run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False,
                     True)
    run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False,
                     True)

    run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128)
    run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128)
    run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 3, 128)

    run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 32, 32, 64, 0, 32, False, True)
    run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 0, 32, False, True)
    run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 128, 128, 128, 0, 128, False, True)

    run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 1, 128, False, True)
    run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True)
    run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 3, 128, False, True)
393
394
395
396


if __name__ == "__main__":
    tilelang.testing.main()