base.py 7.15 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
2
3
4
import gc
import torch
import functools
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
5
from tqdm import tqdm
Casper Hansen's avatar
Casper Hansen committed
6
7
8
9
10
from collections import defaultdict

from awq.utils.calib_data import get_calib_dataset
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
Casper Hansen's avatar
Casper Hansen committed
11
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
Casper Hansen's avatar
Casper Hansen committed
12

Casper Hansen's avatar
Casper Hansen committed
13
14
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
Casper Hansen's avatar
Casper Hansen committed
15

Casper's avatar
Casper committed
16
class BaseAWQForCausalLM:
Casper Hansen's avatar
Casper Hansen committed
17
18
    
    @torch.no_grad()
Casper Hansen's avatar
Casper Hansen committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    def quantize(self, model, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512,
                       auto_scale=True, mse_range=True, run_search=False, run_quant=True,
                       calib_data="pileval", init_only=False):
        
        if run_search:
            self._awq_search(model, tokenizer, w_bit, q_config, n_samples=n_samples, seqlen=seqlen,
                       auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
        
        if run_quant:
            self._awq_quant(model, w_bit, q_config, init_only)
    
    
    def _awq_quant(self, model, w_bit, q_config, init_only):
        assert q_config["zero_point"], "We only support zero_point quantization now."
        layers = self.get_model_layers(model)
Casper's avatar
Casper committed
34

Casper Hansen's avatar
Casper Hansen committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        # Run AWQ quantization
        for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
            layer = layers[i]
            named_linears = get_named_linears(layer)
            
            if not isinstance(layer.ffn.act, ScaledActivation):
                param = next(layer.parameters())

                # get activation scale
                scale_dict = self.get_act_for_scaling(layer)
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)

                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)

            for name, module in named_linears.items():
                if init_only:
                    q_linear = WQLinear.from_linear(
                        module, w_bit, q_config['q_group_size'], True)
                    q_linear.to(next(layer.parameters()).device)
                    set_op_by_name(layer, name, q_linear)
                else:
                    module.cuda()
                    module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
                    scales = scales.t().contiguous()
                    zeros = zeros.t().contiguous()
                    q_linear = WQLinear.from_linear(
                        module, w_bit, q_config['q_group_size'], False, scales, zeros)
                    module.cpu()
                    q_linear.to(next(layer.parameters()).device)
                    set_op_by_name(layer, name, q_linear)
                    torch.cuda.empty_cache()
                    gc.collect()
            
            torch.cuda.empty_cache()
            gc.collect()
    
    def _awq_search(self, model, tokenizer, w_bit, q_config, n_samples=128, seqlen=512,
                       auto_scale=True, mse_range=True, calib_data="pileval"):
Casper Hansen's avatar
Casper Hansen committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        layers = self.get_model_layers(model)

        samples = get_calib_dataset(
            data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

        layers[0] = layers[0].cuda()
        self.move_embed(model, "cuda")
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

            def forward(self, inp, **kwargs):
                inps.append(inp)
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
        layers[0] = Catcher(layers[0])
        try:
            model(samples.to(next(model.parameters()).device))
        except ValueError:  # work with early exit
            pass
        del samples
        layers[0] = layers[0].module  # restore
        inps = inps[0]

        layers[0] = layers[0].cpu()
        self.move_embed(model, "cpu")
        
        gc.collect()
        torch.cuda.empty_cache()
        awq_results = {
            "scale": [],
            "clip": [],
        }

Casper Hansen's avatar
Casper Hansen committed
120
121
        # Run AWQ search layer by layer
        for i in tqdm(range(len(layers)), desc="AWQ Search:"):
Casper Hansen's avatar
Casper Hansen committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            layer = layers[i]
            layer = layer.cuda()
            named_linears = get_named_linears(layer)

            # firstly, get input features of all linear layers
            def cache_input_hook(m, x, y, name, feat_dict):
                x = x[0]
                x = x.detach().cpu()
                feat_dict[name].append(x)

            input_feat = defaultdict(list)
            handles = []
            for name in named_linears:
                handles.append(named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name,
                                    feat_dict=input_feat)))
            inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
            # get output as next layer's input
            inps = layer(inps, **layer_kwargs)[0]
            for h in handles:
                h.remove()
            # now solve for scaling and clipping
            input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

            # Clear GPU memory
            torch.cuda.empty_cache()

            if auto_scale:  # if it applies, we should also modify the input_feat with scales
                scales_list = auto_scale_block(
                    self,
                    layer, layer_kwargs,
                    w_bit=w_bit, q_config=q_config,
                    input_feat=input_feat,
                )
                # apply_scale(layer, scales_list, input_feat_dict=input_feat)
                apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
                # append prefix to make names global
                awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".")

            # Clear GPU memory
            torch.cuda.empty_cache()
            
            if mse_range:
                clip_list = auto_clip_block(layer,
                                w_bit=w_bit, q_config=q_config,
                                input_feat=input_feat,)
                apply_clip(layer, clip_list)
                # append prefix to make names global
                awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".")

            layer = layer.cpu()
            # Haotian: check activation replacement
            del input_feat
            gc.collect()
            torch.cuda.empty_cache()
Casper Hansen's avatar
Casper Hansen committed
177
        
Casper Hansen's avatar
Casper Hansen committed
178
        return awq_results
Casper's avatar
Casper committed
179
180
181
182
183
184
185
186
187

    def save_quantized():
        pass

    def from_pretrained():
        pass

    def from_quantized():
        pass