linear.py 4.39 KB
Newer Older
pppppM's avatar
pppppM committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Type, TypeVar

import torch
from torch import nn


class WeightOnlyQLinear(nn.Module):
    """This class implements weight only quantization linear.

    Args:
        w_bit (int): number of bits for quantization.
        symmetry (bool): If true, use symmetric quantization,
            otherwise use asymmetric quantization.
        group_size (int): size of the quantization group.
        in_features (int): size of each input sample.
        out_features (int): size of each output sample.
        bias (Tensor, optional): Defaults to None.
    """

    def __init__(self,
                 w_bit: int,
                 symmetry: bool,
                 group_size: int,
                 in_features: int,
                 out_features: int,
                 bias: Optional[torch.Tensor] = None) -> None:
        super().__init__()

        if w_bit not in [2, 4, 8]:
            raise NotImplementedError('Only 2,4,8 bit are supported for now.')

        self.in_features = in_features
        self.out_features = out_features
        self.w_bit = w_bit
        self.group_size = group_size if group_size != -1 else in_features

        assert self.in_features % self.group_size == 0
        assert out_features % (32 // self.w_bit) == 0

        w_pack_oc = out_features // (32 // self.w_bit)
        w_inc = in_features
        weight = torch.zeros((w_inc, w_pack_oc), dtype=torch.int32)
        self.register_buffer('qweight', weight)

        if bias:
            self.register_buffer('bias', torch.zeros(out_features))
        else:
            self.bias = None

        s_inc = in_features // self.group_size
        s_oc = out_features
        scales = torch.zeros((s_inc, s_oc), dtype=torch.float16)
        self.register_buffer('scales', scales)

        if not symmetry:
            z_inc = in_features // self.group_size
            z_oc = out_features // (32 // self.w_bit)
            zeros = torch.zeros((z_inc, z_oc), dtype=torch.int32)
            self.register_buffer('qzeros', zeros)
        else:
            self.qzeros = None

    @classmethod
    def from_linear(cls: Type['WeightOnlyQLinear'],
                    linear: nn.Linear,
                    quantizer: TypeVar('Quantizer'),
                    awq_layout: bool = True) -> 'WeightOnlyQLinear':
        """Create a WeightOnlyQLinear object from a PyTorch Linear object.

        Args:
            linear (nn.Linear): PyTorch Linear object.
            quantizer (Quantizer): Object that handles quantization.
            awq_layout (bool): AWQ layout. Defaults to True.

        Returns:
            WeightOnlyQLinear: A WeightOnlyQLinear object.
        """
        device = linear.weight.device

        w_bit = quantizer.bits
        pack_num = 32 // w_bit
        if awq_layout:
            assert w_bit == 4
            pack_order = [0, 2, 4, 6, 1, 3, 5, 7]
        else:
            pack_order = torch.arange(pack_num)
        group_size = quantizer.group_size
        symmetry = quantizer.symmetry

        in_features = linear.in_features
        out_features = linear.out_features
        bias = False if linear.bias is None else True

        qlinear = cls(w_bit, symmetry, group_size, in_features, out_features,
                      bias)
        qlinear.bias = linear.bias

        qparams = quantizer.calculate_qparams(linear.weight)
        i32_w = quantizer.quant(linear.weight, qparams, real=True)
        i32_w = i32_w.t().contiguous()

        pack_int_w = torch.zeros_like(qlinear.qweight).to(device)

        for col in range(pack_int_w.shape[1]):
            for i in range(pack_num):
                pack_int_w_col = i32_w[:, col * pack_num + pack_order[i]]
                pack_int_w[:, col] |= pack_int_w_col << (i * w_bit)

        qlinear.qweight = pack_int_w
        qlinear.scales = qparams.scales.squeeze(-1).t().contiguous()

        if qparams.zero_points is not None:
            zeros = qparams.zero_points.to(torch.int32).to(device)
            zeros = zeros.squeeze(-1).t().contiguous()
            pack_int_zeros = torch.zeros_like(qlinear.qzeros).to(device)

            for col in range(pack_int_zeros.shape[1]):
                for i in range(pack_num):
                    qzero_col = zeros[:, col * pack_num + pack_order[i]]
                    pack_int_zeros[:, col] |= qzero_col << (i * w_bit)
            qlinear.qzeros = pack_int_zeros

        qlinear.to('cpu')

        return qlinear