base.py 9.12 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
2
3
4
import gc
import torch
import functools
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
5
from tqdm import tqdm
Casper Hansen's avatar
Casper Hansen committed
6
7
8
from collections import defaultdict

from awq.utils.calib_data import get_calib_dataset
9
10
11
from transformers import AutoModelForCausalLM, AutoConfig
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
Casper Hansen's avatar
Casper Hansen committed
12
13
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
14
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
Casper Hansen's avatar
Casper Hansen committed
15
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
Casper Hansen's avatar
Casper Hansen committed
16

Casper's avatar
Casper committed
17
class BaseAWQForCausalLM:
18
19
20
21
22
    def __init__(self, model, model_type, is_quantized):
        self.model = model
        self.model_type = model_type
        self.is_quantized = is_quantized

Casper Hansen's avatar
Casper Hansen committed
23
    @torch.no_grad()
Casper Hansen's avatar
Casper Hansen committed
24
25
    def quantize(self, model, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512,
                       auto_scale=True, mse_range=True, run_search=False, run_quant=True,
Casper Hansen's avatar
Casper Hansen committed
26
                       calib_data="pileval"):
27
28
        search_result = None

Casper Hansen's avatar
Casper Hansen committed
29
        if run_search:
30
            search_result = self._awq_search(model, tokenizer, w_bit, q_config, n_samples=n_samples, seqlen=seqlen,
Casper Hansen's avatar
Casper Hansen committed
31
32
33
                       auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
        
        if run_quant:
Casper Hansen's avatar
Casper Hansen committed
34
            self._awq_quant(model, w_bit, q_config)
35
36
        
        return search_result
Casper Hansen's avatar
Casper Hansen committed
37
38
    
    
Casper Hansen's avatar
Casper Hansen committed
39
    def _awq_quant(self, model, w_bit, q_config):
Casper Hansen's avatar
Casper Hansen committed
40
41
        assert q_config["zero_point"], "We only support zero_point quantization now."
        layers = self.get_model_layers(model)
Casper's avatar
Casper committed
42

Casper Hansen's avatar
Casper Hansen committed
43
44
45
46
        # Run AWQ quantization
        for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
            layer = layers[i]
            named_linears = get_named_linears(layer)
47
            self._scale_activations(self, layer)
Casper Hansen's avatar
Casper Hansen committed
48
49

            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
50
51
52
53
54
55
56
57
58
59
60
                module.cuda()
                module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
                scales = scales.t().contiguous()
                zeros = zeros.t().contiguous()
                q_linear = WQLinear.from_linear(
                    module, w_bit, q_config['q_group_size'], False, scales, zeros)
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
Casper Hansen's avatar
Casper Hansen committed
61
62
63
64
65
66
            
            torch.cuda.empty_cache()
            gc.collect()
    
    def _awq_search(self, model, tokenizer, w_bit, q_config, n_samples=128, seqlen=512,
                       auto_scale=True, mse_range=True, calib_data="pileval"):
Casper Hansen's avatar
Casper Hansen committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        layers = self.get_model_layers(model)

        samples = get_calib_dataset(
            data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

        layers[0] = layers[0].cuda()
        self.move_embed(model, "cuda")
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

            def forward(self, inp, **kwargs):
                inps.append(inp)
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
        layers[0] = Catcher(layers[0])
        try:
            model(samples.to(next(model.parameters()).device))
        except ValueError:  # work with early exit
            pass
        del samples
        layers[0] = layers[0].module  # restore
        inps = inps[0]

        layers[0] = layers[0].cpu()
        self.move_embed(model, "cpu")
        
        gc.collect()
        torch.cuda.empty_cache()
        awq_results = {
            "scale": [],
            "clip": [],
        }

Casper Hansen's avatar
Casper Hansen committed
112
        # Run AWQ search layer by layer
113
        for i in tqdm(range(len(layers)), desc="AWQ Search"):
Casper Hansen's avatar
Casper Hansen committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
            layer = layers[i]
            layer = layer.cuda()
            named_linears = get_named_linears(layer)

            # firstly, get input features of all linear layers
            def cache_input_hook(m, x, y, name, feat_dict):
                x = x[0]
                x = x.detach().cpu()
                feat_dict[name].append(x)

            input_feat = defaultdict(list)
            handles = []
            for name in named_linears:
                handles.append(named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name,
                                    feat_dict=input_feat)))
            inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
            # get output as next layer's input
            inps = layer(inps, **layer_kwargs)[0]
            for h in handles:
                h.remove()
            # now solve for scaling and clipping
            input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

            # Clear GPU memory
            torch.cuda.empty_cache()

            if auto_scale:  # if it applies, we should also modify the input_feat with scales
                scales_list = auto_scale_block(
                    self,
                    layer, layer_kwargs,
                    w_bit=w_bit, q_config=q_config,
                    input_feat=input_feat,
                )
                # apply_scale(layer, scales_list, input_feat_dict=input_feat)
                apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
                # append prefix to make names global
                awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".")

            # Clear GPU memory
            torch.cuda.empty_cache()
            
            if mse_range:
                clip_list = auto_clip_block(layer,
                                w_bit=w_bit, q_config=q_config,
                                input_feat=input_feat,)
                apply_clip(layer, clip_list)
                # append prefix to make names global
                awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".")

            layer = layer.cpu()
            # Haotian: check activation replacement
            del input_feat
            gc.collect()
            torch.cuda.empty_cache()
Casper Hansen's avatar
Casper Hansen committed
169
        
Casper Hansen's avatar
Casper Hansen committed
170
        return awq_results
Casper's avatar
Casper committed
171
172
173
174

    def save_quantized():
        pass

175
176
177
178
179
180
181
182
183
184
185
186
187
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, trust_remote_code=True):
        # Load config
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)

        # Load empty weights
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.float16, trust_remote_code=True)
        
        # Load model weights
        model = load_checkpoint_and_dispatch(model, model_path, device_map="balanced", no_split_module_classes=[self.layer_type])

        return self(model, model_type, is_quantized=False)
Casper's avatar
Casper committed
188

189
    @classmethod
190
191
192
193
194
195
196
197
198
199
200
201
202
    def from_quantized(self, model_path, quant_path, w_bit, q_config, device, trust_remote_code=True):
        # Load config
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)

        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.float16, trust_remote_code=True)

        # Initialize layers
        assert q_config["zero_point"], "We only support zero_point quantization now."
        layers = self.get_model_layers(model)
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
            named_linears = get_named_linears(layer)
203
            self._scale_activations(self, layer)
204
205
206
207
208
209
210
211
212
213
214
215

            for name, module in named_linears.items():
                q_linear = WQLinear.from_linear(
                    module, w_bit, q_config['q_group_size'], True)
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
        
        model.tie_weights()

216
        model = load_checkpoint_and_dispatch(model, quant_path, device_map="balanced", no_split_module_classes=[self.layer_type])
217
218
219

        return model
    
220
    @staticmethod
221
222
223
224
225
226
227
228
229
230
231
232
233
    def _scale_activations(self, layer):
        act_function = self.get_act_from_layer(layer)

        if act_function is not None and not isinstance(act_function, ScaledActivation):
            param = next(layer.parameters())

            # get activation scale
            scale_dict = self.get_act_for_scaling(layer)
            scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)

            # scale activation
            scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
            set_op_by_name(layer, scale_dict['scale_name'], scaled_act)