cal_qparams.py 4.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright (c) OpenMMLab. All rights reserved.
from typing import NamedTuple, Optional

import torch


class QParams(NamedTuple):
    """A class to hold the quantization parameters."""

    scales: torch.Tensor
    zero_points: Optional[torch.Tensor]


@torch.no_grad()
pppppM's avatar
pppppM committed
15
16
17
def cal_qparams_per_channel_absmax(w: torch.Tensor,
                                   n_bits: int,
                                   return_stats: bool = False) -> QParams:
18
19
20
    """Calculate quantization parameters for each channel using absolute max
    value."""

pppppM's avatar
pppppM committed
21
    absmax = w.abs().max(dim=-1, keepdim=True)[0]
22
    q_max = 2**(n_bits - 1) - 1
pppppM's avatar
pppppM committed
23
24
25
26
27
28
    scales = absmax.clamp(min=1e-5).div(q_max)

    if return_stats:
        return QParams(scales=scales, zero_points=None), absmax
    else:
        return QParams(scales=scales, zero_points=None)
29
30
31


@torch.no_grad()
pppppM's avatar
pppppM committed
32
33
34
def cal_qparams_per_channel_minmax(w: torch.Tensor,
                                   n_bits: int,
                                   return_stats: bool = False) -> QParams:
35
36
37
38
39
40
41
42
43
44
45
46
    """Calculate quantization parameters for each channel using min and max
    values."""

    w_min = w.min(dim=-1, keepdim=True)[0]
    w_max = w.max(dim=-1, keepdim=True)[0]

    q_max = 2**n_bits - 1
    scales = (w_max - w_min)
    scales = scales.clamp_(min=1e-5).div_(q_max)

    zero_points = (-w_min / scales).round()

pppppM's avatar
pppppM committed
47
48
49
50
    if return_stats:
        return QParams(scales=scales, zero_points=zero_points), (w_min, w_max)
    else:
        return QParams(scales=scales, zero_points=zero_points)
51
52
53


@torch.no_grad()
pppppM's avatar
pppppM committed
54
55
56
57
def cal_qparams_per_group_absmax(w: torch.Tensor,
                                 n_bits: int,
                                 group_size: int,
                                 return_stats: bool = False) -> QParams:
58
59
60
61
62
63
64
65
    """Calculate quantization parameters for each group using absolute max
    value."""

    outc, inc = w.shape
    assert inc >= group_size, \
        'Input channels should be greater than or equal to group_size.'
    assert inc % group_size == 0, \
        'Input channels should be divisible by group_size.'
pppppM's avatar
pppppM committed
66
    absmax = w.abs().reshape(outc, -1, group_size).max(dim=-1, keepdim=True)[0]
67
    q_max = 2**(n_bits - 1) - 1
pppppM's avatar
pppppM committed
68
69
70
71
72
    scales = absmax.clamp(min=1e-5).div(q_max)
    if return_stats:
        return QParams(scales=scales, zero_points=None), absmax
    else:
        return QParams(scales=scales, zero_points=None)
73
74
75


@torch.no_grad()
pppppM's avatar
pppppM committed
76
77
78
79
def cal_qparams_per_group_minmax(w: torch.Tensor,
                                 n_bits: int,
                                 group_size: int,
                                 return_stats: bool = False) -> QParams:
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    """Calculate quantization parameters for each group using min and max
    values."""

    outc, inc = w.shape
    assert inc >= group_size, \
        'Input channels should be greater than or equal to group_size.'
    assert inc % group_size == 0, \
        'Input channels should be divisible by group_size.'
    w_group_wise = w.reshape(outc, -1, group_size)
    w_min = w_group_wise.min(dim=-1, keepdim=True)[0]
    w_max = w_group_wise.max(dim=-1, keepdim=True)[0]

    q_max = 2**n_bits - 1
    scales = (w_max - w_min)
    scales = scales.clamp_(min=1e-5).div_(q_max)
    zero_points = (-w_min / scales).round()
pppppM's avatar
pppppM committed
96
97
98
99
    if return_stats:
        return QParams(scales=scales, zero_points=zero_points), (w_min, w_max)
    else:
        return QParams(scales=scales, zero_points=zero_points)
100
101
102


@torch.no_grad()
pppppM's avatar
pppppM committed
103
104
105
def cal_qparams_per_tensor_minmax(w: torch.Tensor,
                                  n_bits: int,
                                  return_stats: bool = False) -> QParams:
106
107
108
109
110
111
112
113
114
115
    """Calculate quantization parameters for the entire tensor using min and
    max values."""

    w_min = w.min()
    w_max = w.max()

    q_max = 2**n_bits - 1
    scales = (w_max - w_min)
    scales = scales.clamp_(min=1e-5).div_(q_max)
    zero_points = (-w_min / scales).round()
pppppM's avatar
pppppM committed
116
117
118
119
    if return_stats:
        return QParams(scales=scales, zero_points=zero_points), (w_min, w_max)
    else:
        return QParams(scales=scales, zero_points=zero_points)
120
121
122


@torch.no_grad()
pppppM's avatar
pppppM committed
123
124
125
def cal_qparams_per_tensor_absmax(w: torch.Tensor,
                                  n_bits: int,
                                  return_stats: bool = False) -> QParams:
126
127
    """Calculate quantization parameters for the entire tensor using absolute
    max value."""
pppppM's avatar
pppppM committed
128
    absmax = w.abs().max()
129
    q_max = 2**(n_bits - 1) - 1
pppppM's avatar
pppppM committed
130
131
132
133
134
135
    scales = absmax.clamp(min=1e-5).div(q_max)

    if return_stats:
        return QParams(scales=scales, zero_points=None), absmax
    else:
        return QParams(scales=scales, zero_points=None)