"driver/device_direct_convolution_3.cuh" did not exist on "a5bcde36e3a53e6ee68ee48af96c7441f620f574"
example_dequant_gemm_fp4_hopper.py 12.1 KB
Newer Older
1
2
3
4
5
6
7
8
import tilelang
import tilelang.language as T
from tilelang.autotuner import *
from tvm import tir
import itertools
import torch
import argparse

9

10
11
12
13
14
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
    assert nbit == 4
    assert dtype == "float16"
    assert val.dtype == "uint8"
    # e_f4 == 0 -> e_f16 = 0
15
16
    # e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14
    # s1e2m1
17
18
19
    mask = tir.const((1 << nbit) - 1, "uint16")
    f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
    s = f4 >> tir.const(3, "uint16")
20
21
22
23
24
25
26
    e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
    e_f16 = e_f4 + tir.const(14, "uint16")
    m_f4 = f4 & tir.const(1, "uint16")
    m_f16 = m_f4
    val_f16 = tir.reinterpret("float16",
                              ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")
                               | m_f16 << tir.const(9, "uint16")).astype("uint16"))
27
28
29
    # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
    return val_f16

30

31
def torch_convert(tensor):
32

33
34
35
36
37
38
39
40
41
42
43
    def print_bit(name, val):
        val_cpu = val.cpu().item()
        binary_repr = f'{val_cpu:032b}'
        print(name, binary_repr)

    def _convert(val, pos):
        assert val.dtype == torch.uint8
        val = val.view(torch.int8)
        mask = (1 << 4) - 1
        f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
        s = f4 >> 3
44
45
46
47
48
        e_f4 = (f4 & 6) >> 1
        e_f16 = e_f4 + 14
        m_f4 = f4 & 1
        m_f16 = m_f4
        val_f16 = (((e_f16 | (s << 5)) << 10) | (m_f16 << 9)) & 0xFFFF
49
50
        lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
        return lower_16_bits.view(torch.float16)
51

52
53
54
55
56
57
58
59
    N = tensor.shape[0]
    K = tensor.shape[1]
    new_tensor = torch.empty(N, K * 2, dtype=torch.float16, device=tensor.device)
    for i in range(new_tensor.shape[0]):
        for j in range(new_tensor.shape[1]):
            new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
    return new_tensor

60

61
@tilelang.jit(out_idx=[1])
62
63
64
65
66
67
68
69
70
def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
    num_elems_per_byte = 8 // num_bits
    storage_dtype = "uint8"
    B_shape = (N, K // num_elems_per_byte)
    B_shared_shape = (block_N, block_K // num_elems_per_byte)
    B_dequantize_shared_shape = (block_N, block_K)

    @T.prim_func
    def main(
71
72
            B: T.Tensor(B_shape, storage_dtype),
            C: T.Tensor((N, K), in_dtype),
73
74
75
76
77
78
    ):
        with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
            B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
            B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
            B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)

79
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
80
81
82
83
84
85
86
87
88
89
90
91
92
                T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
                T.copy(B_shared, B_local)
                for i, j in T.Parallel(block_N, block_K):
                    B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
                        num_bits,
                        B_local[i, j // num_elems_per_byte],
                        j % num_elems_per_byte,
                        dtype=in_dtype,
                    )
                T.copy(B_dequantize_local, C[bx * block_N, k * block_K])

    return main

93

94
95
96
def test_fp4_fp16_convert_close():
    N, K = 256, 256
    block_N, block_K = 64, 64
97
    kernel = test_convert(
98
        N,
99
        K,
100
101
102
103
104
105
        block_N,
        block_K,
        "float16",
    )

    B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
106
    tl_out = kernel(B)
107
108
109
110
    ref_out = torch_convert(B)
    assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
    print("Pass")

111

112
113
114
115
116
117
118
119
120
def get_configs():
    block_M = [128]
    block_N = [128, 256]
    block_K = [128]
    num_stages = [2]
    threads = [256]
    splits = [1]
    _configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits))

121
122
123
124
125
126
127
128
    configs = [{
        'block_M': c[0],
        'block_N': c[1],
        'block_K': c[2],
        'num_stages': c[3],
        'threads': c[4],
        'split': c[5]
    } for c in _configs]
129
130
    return configs

131

132
def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
133

134
    @tilelang.jit(out_idx=[2])
135
136
137
138
139
140
141
142
143
144
145
146
147
    def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
        num_elems_per_byte = 8 // num_bits
        storage_dtype = "uint8"
        A_shape = (M, K)
        B_shape = (N, K // num_elems_per_byte)
        A_shared_shape = (block_M, block_K)
        B_shared_shape = (block_N, block_K // num_elems_per_byte)
        B_dequantize_shared_shape = (block_N, block_K)
        assert K % (block_K * split) == 0
        KK = K // split

        @T.prim_func
        def main_split(
148
149
150
                A: T.Tensor(A_shape, in_dtype),
                B: T.Tensor(B_shape, storage_dtype),
                Ct: T.Tensor((N, M), out_dtype),
151
        ):
152
153
154
155
156
157
158
            SplitC = T.alloc_buffer([
                split, (N + block_N - 1) // block_N * block_N,
                (M + block_M - 1) // block_M * block_M
            ], out_dtype)
            with T.Kernel(
                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), split,
                    threads=threads) as (bx, by, bz):
159
160
161
162
163
164
165
166
                A_shared = T.alloc_shared(A_shared_shape, in_dtype)
                B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
                B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
                B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
                B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
                Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
                Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)

167
168
169
170
                T.annotate_layout({
                    B_shared: tilelang.layout.make_swizzled_layout(B_shared),
                    Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
                })
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185

                T.clear(Ct_local)
                for k in T.Pipelined(K // (block_K * split), num_stages=num_stages):
                    T.copy(A[by * block_M, KK * bz + k * block_K], A_shared)
                    T.copy(B[bx * block_N, (KK * bz + k * block_K) // num_elems_per_byte], B_shared)
                    T.copy(B_shared, B_local)
                    for i, j in T.Parallel(block_N, block_K):
                        B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
                            num_bits,
                            B_local[i, j // num_elems_per_byte],
                            j % num_elems_per_byte,
                            dtype=in_dtype,
                        )
                    T.copy(B_dequantize_local, B_dequantize_prev_local)
                    T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
186
187
                T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N,
                                        by * block_M:(by + 1) * block_M])
188
189
190
191
192
193
194
            with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by):
                acc = T.alloc_fragment((block_N, block_M), out_dtype)
                T.clear(acc)
                for k in range(split):
                    for i, j in T.Parallel(block_N, block_M):
                        acc[i, j] += SplitC[k, bx * block_N + i, by * block_M + j]
                T.copy(acc, Ct[bx * block_N, by * block_M])
195

196
197
        @T.prim_func
        def main(
198
199
200
                A: T.Tensor(A_shape, in_dtype),
                B: T.Tensor(B_shape, storage_dtype),
                Ct: T.Tensor((N, M), out_dtype),
201
        ):
202
203
            with T.Kernel(
                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
204
205
206
207
208
209
210
211
                A_shared = T.alloc_shared(A_shared_shape, in_dtype)
                B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
                B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
                B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
                B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
                Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
                Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)

212
213
214
215
                T.annotate_layout({
                    B_shared: tilelang.layout.make_swizzled_layout(B_shared),
                    Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
                })
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231

                T.clear(Ct_local)
                for k in T.Pipelined(K // block_K, num_stages=num_stages):
                    T.copy(A[by * block_M, k * block_K], A_shared)
                    T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
                    T.copy(B_shared, B_local)
                    for i, j in T.Parallel(block_N, block_K):
                        B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
                            num_bits,
                            B_local[i, j // num_elems_per_byte],
                            j % num_elems_per_byte,
                            dtype=in_dtype,
                        )
                    T.copy(B_dequantize_local, B_dequantize_prev_local)
                    T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
                T.copy(Ct_local, Ct_shared)
232
233
                T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
                                     by * block_M:(by + 1) * block_M])
234
235
236
237
238

        if split == 1:
            return main
        else:
            return main_split
239

240
    if tune:
241

242
243
244
245
        @autotune(
            configs=get_configs(),
            keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"],
            warmup=10,
246
            rep=10)
247
        @tilelang.jit(out_idx=[2])
248
249
250
251
252
253
        def kernel(block_M=None,
                   block_N=None,
                   block_K=None,
                   num_stages=None,
                   threads=None,
                   split=None):
254
255
256
257
            return kernel_func(block_M, block_N, block_K, num_stages, threads, split)

        return kernel()
    else:
258

259
260
261
262
263
        def kernel(block_M, block_N, block_K, num_stages, threads, split=1):
            return kernel_func(block_M, block_N, block_K, num_stages, threads, split)

        return kernel

264

265
266
267
268
269
270
271
def ref_program(A, qB):
    dtypeC = "float16"
    B = torch_convert(qB)
    C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
    C = C.to(torch.__getattribute__(dtypeC))
    return C.transpose(0, 1)

272

273
274
def main(m=256, n=256, k=256, tune=False):
    total_flops = 2 * m * n * k
275

276
    if (not tune):
277
        kernel = matmul(
278
            m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)(
279
                block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
280
281
        profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
        profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
282
        print("All checks pass.")
283
        latency = profiler.do_bench(ref_program, warmup=500)
284
285
        print("Ref: {:.2f} ms".format(latency))
        print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
286
        latency = profiler.do_bench(warmup=500)
287
288
289
        print("Tile-lang: {:.2f} ms".format(latency))
        print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
    else:
290
        best_result = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)
291
292
        best_latency = best_result.latency
        best_config = best_result.config
293
294
295
        print(f"Best latency: {best_latency}")
        print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
        print(f"Best config: {best_config}")
296
297
298
299
300
301
302
303
304
305
306


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--m', type=int, default=256, help='M')
    parser.add_argument('--n', type=int, default=256, help='N')
    parser.add_argument('--k', type=int, default=256, help='K')
    parser.add_argument('--tune', action='store_true', help='tune configs')
    args = parser.parse_args()
    M, N, K = args.m, args.n, args.k
    main(M, N, K, args.tune)