dequantize.py 8.52 KB
Newer Older
blkmjsian's avatar
blkmjsian committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
    LIBINFINIOP,
    TestTensor,
    get_test_devices,
    check_error,
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
)

# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    # qweight_shape, qzeros_shape, qscales_shape, out_shape, qweight_strides, qzeros_strides,
    # qscales_strides, out_strides, qweights_dtype, qzeros_dtype, qscales_dtype, out_dtype, bits, group_size
    (
        (512, 256),
        (16, 256),
        (16, 2048),
        (512, 2048),
        None,
        None,
        None,
        None,
        InfiniDtype.I32,
        InfiniDtype.I32,
        InfiniDtype.F16,
        InfiniDtype.F16,
        4,
        32,
    ),
    (
        (1024, 128),
        (2, 128),
        (2, 1024),
        (1024, 1024),
        None,
        None,
        None,
        None,
        InfiniDtype.I32,
        InfiniDtype.I32,
        InfiniDtype.F16,
        InfiniDtype.F16,
        4,
        512,
    ),
    (
        (2048, 1024),
        (16, 1024),
        (16, 8192),
        (2048, 8192),
        None,
        None,
        None,
        None,
        InfiniDtype.I32,
        InfiniDtype.I32,
        InfiniDtype.F16,
        InfiniDtype.F16,
        4,
        128,
    ),
    (
        (4096, 512),
        (4, 512),
        (4, 4096),
        (4096, 4096),
        None,
        None,
        None,
        None,
        InfiniDtype.I32,
        InfiniDtype.I32,
        InfiniDtype.F16,
        InfiniDtype.F16,
        4,
        1024,
    ),
    (
        (8192, 256),
        (64, 256),
        (64, 2048),
        (8192, 2048),
        None,
        None,
        None,
        None,
        InfiniDtype.I32,
        InfiniDtype.I32,
        InfiniDtype.F16,
        InfiniDtype.F16,
        4,
        128,
    ),
    (
        (8192, 512),
        (32, 512),
        (32, 4096),
        (8192, 4096),
        None,
        None,
        None,
        None,
        InfiniDtype.I32,
        InfiniDtype.I32,
        InfiniDtype.F16,
        InfiniDtype.F16,
        4,
        256,
    ),
blkmjsian's avatar
blkmjsian committed
124
125
126
]

# Data types used for testing
127
_TENSOR_DTYPES = [InfiniDtype.F16]
blkmjsian's avatar
blkmjsian committed
128
129
130

# Tolerance map for different data types
_TOLERANCE_MAP = {
131
    InfiniDtype.F16: {"atol": 0, "rtol": 1e-4},
blkmjsian's avatar
blkmjsian committed
132
133
134
135
136
137
138
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

139
140
AWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
blkmjsian's avatar
blkmjsian committed
141

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

def dequantize(
    qweight: torch.Tensor,
    qzeros: torch.Tensor,
    qscales: torch.Tensor,
    bits: int,
    group_size: int,
):
    shifts = torch.arange(0, 32, bits, device=qweight.device)

    # Unpacking qweight columnwise
    iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
        torch.int8  # smallest dtype available
    )
    iweights = iweights.view(iweights.shape[0], -1)

    # Unpacking qzeros columnwise
    if qzeros is not None:
        izeros = torch.bitwise_right_shift(
            qzeros[:, :, None], shifts[None, None, :]
        ).to(
            torch.int8  # smallest dtype available
        )
        izeros = izeros.view(izeros.shape[0], -1)
    else:
        izeros = qzeros

    # Reverse AWQ specific packing order - weights are packed in reverse within each 32-bit word
    reverse_order_tensor = torch.arange(
        iweights.shape[-1],
        dtype=torch.int32,
        device=izeros.device,
    )
    reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
    reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
    reverse_order_tensor = reverse_order_tensor.view(-1)

    if izeros is not None:
        izeros = izeros[:, reverse_order_tensor]
    iweights = iweights[:, reverse_order_tensor]

    # Extract the actual quantized values by masking higher bits
    iweight = torch.bitwise_and(iweights, (2**bits) - 1)
    izeros = torch.bitwise_and(izeros, (2**bits) - 1)

    # Expand scaling factors and zeros to match the full weight dimensions
    # Apply dequantization formula: dequantized = (quantized - zero_point) * scale
    qscales = qscales.repeat_interleave(group_size, dim=0)
    izeros = izeros.repeat_interleave(group_size, dim=0)
    iweight = (iweight - izeros) * qscales

    return iweight
blkmjsian's avatar
blkmjsian committed
194
195
196
197
198
199
200


# The argument list should be (lib, handle, torch_device, <param list>, dtype)
# The <param list> should keep the same order as the one specified in _TEST_CASES
def test(
    handle,
    device,
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    qweights_shape,
    qzeros_shape,
    qscales_shape,
    out_shape,
    qweights_stride,
    qzeros_stride,
    qscales_stride,
    out_stride,
    qweights_dtype,
    qzeros_dtype,
    qscales_dtype,
    out_dtype,
    bits,
    group_size,
    dtype=None,
blkmjsian's avatar
blkmjsian committed
216
217
218
    sync=None,
):
    print(
219
220
221
222
223
224
225
226
        f"Testing Dequantize on {InfiniDeviceNames[device]} with bits:{bits}, group_size:{group_size},"
        f" qweights_shape:{qweights_shape}, qzeros_shape:{qzeros_shape}, qscales_shape:{qscales_shape},"
        f" qweights_stride:{qweights_stride}, qzeros_stride:{qzeros_stride}, qscales_stride:{qscales_stride},"
        f" qweights_dtype:{InfiniDtypeNames[qweights_dtype]}, qzeros_dtype:{InfiniDtypeNames[qzeros_dtype]}, qscales_dtype:{InfiniDtypeNames[qscales_dtype]}"
    )

    qweights = TestTensor(
        qweights_shape, qweights_stride, qweights_dtype, device, mode="randint"
blkmjsian's avatar
blkmjsian committed
227
    )
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    qzeros = TestTensor(qzeros_shape, qzeros_stride, qzeros_dtype, device, mode="randint")
    qscales = TestTensor(qscales_shape, qscales_stride, qscales_dtype, device)
    out = TestTensor(out_shape, out_stride, out_dtype, device, mode="zeros")
    ans = TestTensor(out_shape, out_stride, out_dtype, device, mode="ones")

    # Compute the PyTorch reference result
    def torch_dequantize():
        return dequantize(
            qweights.torch_tensor(),
            qzeros.torch_tensor(),
            qscales.torch_tensor(),
            bits,
            group_size,
        )

    ans = torch_dequantize()

    if sync is not None:
        sync()
blkmjsian's avatar
blkmjsian committed
247
248
249
250
251
252
253

    descriptor = infiniopOperatorDescriptor_t()
    check_error(
        LIBINFINIOP.infiniopCreateDequantizeDescriptor(
            handle,
            ctypes.byref(descriptor),
            out.descriptor,
254
255
256
            qweights.descriptor,
            qscales.descriptor,
            qzeros.descriptor,
blkmjsian's avatar
blkmjsian committed
257
258
259
260
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
261
262
    for tensor in [qweights, qzeros, qscales, out]:
        tensor.destroy_desc()
blkmjsian's avatar
blkmjsian committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280

    # Get workspace size and create workspace
    workspace_size = c_uint64(0)
    check_error(
        LIBINFINIOP.infiniopGetDequantizeWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
    )
    workspace = TestWorkspace(workspace_size.value, device)

    # Execute infiniop gemm operator
    def lib_dequantize():
        check_error(
            LIBINFINIOP.infiniopDequantize(
                descriptor,
                workspace.data(),
                workspace_size.value,
                out.data(),
281
282
283
                qweights.data(),
                qscales.data(),
                qzeros.data(),
blkmjsian's avatar
blkmjsian committed
284
285
286
287
288
289
                None,
            )
        )

    lib_dequantize()

290
291
    # Validate results
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
blkmjsian's avatar
blkmjsian committed
292

293
294
    if DEBUG:
        debug(out.actual_tensor(), ans, atol=atol, rtol=rtol)
blkmjsian's avatar
blkmjsian committed
295

296
    assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol)
blkmjsian's avatar
blkmjsian committed
297

298
299
300
301
302
303
304
    # Profiling workflow
    if PROFILE:
        # fmt: off
        profile_operation("PyTorch", lambda: torch_dequantize(), device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_dequantize(), device, NUM_PRERUN, NUM_ITERATIONS)
        # fmt: on
    check_error(LIBINFINIOP.infiniopDestroyDequantizeDescriptor(descriptor))
blkmjsian's avatar
blkmjsian committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323


# ==============================================================================
#  Main Execution
# ==============================================================================
if __name__ == "__main__":
    args = get_args()

    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)

    print("\033[92mTest passed!\033[0m")