example_dequant_gemv_fp16xint4.py 7.21 KB
Newer Older
1
2
3
4
5
6
import tilelang
from tilelang import language as T
from typing import Optional, Callable, Any
import torch
from tilelang import DataType
from tilelang.quantize import (
7
8
    _tir_packed_int_to_int_convert,
)
9
10


11
@tilelang.jit
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def dequantize_gemv(
    M: int,
    N: int,
    K: int,
    in_dtype: str,
    out_dtype: str,
    accum_dtype: str,
    num_bits: int = 4,
    storage_dtype: str = "int8",
    source_format: str = "uint",
    n_partition: int = 4,
    reduce_thread: int = 32,
    fast_decoding: bool = False,
    trans_A: bool = False,
    trans_B: bool = True,
    group_size: int = -1,
    with_scaling: bool = False,
) -> Callable[..., Any]:
    assert n_partition is not None, "n_partition must be provided"
    assert reduce_thread is not None, (
32
33
        "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented"
    )
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

    assert trans_A is False, "Dequantize only implement for trans_A=False currently"
    assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"
    storage_type = "".join(c for c in storage_dtype if not c.isdigit())
    storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
    num_elems_per_byte = storage_nbit // num_bits

    MAX_TRANSACTION_SIZE_IN_BITS = 128
    micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
    micro_size_k_compressed = micro_size_k // num_elems_per_byte
    block_K = reduce_thread * micro_size_k

    if group_size == -1:
        group_size = K

    A_shape = (M, K)
    B_shape = (N, K // storage_nbit * num_bits)
    C_shape = (M, N)

    dp4a_size = 4
    use_dp4a = in_dtype == "int8" and accum_dtype == "int32"

    import_source: Optional[str] = None
    func_name: str = ""
    if fast_decoding is True:
        # Lazy import to decrease the startup time
        # as intrin registry may take a while to load
        from tilelang.quantize import get_lop3_intrin_group

        lop3_intrin_info = get_lop3_intrin_group(
            out_dtype=in_dtype,
            source_format=source_format,
            source_bit=num_bits,
            storage_dtype=storage_dtype,
            with_scaling=with_scaling,
            with_zeros=False,
        )
        import_source = lop3_intrin_info["c_source"]
        func_name = lop3_intrin_info["func_name"]
        assert import_source is not None, "lop3_intrin_info is not found"
        assert func_name is not None, "lop3_intrin_info is not found"
        import_source = import_source

    @T.prim_func
    def main(
        A: T.Tensor[A_shape, in_dtype],
        B: T.Tensor[B_shape, storage_dtype],
        C: T.Tensor[C_shape, out_dtype],
    ):
        with T.Kernel(
84
85
86
            T.ceildiv(N, n_partition),
            M,
            threads=(reduce_thread, n_partition),
87
        ) as (
88
89
            bx,
            by,
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        ):
            A_local = T.alloc_local((micro_size_k,), in_dtype)
            B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype)
            B_dequantize_local = T.alloc_local([micro_size_k], in_dtype)
            accum_res = T.alloc_local((1,), accum_dtype)
            reduced_accum_res = T.alloc_local((1,), accum_dtype)

            kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x")
            ni = T.thread_binding(0, n_partition, thread="threadIdx.y")

            T.import_source(import_source)

            T.clear(accum_res)
            for ko in T.serial(T.ceildiv(K, block_K)):
                for v in T.vectorized(micro_size_k):
                    A_local[v] = A[by, ko * block_K + kr * micro_size_k + v]

                for v in T.vectorized(micro_size_k_compressed):
                    B_quant_local[v] = B[
                        bx * n_partition + ni,
110
                        ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v,
111
112
113
114
115
116
117
118
119
120
121
                    ]

                if fast_decoding:
                    T.call_extern(
                        func_name,
                        T.address_of(B_quant_local[0]),
                        T.address_of(B_dequantize_local[0]),
                        dtype=in_dtype,
                    )
                else:
                    for ki in T.serial(micro_size_k):
122
123
124
                        B_dequantize_local[ki] = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
                            num_bits, B_quant_local[ki // num_elems_per_byte], ki % num_elems_per_byte, in_dtype
                        )
125
126
127
128
129
130
131
132
133
134
135
136
137

                if use_dp4a:
                    for ki in T.serial(micro_size_k // dp4a_size):
                        T.dp4a(
                            A_local[ki * dp4a_size],
                            B_dequantize_local[ki * dp4a_size],
                            accum_res[0],
                        )
                else:
                    for ki in T.serial(micro_size_k):
                        accum_res[0] += A_local[ki] * B_dequantize_local[ki]

            with T.attr(
138
139
140
                T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
                "reduce_scope",
                T.reinterpret(T.uint64(0), dtype="handle"),
141
142
143
144
145
146
147
148
149
            ):
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        accum_res[0],
                        True,
                        reduced_accum_res[0],
                        kr,
                        dtype="handle",
150
151
                    )
                )
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
            if kr == 0:
                C[by, bx * n_partition + ni] = reduced_accum_res[0]

    return main


def main() -> None:
    M = 1
    N = 1024
    K = 1024
    in_dtype = "float16"
    out_dtype = "float16"
    accum_dtype = "float16"
    num_bits = 4
    storage_dtype = "int8"
    source_format = "uint"
    n_partition = 4
    reduce_thread = 32
    fast_decoding = True
    trans_A = False
    trans_B = True
    group_size = -1
    with_scaling = False

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    kernel = dequantize_gemv(
        M,
        N,
        K,
        in_dtype,
        out_dtype,
        accum_dtype,
        num_bits,
        storage_dtype,
        source_format,
        n_partition,
        reduce_thread,
        fast_decoding,
        trans_A,
        trans_B,
        group_size,
        with_scaling,
    )
194
195
196
197

    storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
    num_elems_per_byte = storage_nbit // num_bits
    A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda()
198
    qB = torch.randint(0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda()
199
200
201
202
    C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda()

    if fast_decoding:
        from tilelang.quantize.utils import interleave_weight
203

204
205
206
207
        qB = interleave_weight(qB, num_bits, in_dtype)
    kernel(A, qB, C)

    # int4 reference
208
    B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device)
209
210
211
212
213
214
215
216
217
218
219
220
221
    for j in range(B.shape[1]):
        B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)

    # Get Reference Result
    ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
    print("C: ", C)
    print("Ref C: ", ref_c)
    # doesn't apply scaling, the absolute error is large
    torch.testing.assert_close(C, ref_c, atol=1e3, rtol=1e-1)


if __name__ == "__main__":
    main()