"vscode:/vscode.git/clone" did not exist on "2a413829f42b8e8433a3e7cfd91cc9cb241cfbc0"
quantizer.py 2.34 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
2
3
4
import torch
from tqdm import tqdm
EMBEDDING_KEYWORDS = ["embed"]
LM_HEAD_KEYWORDS = ["lm_head", "embed_out", "output"]
5
    
Ji Lin's avatar
Ji Lin committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8,
                           zero_point=True, q_group_size=-1,
                           inplace=False,
                           get_scale_zp=False
                           ):
    org_w_shape = w.shape
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)
    assert w.dim() == 2
    if zero_point:
        max_val = w.amax(dim=1, keepdim=True)
        min_val = w.amin(dim=1, keepdim=True)
        max_int = 2 ** n_bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
    else:  # we actually never used this
        assert min_val is None
        max_val = w.abs().amax(dim=1, keepdim=True)
        max_val = max_val.clamp(min=1e-5)
        max_int = 2 ** (n_bit - 1) - 1
        min_int = - 2 ** (n_bit - 1)
        scales = max_val / max_int
        zeros = 0

    assert torch.isnan(scales).sum() == 0
    assert torch.isnan(w).sum() == 0

    if inplace:
        ((w.div_(scales).round_().add_(zeros)).clamp_(
            min_int, max_int).sub_(zeros)).mul_(scales)
    else:
        w = (torch.clamp(torch.round(w / scales) +
                         zeros, min_int, max_int) - zeros) * scales
    assert torch.isnan(w).sum() == 0

    w = w.reshape(org_w_shape)

    if get_scale_zp:
        return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
    else:
        return w

@torch.no_grad()
def pseudo_quantize_model_weight(
    model, w_bit, q_config,
):    
    from .pre_quant import get_blocks, get_named_linears
    layers = get_blocks(model)
    for i in tqdm(range(len(layers)), desc="pseudo weight quantization..."):
        named_linears = get_named_linears(layers[i])
        for n, m in named_linears.items():
60
            m.cuda()
Ji Lin's avatar
Ji Lin committed
61
            m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, **q_config)
62
            m.cpu()
Ji Lin's avatar
Ji Lin committed
63
64
65


@torch.no_grad()
Casper Hansen's avatar
Casper Hansen committed
66
67
68
def real_quantize_model_weight(model, awq_model):
    layers = awq_model.get_model_layers(model)
    for i in tqdm(range(len(layers)), desc="real weight quantization..."):
Ji Lin's avatar
Ji Lin committed
69
        layer = layers[i]
Casper Hansen's avatar
Casper Hansen committed
70
        del layer