dequantize_awq.py 8.58 KB
Newer Older
blkmjsian's avatar
blkmjsian committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
    LIBINFINIOP,
    TestTensor,
    get_test_devices,
    check_error,
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
)

# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    # qweight_shape, qzeros_shape, qscales_shape, out_shape, qweight_strides, qzeros_strides,
    # qscales_strides, out_strides, qweights_dtype, qzeros_dtype, qscales_dtype, out_dtype, bits, group_size
    (
        (512, 256),
        (16, 256),
        (16, 2048),
        (512, 2048),
        None,
        None,
        None,
        None,
        InfiniDtype.I32,
        InfiniDtype.I32,
        InfiniDtype.F16,
        InfiniDtype.F16,
        4,
        32,
    ),
    (
        (1024, 128),
        (2, 128),
        (2, 1024),
        (1024, 1024),
        None,
        None,
        None,
        None,
        InfiniDtype.I32,
        InfiniDtype.I32,
        InfiniDtype.F16,
        InfiniDtype.F16,
        4,
        512,
    ),
    (
        (2048, 1024),
        (16, 1024),
        (16, 8192),
        (2048, 8192),
        None,
        None,
        None,
        None,
        InfiniDtype.I32,
        InfiniDtype.I32,
        InfiniDtype.F16,
        InfiniDtype.F16,
        4,
        128,
    ),
    (
        (4096, 512),
        (4, 512),
        (4, 4096),
        (4096, 4096),
        None,
        None,
        None,
        None,
        InfiniDtype.I32,
        InfiniDtype.I32,
        InfiniDtype.F16,
        InfiniDtype.F16,
        4,
        1024,
    ),
    (
        (8192, 256),
        (64, 256),
        (64, 2048),
        (8192, 2048),
        None,
        None,
        None,
        None,
        InfiniDtype.I32,
        InfiniDtype.I32,
        InfiniDtype.F16,
        InfiniDtype.F16,
        4,
        128,
    ),
    (
        (8192, 512),
        (32, 512),
        (32, 4096),
        (8192, 4096),
        None,
        None,
        None,
        None,
        InfiniDtype.I32,
        InfiniDtype.I32,
        InfiniDtype.F16,
        InfiniDtype.F16,
        4,
        256,
    ),
blkmjsian's avatar
blkmjsian committed
124
125
126
]

# Data types used for testing
127
_TENSOR_DTYPES = [InfiniDtype.F16]
blkmjsian's avatar
blkmjsian committed
128
129
130

# Tolerance map for different data types
_TOLERANCE_MAP = {
131
    InfiniDtype.F16: {"atol": 0, "rtol": 1e-4},
blkmjsian's avatar
blkmjsian committed
132
133
134
135
136
137
138
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

139
140
AWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
blkmjsian's avatar
blkmjsian committed
141

142

143
def dequantize_awq(
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    qweight: torch.Tensor,
    qzeros: torch.Tensor,
    qscales: torch.Tensor,
    bits: int,
    group_size: int,
):
    shifts = torch.arange(0, 32, bits, device=qweight.device)

    # Unpacking qweight columnwise
    iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
        torch.int8  # smallest dtype available
    )
    iweights = iweights.view(iweights.shape[0], -1)

    # Unpacking qzeros columnwise
    if qzeros is not None:
        izeros = torch.bitwise_right_shift(
            qzeros[:, :, None], shifts[None, None, :]
        ).to(
            torch.int8  # smallest dtype available
        )
        izeros = izeros.view(izeros.shape[0], -1)
    else:
        izeros = qzeros

    # Reverse AWQ specific packing order - weights are packed in reverse within each 32-bit word
    reverse_order_tensor = torch.arange(
        iweights.shape[-1],
        dtype=torch.int32,
        device=izeros.device,
    )
    reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
    reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
    reverse_order_tensor = reverse_order_tensor.view(-1)

    if izeros is not None:
        izeros = izeros[:, reverse_order_tensor]
    iweights = iweights[:, reverse_order_tensor]

    # Extract the actual quantized values by masking higher bits
    iweight = torch.bitwise_and(iweights, (2**bits) - 1)
    izeros = torch.bitwise_and(izeros, (2**bits) - 1)

    # Expand scaling factors and zeros to match the full weight dimensions
    # Apply dequantization formula: dequantized = (quantized - zero_point) * scale
    qscales = qscales.repeat_interleave(group_size, dim=0)
    izeros = izeros.repeat_interleave(group_size, dim=0)
    iweight = (iweight - izeros) * qscales

    return iweight
blkmjsian's avatar
blkmjsian committed
194
195
196
197
198
199
200


# The argument list should be (lib, handle, torch_device, <param list>, dtype)
# The <param list> should keep the same order as the one specified in _TEST_CASES
def test(
    handle,
    device,
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    qweights_shape,
    qzeros_shape,
    qscales_shape,
    out_shape,
    qweights_stride,
    qzeros_stride,
    qscales_stride,
    out_stride,
    qweights_dtype,
    qzeros_dtype,
    qscales_dtype,
    out_dtype,
    bits,
    group_size,
    dtype=None,
blkmjsian's avatar
blkmjsian committed
216
217
218
    sync=None,
):
    print(
219
        f"Testing Dequantize AWQ on {InfiniDeviceNames[device]} with bits:{bits}, group_size:{group_size},"
220
221
222
223
224
225
226
        f" qweights_shape:{qweights_shape}, qzeros_shape:{qzeros_shape}, qscales_shape:{qscales_shape},"
        f" qweights_stride:{qweights_stride}, qzeros_stride:{qzeros_stride}, qscales_stride:{qscales_stride},"
        f" qweights_dtype:{InfiniDtypeNames[qweights_dtype]}, qzeros_dtype:{InfiniDtypeNames[qzeros_dtype]}, qscales_dtype:{InfiniDtypeNames[qscales_dtype]}"
    )

    qweights = TestTensor(
        qweights_shape, qweights_stride, qweights_dtype, device, mode="randint"
blkmjsian's avatar
blkmjsian committed
227
    )
228
229
230
    qzeros = TestTensor(
        qzeros_shape, qzeros_stride, qzeros_dtype, device, mode="randint"
    )
231
232
233
234
235
    qscales = TestTensor(qscales_shape, qscales_stride, qscales_dtype, device)
    out = TestTensor(out_shape, out_stride, out_dtype, device, mode="zeros")
    ans = TestTensor(out_shape, out_stride, out_dtype, device, mode="ones")

    # Compute the PyTorch reference result
236
237
    def torch_dequantize_awq():
        return dequantize_awq(
238
239
240
241
242
243
244
            qweights.torch_tensor(),
            qzeros.torch_tensor(),
            qscales.torch_tensor(),
            bits,
            group_size,
        )

245
    ans = torch_dequantize_awq()
246
247
248

    if sync is not None:
        sync()
blkmjsian's avatar
blkmjsian committed
249
250
251

    descriptor = infiniopOperatorDescriptor_t()
    check_error(
252
        LIBINFINIOP.infiniopCreateDequantizeAWQDescriptor(
blkmjsian's avatar
blkmjsian committed
253
254
255
            handle,
            ctypes.byref(descriptor),
            out.descriptor,
256
257
258
            qweights.descriptor,
            qscales.descriptor,
            qzeros.descriptor,
blkmjsian's avatar
blkmjsian committed
259
260
261
262
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
263
264
    for tensor in [qweights, qzeros, qscales, out]:
        tensor.destroy_desc()
blkmjsian's avatar
blkmjsian committed
265
266
267
268

    # Get workspace size and create workspace
    workspace_size = c_uint64(0)
    check_error(
269
        LIBINFINIOP.infiniopGetDequantizeAWQWorkspaceSize(
blkmjsian's avatar
blkmjsian committed
270
271
272
273
274
275
            descriptor, ctypes.byref(workspace_size)
        )
    )
    workspace = TestWorkspace(workspace_size.value, device)

    # Execute infiniop gemm operator
276
    def lib_dequantize_awq():
blkmjsian's avatar
blkmjsian committed
277
        check_error(
278
            LIBINFINIOP.infiniopDequantizeAWQ(
blkmjsian's avatar
blkmjsian committed
279
280
281
282
                descriptor,
                workspace.data(),
                workspace_size.value,
                out.data(),
283
284
285
                qweights.data(),
                qscales.data(),
                qzeros.data(),
blkmjsian's avatar
blkmjsian committed
286
287
288
289
                None,
            )
        )

290
    lib_dequantize_awq()
blkmjsian's avatar
blkmjsian committed
291

292
293
    # Validate results
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
blkmjsian's avatar
blkmjsian committed
294

295
296
    if DEBUG:
        debug(out.actual_tensor(), ans, atol=atol, rtol=rtol)
blkmjsian's avatar
blkmjsian committed
297

298
    assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol)
blkmjsian's avatar
blkmjsian committed
299

300
301
302
    # Profiling workflow
    if PROFILE:
        # fmt: off
303
304
        profile_operation("PyTorch", lambda: torch_dequantize_awq(), device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_dequantize_awq(), device, NUM_PRERUN, NUM_ITERATIONS)
305
        # fmt: on
306
    check_error(LIBINFINIOP.infiniopDestroyDequantizeAWQDescriptor(descriptor))
blkmjsian's avatar
blkmjsian committed
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325


# ==============================================================================
#  Main Execution
# ==============================================================================
if __name__ == "__main__":
    args = get_args()

    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)

    print("\033[92mTest passed!\033[0m")