example_dequant_gemm_w4a8.py 7.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import tilelang
import tilelang.language as T
from tilelang.autotuner import *
from tvm import tir
import itertools
import torch
import argparse


def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
    assert nbit == 4
    assert dtype == "int8"
    assert val.dtype == "uint8"

    mask = tir.const((1 << nbit) - 1, "uint8")

    i4 = (val >> (pos.astype("uint8") * tir.const(nbit, "uint8"))) & mask

    i8_shifted = tir.reinterpret("int8", i4 << tir.const(4, "uint8"))
    i8 = i8_shifted >> tir.const(4, "int8")
    return i8


def get_configs():
    iter_params = dict(
        block_M=[64, 128],
        block_N=[64, 128],
        block_K=[128, 256],
        num_stages=[1, 2],
        threads=[128, 256, 512],
    )
    return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]


@tilelang.jit(out_idx=[1])
def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
    num_elems_per_byte = 8 // num_bits
    storage_dtype = "uint8"
    B_shape = (N, K // num_elems_per_byte)
    B_shared_shape = (block_N, block_K // num_elems_per_byte)
    B_dequantize_shared_shape = (block_N, block_K)

    @T.prim_func
    def main(
45
46
        B: T.Tensor(B_shape, storage_dtype),
        C: T.Tensor((N, K), in_dtype),
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    ):
        with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
            B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
            B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
            B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)

            for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1):
                T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
                T.copy(B_shared, B_local)
                for i, j in T.Parallel(block_N, block_K):
                    B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8(
                        num_bits,
                        B_local[i, j // num_elems_per_byte],
                        j % num_elems_per_byte,
                        dtype=in_dtype,
                    )
                T.copy(B_dequantize_local, C[bx * block_N, k * block_K])

    return main


def torch_convert(tensor):
    def _convert(val, pos):
        assert val.dtype == torch.uint8
        val = val.view(torch.int8)
        mask = (1 << 4) - 1
73
74
        i4_shifted = (val >> (pos * 4)) & mask
        i4 = (i4_shifted << 4) >> 4
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

        return i4.view(torch.int8)

    N = tensor.shape[0]
    K = tensor.shape[1]
    new_tensor = torch.empty(N, K * 2, dtype=torch.int8, device=tensor.device)
    for i in range(new_tensor.shape[0]):
        for j in range(new_tensor.shape[1]):
            new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
    return new_tensor


def ref_program(A, qB):
    dtypeC = "int32"
    B = torch_convert(qB)
    C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
    C = C.to(torch.__getattribute__(dtypeC))
    return C.transpose(0, 1)


def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
    @tilelang.jit(out_idx=[2])
    def kernel_func(block_M, block_N, block_K, num_stages, threads):
        num_elems_per_byte = 8 // num_bits
        storage_dtype = "uint8"
        A_shape = (M, K)
        B_shape = (N, K // num_elems_per_byte)
        A_shared_shape = (block_M, block_K)
        B_shared_shape = (block_N, block_K // num_elems_per_byte)
        B_dequantize_local_shape = (block_N, block_K)

        assert K % (block_K) == 0

        @T.prim_func
        def main(
110
111
112
            A: T.Tensor(A_shape, in_dtype),
            B: T.Tensor(B_shape, storage_dtype),
            Ct: T.Tensor((N, M), out_dtype),
113
        ):
114
            with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
115
116
117
118
119
120
121
122
                A_shared = T.alloc_shared(A_shared_shape, in_dtype)
                B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
                B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
                B_dequantize_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype)
                B_dequantize_prev_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype)
                Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
                Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)

123
124
125
126
127
128
                T.annotate_layout(
                    {
                        B_shared: tilelang.layout.make_swizzled_layout(B_shared),
                        Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
                    }
                )
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

                T.clear(Ct_local)
                for k in T.Pipelined(K // block_K, num_stages=num_stages):
                    T.copy(A[by * block_M, k * block_K], A_shared)
                    T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
                    T.copy(B_shared, B_local)
                    for i, j in T.Parallel(block_N, block_K):
                        B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8(
                            num_bits,
                            B_local[i, j // num_elems_per_byte],
                            j % num_elems_per_byte,
                            dtype=in_dtype,
                        )
                    T.copy(B_dequantize_local, B_dequantize_prev_local)
                    T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
                T.copy(Ct_local, Ct_shared)
145
                T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167

        return main

    if tune:

        @autotune(configs=get_configs(), warmup=10, rep=10)
        @tilelang.jit(out_idx=[2])
        def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
            return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func

        return kernel()

    else:

        def kernel(block_M, block_N, block_K, num_stages, threads):
            return kernel_func(block_M, block_N, block_K, num_stages, threads)

        return kernel


def main(m=128, n=256, k=256, tune=False):
    total_flops = 2 * m * n * k
168
169
170
171
    if not tune:
        kernel = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)(
            block_M=32, block_N=32, block_K=128, num_stages=1, threads=128
        )
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        profiler = kernel.get_profiler()
        profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2)
        print("All checks pass.")

        latency = profiler.do_bench(warmup=50)
        print(f"Tilelang: {latency} ms")

    else:
        best_result = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)
        best_latency = best_result.latency
        best_config = best_result.config
        print(f"Bset latency: {best_latency}")
        print(f"Best config: {best_config}")
        print(f"Best tflops: {total_flops / best_latency * 1e-9}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--m", type=int, default=512, help="Matrix dimension M")
    parser.add_argument("--n", type=int, default=512, help="Matrix dimension N")
    parser.add_argument("--k", type=int, default=512, help="Matrix dimension K")
    parser.add_argument("--tune", action="store_true", help="Enable tuning")
    args = parser.parse_args()

    M, N, K = args.m, args.n, args.k
    main(M, N, K, args.tune)
    # main(M, N, K, True)