"vscode:/vscode.git/clone" did not exist on "4ff719ad57b6c16bbb7d104dbeb3c3c9cfa789cb"
quantizer.py 5.61 KB
Newer Older
pppppM's avatar
pppppM committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict, Optional

import torch

from lmdeploy.lite.utils import (QParams, cal_qparams_per_channel_absmax,
                                 cal_qparams_per_channel_minmax,
                                 cal_qparams_per_group_absmax,
                                 cal_qparams_per_group_minmax,
                                 cal_qparams_per_tensor_absmax,
                                 cal_qparams_per_tensor_minmax)
from lmdeploy.lite.utils.global_avail import GlobalAvailMixin


class WeightQuantizer(GlobalAvailMixin):
    """A class for performing weight quantization of neural networks.

    The WeightQuantizer class provides various methods to quantize the weights
    of a neural network. This helps in reducing the memory requirements and
    computational complexity of the model, potentially offering faster
    inference and lower power consumption.

    Attributes:
        bits (int): The bit width for quantization.
        symmetry (bool): If True, use absmax scaling; if False,
            use min-max scaling.
        granularity (str): The granularity of quantization. Available options
            are 'per_channel', 'per_tensor', and 'per_group'.
        group_size (Optional[int]): If using 'per_group' quantization, this is
            the number of channels in each group.

    Example:

        # Instantiate the weight quantizer with specific quantization settings
        quantizer = WeightQuantizer(bits=8,
                                     symmetry=True,
                                     granularity='per_tensor')

        # Calculate the quantization parameters for given weights
        qparams = quantizer.calculate_qparams(weights)

        # Perform fake quantization on the weights
        quantized_weights = quantizer.fake_quant(weights, qparams)
    """

    CAL_FUNC_MAP: Dict[str, Dict[str, Callable]] = {
        'per_group': {
            'absmax': cal_qparams_per_group_absmax,
            'minmax': cal_qparams_per_group_minmax,
        },
        'per_channel': {
            'absmax': cal_qparams_per_channel_absmax,
            'minmax': cal_qparams_per_channel_minmax,
        },
        'per_tensor': {
            'absmax': cal_qparams_per_tensor_absmax,
            'minmax': cal_qparams_per_tensor_minmax,
        },
    }

    def __init__(self,
                 bits: int,
                 symmetry: bool,
                 granularity: str,
                 group_size: Optional[int] = -1):

        assert bits in [4, 8], "The 'bits' argument must be either 4 or 8."
        self.bits = bits

        if granularity not in ['per_channel', 'per_tensor', 'per_group']:
            raise NotImplementedError(
                "The 'granularity' argument must be one of 'per_channel', "
                "'per_tensor', or 'per_group'.")

        self.granularity = granularity

        if self.granularity == 'per_group':
            assert group_size > 0, \
                "The 'group_size' argument must be greater than 0."

        self.group_size = group_size

        # If symmetry is True, use absmax to compute scales
        # If symmetry is False, use minmax to compute scales and zeor-points
        self.symmetry = symmetry
        self.observer = 'absmax' if symmetry else 'minmax'

    def calculate_qparams(self, weight: torch.Tensor) -> QParams:
        """Calculate the quantization parameters for the given weight tensor.

        Args:
            weight (torch.Tensor): The weight tensor with shape
                (out_features, in_features).

        Returns:
            QParams: A namedtuple containing 'scales' and 'zero_points'.
        """

        cal_func = self.CAL_FUNC_MAP[self.granularity][self.observer]
        if self.granularity == 'per_group':
            return cal_func(weight, self.bits, self.group_size)
        else:
            return cal_func(weight, self.bits)

    def quant(self,
              weight: torch.Tensor,
              qparams: Optional[QParams] = None,
              real: bool = False) -> torch.Tensor:
        """Perform fake quantization on the given weight tensor.

        Args:
            weight (torch.Tensor): The weight tensor with shape
                (out_features, in_features).
            qparams (Optional[QParams]): A namedtuple containing 'scales'
                and 'zero_points'.
            real (bool): If True, return the tensor with quantized type.

        Returns:
            torch.Tensor: The fake quantized weight tensor.
        """

        if qparams is None:
            qparams = self.calculate_qparams(weight)

        scales = qparams.scales
        zero_points = qparams.zero_points

        out_c, in_c = weight.shape

        # Reshape the weights if using per_group quantization
        # per tensor scales shape: [1]
        # per channel scales shape: [out_c, 1]
        # per group scales shape: [out_c, in_c//group_size, 1]
        if len(scales.shape) > 2:
            # scales shape: [out_c, in_c//group_size, 1]
            weight = weight.reshape(out_c, scales.shape[1], -1)

        if zero_points is None:
            assert self.symmetry
            real_qweight = (weight / scales).round()
            fake_qweight = real_qweight * scales

        else:
            assert not self.symmetry

            real_qweight = (weight / scales).round() + zero_points
            fake_qweight = (real_qweight - zero_points) * scales

        if len(scales.shape) > 2:
            real_qweight = real_qweight.reshape(out_c, in_c)
            fake_qweight = fake_qweight.reshape(out_c, in_c)

        if real:
            return real_qweight.to(torch.int32)
        else:
            return fake_qweight