test_dequantize.py 3.12 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""

import pytest
import torch
from deepspeed.ops import op_builder
from deepspeed.accelerator import get_accelerator

quantize_module = None


def int4x2to2xint4(int4X2tensor):
    high = int4X2tensor >> 4
    low = (int4X2tensor << 4) >> 4
    return torch.stack((high, low), dim=-1).flatten()


def run_quantize(data, num_groups, q_bits, is_symmetric_quant):
    global quantize_module
    if quantize_module is None:
        quantize_module = op_builder.QuantizerBuilder().load()

    return quantize_module.quantize(
        data,
        num_groups,
        q_bits,
        quantize_module.Symmetric if is_symmetric_quant else quantize_module.Asymmetric)


def run_dequantize(quantized_data, params, num_groups, q_bits, is_symmetric_quant):
    global quantize_module
    if quantize_module is None:
        quantize_module = op_builder.QuantizerBuilder().load()

    return quantize_module.dequantize(
        quantized_data,
        params,
        num_groups,
        q_bits,
        quantize_module.Symmetric if is_symmetric_quant else quantize_module.Asymmetric)


def run_ref_dequantize(quantized_data, params, num_groups, q_bits, is_symmetric_quant):

    if (q_bits == 4):
        quantized_data = int4x2to2xint4(quantized_data)

    quantized_data = quantized_data.reshape(num_groups, -1).to(torch.float32)

    if is_symmetric_quant:
        return (quantized_data * params).to(torch.float16)
    else:
        scales = params[:, 0].reshape(-1, 1)
        offsets = params[:, 1].reshape(-1, 1)
        return (quantized_data * scales + offsets).to(torch.float16)


@pytest.mark.inference_ops
@pytest.mark.parametrize("num_groups", [1, 13, 512])
@pytest.mark.parametrize("num_elems",
                         [8,
                          16,
                          32,
                          64,
                          128,
                          256,
                          4096,
                          8192,
                          12288,
                          16384])
@pytest.mark.parametrize("is_symmetric_quant", [True, False])
@pytest.mark.parametrize("q_bits", [4, 8])
def test_dequantize(num_elems, num_groups, is_symmetric_quant, q_bits):

    activations = torch.randn((num_groups,
                               num_elems),
                              dtype=torch.float16,
                              device=get_accelerator().device_name())
    quantized_data, params = run_quantize(activations, num_groups, q_bits, is_symmetric_quant)

    ds_dequant = run_dequantize(quantized_data,
                                params,
                                num_groups,
                                q_bits,
                                is_symmetric_quant)
    ref_dequant = run_ref_dequantize(quantized_data,
                                     params,
                                     num_groups,
                                     q_bits,
                                     is_symmetric_quant)

    assert (torch.allclose(ds_dequant.flatten(),
                           ref_dequant.flatten(),
                           rtol=3e-2,
                           atol=2e-3))