auto_clip.py 3.45 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
2
3
4
5
6
7
8
9
10
import torch
import torch.nn as nn
from .quantizer import pseudo_quantize_tensor
import gc

__all__ = ["auto_clip_block"]


# weight quantization
@torch.no_grad()
11
12
13
def auto_clip_layer(w, 
                    input_feat, 
                    quant_config,
Ji Lin's avatar
Ji Lin committed
14
15
16
17
18
19
20
                    n_grid=20,
                    max_shrink=0.5,
                    n_sample_token=512):
    assert w.dim() == 2
    org_w_shape = w.shape
    # w           [co, ci]      -> [co, 1, n_group, group size]
    # input_feat  [n_token, ci] -> [1, n_token, n_group, group size]
21
    group_size = quant_config["q_group_size"] if quant_config["q_group_size"] > 0 else w.shape[1]
Ji Lin's avatar
Ji Lin committed
22
23
24
25
26
    input_feat = input_feat.view(-1, input_feat.shape[-1])
    input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
    input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
    w = w.reshape(w.shape[0], 1, -1, group_size)

27
    oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64  # prevent OOM
Ji Lin's avatar
Ji Lin committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    assert w.shape[0] % oc_batch_size == 0
    w_all = w
    best_max_val_all = []

    for i_b in range(w.shape[0] // oc_batch_size):
        w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]

        org_max_val = w.abs().amax(dim=-1, keepdim=True)  # co, 1, n_group, 1

        best_max_val = org_max_val.clone()
        min_errs = torch.ones_like(org_max_val) * 1e9
        input_feat = input_feat.to(w.device)
        org_out = (input_feat * w).sum(dim=-1)  # co, n_token, n_group

        for i_s in range(int(max_shrink * n_grid)):
            max_val = org_max_val * (1 - i_s / n_grid)
            min_val = - max_val
            cur_w = torch.clamp(w, min_val, max_val)
46
            q_w = pseudo_quantize_tensor(cur_w, **quant_config)
Ji Lin's avatar
Ji Lin committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
            cur_out = (input_feat * q_w).sum(dim=-1)

            # co, 1, n_group, 1
            err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
            del cur_w
            del cur_out
            cur_best_idx = err < min_errs
            min_errs[cur_best_idx] = err[cur_best_idx]
            best_max_val[cur_best_idx] = max_val[cur_best_idx]
        best_max_val_all.append(best_max_val)

    best_max_val = torch.cat(best_max_val_all, dim=0)

    del input_feat
    del org_out
    gc.collect()
    torch.cuda.empty_cache()
    return best_max_val.squeeze(1)


@torch.no_grad()
def auto_clip_block(module,
69
                    quant_config,
Ji Lin's avatar
Ji Lin committed
70
71
72
73
74
75
76
77
                    input_feat):

    named_linears = {name: m for name,
                     m in module.named_modules() if isinstance(m, nn.Linear)}

    clip_list = []
    for name in named_linears:
        # due to qk bmm, it is hard to clip precisely
78
        if any([_ in name for _ in ["q_", "k_", "query", "key", "Wqkv"]]):
Ji Lin's avatar
Ji Lin committed
79
            continue
80
        named_linears[name].cuda()
Ji Lin's avatar
Ji Lin committed
81
        max_val = auto_clip_layer(
82
            named_linears[name].weight, input_feat[name], quant_config=quant_config)
Ji Lin's avatar
Ji Lin committed
83
        clip_list.append((name, max_val))
84
        named_linears[name].cpu()
Ji Lin's avatar
Ji Lin committed
85
86
87
88
89
90
91
92
    return clip_list


@torch.no_grad()
def apply_clip(module, clip_list):
    from ..utils.module import get_op_by_name
    for name, max_val in clip_list:
        layer = get_op_by_name(module, name)
93
        layer.cuda()
Ji Lin's avatar
Ji Lin committed
94
95
96
97
98
        max_val = max_val.to(layer.weight.device)
        org_shape = layer.weight.shape
        layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
        layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
        layer.weight.data = layer.weight.data.reshape(org_shape)
99
        layer.cpu()