example_dequant_gemm_fp4_hopper.py 12.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import tilelang
from tilelang import Profiler
import tilelang.language as T
from tilelang.autotuner import *
from tilelang import tvm
from tvm import tir
import itertools
import torch
import argparse
from functools import partial

15

16
17
18
19
20
21
22
23
24
25
26
27
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
    assert nbit == 4
    assert dtype == "float16"
    assert val.dtype == "uint8"
    # e_f4 == 0 -> e_f16 = 0
    # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
    # s1e2n1
    mask = tir.const((1 << nbit) - 1, "uint16")
    f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
    s = f4 >> tir.const(3, "uint16")
    e_f4 = f4 & tir.const(7, "uint16")
    e_f16 = e_f4 | tir.const(8, "uint16")
28
29
30
    val_f16 = tir.reinterpret(
        "float16",
        ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
31
32
33
    # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
    return val_f16

34

35
def torch_convert(tensor):
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    def print_bit(name, val):
        val_cpu = val.cpu().item()
        binary_repr = f'{val_cpu:032b}'
        print(name, binary_repr)

    def _convert(val, pos):
        assert val.dtype == torch.uint8
        val = val.view(torch.int8)
        mask = (1 << 4) - 1
        f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
        s = f4 >> 3
        e_f4 = f4 & 7
        e_f16 = e_f4 | 8
        val_f16 = ((e_f16 | (s << 5)) << 10) & 0xFFFF
        lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
        return lower_16_bits.view(torch.float16)
53

54
55
56
57
58
59
60
61
    N = tensor.shape[0]
    K = tensor.shape[1]
    new_tensor = torch.empty(N, K * 2, dtype=torch.float16, device=tensor.device)
    for i in range(new_tensor.shape[0]):
        for j in range(new_tensor.shape[1]):
            new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
    return new_tensor

62

63
64
65
66
67
68
69
70
71
def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
    num_elems_per_byte = 8 // num_bits
    storage_dtype = "uint8"
    B_shape = (N, K // num_elems_per_byte)
    B_shared_shape = (block_N, block_K // num_elems_per_byte)
    B_dequantize_shared_shape = (block_N, block_K)

    @T.prim_func
    def main(
72
73
            B: T.Buffer(B_shape, storage_dtype),
            C: T.Buffer((N, K), in_dtype),
74
75
76
77
78
79
    ):
        with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
            B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
            B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
            B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)

80
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
81
82
83
84
85
86
87
88
89
90
91
92
93
                T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
                T.copy(B_shared, B_local)
                for i, j in T.Parallel(block_N, block_K):
                    B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
                        num_bits,
                        B_local[i, j // num_elems_per_byte],
                        j % num_elems_per_byte,
                        dtype=in_dtype,
                    )
                T.copy(B_dequantize_local, C[bx * block_N, k * block_K])

    return main

94

95
96
97
98
99
def test_fp4_fp16_convert_close():
    N, K = 256, 256
    block_N, block_K = 64, 64
    program = test_convert(
        N,
100
        K,
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        block_N,
        block_K,
        "float16",
    )

    mod, params = tilelang.lower(program)
    mod = Profiler(mod, params, [1], tilelang.TensorSupplyType.Integer)

    B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
    tl_out = mod.func(B)
    ref_out = torch_convert(B)
    assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
    print("Pass")

115

116
117
118
119
120
121
122
123
124
def get_configs():
    block_M = [128]
    block_N = [128, 256]
    block_K = [128]
    num_stages = [2]
    threads = [256]
    splits = [1]
    _configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits))

125
126
127
128
129
130
131
132
    configs = [{
        'block_M': c[0],
        'block_N': c[1],
        'block_K': c[2],
        'num_stages': c[3],
        'threads': c[4],
        'split': c[5]
    } for c in _configs]
133
134
    return configs

135

136
def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
137

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
        num_elems_per_byte = 8 // num_bits
        storage_dtype = "uint8"
        A_shape = (M, K)
        B_shape = (N, K // num_elems_per_byte)
        A_shared_shape = (block_M, block_K)
        B_shared_shape = (block_N, block_K // num_elems_per_byte)
        B_dequantize_shared_shape = (block_N, block_K)
        assert K % (block_K * split) == 0
        KK = K // split

        @T.prim_func
        def main_split(
                A: T.Buffer(A_shape, in_dtype),
                B: T.Buffer(B_shape, storage_dtype),
                Ct: T.Buffer((N, M), out_dtype),
        ):
155
156
157
158
159
160
161
            SplitC = T.alloc_buffer([
                split, (N + block_N - 1) // block_N * block_N,
                (M + block_M - 1) // block_M * block_M
            ], out_dtype)
            with T.Kernel(
                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), split,
                    threads=threads) as (bx, by, bz):
162
163
164
165
166
167
168
169
                A_shared = T.alloc_shared(A_shared_shape, in_dtype)
                B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
                B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
                B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
                B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
                Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
                Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)

170
171
172
173
                T.annotate_layout({
                    B_shared: tilelang.layout.make_swizzled_layout(B_shared),
                    Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
                })
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188

                T.clear(Ct_local)
                for k in T.Pipelined(K // (block_K * split), num_stages=num_stages):
                    T.copy(A[by * block_M, KK * bz + k * block_K], A_shared)
                    T.copy(B[bx * block_N, (KK * bz + k * block_K) // num_elems_per_byte], B_shared)
                    T.copy(B_shared, B_local)
                    for i, j in T.Parallel(block_N, block_K):
                        B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
                            num_bits,
                            B_local[i, j // num_elems_per_byte],
                            j % num_elems_per_byte,
                            dtype=in_dtype,
                        )
                    T.copy(B_dequantize_local, B_dequantize_prev_local)
                    T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
189
190
                T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N,
                                        by * block_M:(by + 1) * block_M])
191
192
193
194
195
196
197
            with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by):
                acc = T.alloc_fragment((block_N, block_M), out_dtype)
                T.clear(acc)
                for k in range(split):
                    for i, j in T.Parallel(block_N, block_M):
                        acc[i, j] += SplitC[k, bx * block_N + i, by * block_M + j]
                T.copy(acc, Ct[bx * block_N, by * block_M])
198

199
200
201
202
203
204
        @T.prim_func
        def main(
                A: T.Buffer(A_shape, in_dtype),
                B: T.Buffer(B_shape, storage_dtype),
                Ct: T.Buffer((N, M), out_dtype),
        ):
205
206
            with T.Kernel(
                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
207
208
209
210
211
212
213
214
                A_shared = T.alloc_shared(A_shared_shape, in_dtype)
                B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
                B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
                B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
                B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
                Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
                Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)

215
216
217
218
                T.annotate_layout({
                    B_shared: tilelang.layout.make_swizzled_layout(B_shared),
                    Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
                })
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

                T.clear(Ct_local)
                for k in T.Pipelined(K // block_K, num_stages=num_stages):
                    T.copy(A[by * block_M, k * block_K], A_shared)
                    T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
                    T.copy(B_shared, B_local)
                    for i, j in T.Parallel(block_N, block_K):
                        B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
                            num_bits,
                            B_local[i, j // num_elems_per_byte],
                            j % num_elems_per_byte,
                            dtype=in_dtype,
                        )
                    T.copy(B_dequantize_local, B_dequantize_prev_local)
                    T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
                T.copy(Ct_local, Ct_shared)
235
236
                T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
                                     by * block_M:(by + 1) * block_M])
237
238
239
240
241

        if split == 1:
            return main
        else:
            return main_split
242

243
    if tune:
244

245
246
247
248
        @autotune(
            configs=get_configs(),
            keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"],
            warmup=10,
249
250
251
252
253
254
255
256
257
258
259
260
            rep=10)
        @jit(
            out_idx=[2],
            supply_type=tilelang.TensorSupplyType.Integer,
            ref_prog=None,
            profiler="auto")
        def kernel(block_M=None,
                   block_N=None,
                   block_K=None,
                   num_stages=None,
                   threads=None,
                   split=None):
261
262
263
264
            return kernel_func(block_M, block_N, block_K, num_stages, threads, split)

        return kernel()
    else:
265

266
267
268
269
270
        def kernel(block_M, block_N, block_K, num_stages, threads, split=1):
            return kernel_func(block_M, block_N, block_K, num_stages, threads, split)

        return kernel

271

272
273
274
275
276
277
278
def ref_program(A, qB):
    dtypeC = "float16"
    B = torch_convert(qB)
    C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
    C = C.to(torch.__getattribute__(dtypeC))
    return C.transpose(0, 1)

279

280
281
282
283
284
285
286
287
288
289
290
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--m', type=int, default=256, help='M')
    parser.add_argument('--n', type=int, default=256, help='N')
    parser.add_argument('--k', type=int, default=256, help='K')
    parser.add_argument('--tune', action='store_true', help='tune configs')
    args = parser.parse_args()
    M, N, K = args.m, args.n, args.k
    total_flops = 2 * M * N * K

    if (not args.tune):
291
292
293
        program = matmul(
            M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)(
                block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
294
295
296
297
298
299
300
301
302
303
304
        mod, params = tilelang.lower(program)
        mod = Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer)
        mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
        print("All checks pass.")
        latency = mod.do_bench(ref_program, warmup=500)
        print("Ref: {:.2f} ms".format(latency))
        print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
        latency = mod.do_bench(mod.func, warmup=500)
        print("Tile-lang: {:.2f} ms".format(latency))
        print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
    else:
305
306
        best_latency, best_config, ref_latency = matmul(
            M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)
307
308
309
        print(f"Best latency: {best_latency}")
        print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
        print(f"Best config: {best_config}")