cal_qparams.py 3.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Copyright (c) OpenMMLab. All rights reserved.
from typing import NamedTuple, Optional

import torch


class QParams(NamedTuple):
    """A class to hold the quantization parameters."""

    scales: torch.Tensor
    zero_points: Optional[torch.Tensor]


@torch.no_grad()
def cal_qparams_per_channel_absmax(w: torch.Tensor, n_bits: int) -> QParams:
    """Calculate quantization parameters for each channel using absolute max
    value."""

    scales = w.abs().max(dim=-1, keepdim=True)[0]
    q_max = 2**(n_bits - 1) - 1
    scales = scales.clamp_(min=1e-5).div_(q_max)
    return QParams(scales=scales, zero_points=None)


@torch.no_grad()
def cal_qparams_per_channel_minmax(w: torch.Tensor, n_bits: int) -> QParams:
    """Calculate quantization parameters for each channel using min and max
    values."""

    w_min = w.min(dim=-1, keepdim=True)[0]
    w_max = w.max(dim=-1, keepdim=True)[0]

    q_max = 2**n_bits - 1
    scales = (w_max - w_min)
    scales = scales.clamp_(min=1e-5).div_(q_max)

    zero_points = (-w_min / scales).round()

    return QParams(scales=scales, zero_points=zero_points)


@torch.no_grad()
def cal_qparams_per_group_absmax(w: torch.Tensor, n_bits: int,
                                 group_size: int) -> QParams:
    """Calculate quantization parameters for each group using absolute max
    value."""

    outc, inc = w.shape
    assert inc >= group_size, \
        'Input channels should be greater than or equal to group_size.'
    assert inc % group_size == 0, \
        'Input channels should be divisible by group_size.'
    scales = w.abs().reshape(outc, -1, group_size).max(dim=-1, keepdim=True)[0]
    q_max = 2**(n_bits - 1) - 1
    scales = scales.clamp_(min=1e-5).div_(q_max)
    return QParams(scales=scales, zero_points=None)


@torch.no_grad()
def cal_qparams_per_group_minmax(w: torch.Tensor, n_bits: int,
                                 group_size: int) -> QParams:
    """Calculate quantization parameters for each group using min and max
    values."""

    outc, inc = w.shape
    assert inc >= group_size, \
        'Input channels should be greater than or equal to group_size.'
    assert inc % group_size == 0, \
        'Input channels should be divisible by group_size.'
    w_group_wise = w.reshape(outc, -1, group_size)
    w_min = w_group_wise.min(dim=-1, keepdim=True)[0]
    w_max = w_group_wise.max(dim=-1, keepdim=True)[0]

    q_max = 2**n_bits - 1
    scales = (w_max - w_min)
    scales = scales.clamp_(min=1e-5).div_(q_max)
    zero_points = (-w_min / scales).round()
    return QParams(scales=scales, zero_points=zero_points)


@torch.no_grad()
def cal_qparams_per_tensor_minmax(w: torch.Tensor, n_bits: int) -> QParams:
    """Calculate quantization parameters for the entire tensor using min and
    max values."""

    w_min = w.min()
    w_max = w.max()

    q_max = 2**n_bits - 1
    scales = (w_max - w_min)
    scales = scales.clamp_(min=1e-5).div_(q_max)
    zero_points = (-w_min / scales).round()
    return QParams(scales=scales, zero_points=zero_points)


@torch.no_grad()
def cal_qparams_per_tensor_absmax(w: torch.Tensor, n_bits: int) -> QParams:
    """Calculate quantization parameters for the entire tensor using absolute
    max value."""

    scales = w.abs().max()
    q_max = 2**(n_bits - 1) - 1
    scales = scales.clamp_(min=1e-5).div_(q_max)
    return QParams(scales=scales, zero_points=None)