quantizer.py 1.57 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
import torch
2
    
Ji Lin's avatar
Ji Lin committed
3
# core quantization method (simulated quantization)
4
5
6
def pseudo_quantize_tensor(w, w_bit=4,
                           zero_point=True, 
                           q_group_size=-1,
Ji Lin's avatar
Ji Lin committed
7
8
9
10
11
12
13
14
15
16
17
                           inplace=False,
                           get_scale_zp=False
                           ):
    org_w_shape = w.shape
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)
    assert w.dim() == 2
    if zero_point:
        max_val = w.amax(dim=1, keepdim=True)
        min_val = w.amin(dim=1, keepdim=True)
18
        max_int = 2 ** w_bit - 1
Ji Lin's avatar
Ji Lin committed
19
20
21
22
23
24
25
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
    else:  # we actually never used this
        assert min_val is None
        max_val = w.abs().amax(dim=1, keepdim=True)
        max_val = max_val.clamp(min=1e-5)
26
27
        max_int = 2 ** (w_bit - 1) - 1
        min_int = - 2 ** (w_bit - 1)
Ji Lin's avatar
Ji Lin committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        scales = max_val / max_int
        zeros = 0

    assert torch.isnan(scales).sum() == 0
    assert torch.isnan(w).sum() == 0

    if inplace:
        ((w.div_(scales).round_().add_(zeros)).clamp_(
            min_int, max_int).sub_(zeros)).mul_(scales)
    else:
        w = (torch.clamp(torch.round(w / scales) +
                         zeros, min_int, max_int) - zeros) * scales
    assert torch.isnan(w).sum() == 0

    w = w.reshape(org_w_shape)

    if get_scale_zp:
        return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
    else:
        return w