quantizer.py 16.8 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
import torch
2
import inspect
Casper's avatar
Casper committed
3
4
5
6
import logging
import functools
import torch.nn as nn
from tqdm import tqdm
Vik Paruchuri's avatar
Vik Paruchuri committed
7
from typing import Dict, List
Casper's avatar
Casper committed
8
9
10
from collections import defaultdict
from awq.utils.utils import clear_memory
from awq.utils.calib_data import get_calib_dataset
Casper Hansen's avatar
Casper Hansen committed
11
from awq.quantize.scale import apply_scale, apply_clip
Casper's avatar
Casper committed
12
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
13
14
15
16
17
18
19
from awq.utils.module import (
    append_str_prefix,
    get_op_name,
    get_named_linears,
    set_op_by_name,
    exclude_layers_to_not_quantize
)
Casper's avatar
Casper committed
20
21
22


class AwqQuantizer:
23
    def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version, 
24
                       calib_data, split, text_column, duo_scaling, modules_to_not_convert=None) -> None:
Casper Hansen's avatar
Casper Hansen committed
25
        self.awq_model = awq_model
Casper's avatar
Casper committed
26
27
28
29
30
31
32
33
        self.model = model
        self.tokenizer = tokenizer
        self.w_bit = w_bit
        self.group_size = group_size
        self.version = version
        self.calib_data = calib_data
        self.split = split
        self.text_column = text_column
34
        self.duo_scaling = duo_scaling
35
        self.modules_to_not_convert = modules_to_not_convert if modules_to_not_convert is not None else []
Casper Hansen's avatar
Casper Hansen committed
36
        self.modules, self.module_kwargs, self.inps = self.init_quant()
37
    
Casper's avatar
Casper committed
38
39
40
41
42
43
44
45
    def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False):
        org_w_shape = w.shape
        if self.group_size > 0:
            assert org_w_shape[-1] % self.group_size == 0
            w = w.reshape(-1, self.group_size)
        assert w.dim() == 2

        # zero point quantization
Ji Lin's avatar
Ji Lin committed
46
47
        max_val = w.amax(dim=1, keepdim=True)
        min_val = w.amin(dim=1, keepdim=True)
Casper's avatar
Casper committed
48
        max_int = 2 ** self.w_bit - 1
Ji Lin's avatar
Ji Lin committed
49
50
51
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
Casper's avatar
Casper committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(w).sum() == 0

        w = (torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros) * scales
        assert torch.isnan(w).sum() == 0

        w = w.reshape(org_w_shape)

        if get_scale_zp:
            return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
        else:
            return w
    
Casper's avatar
Casper committed
66
67
68
69
70
71
72
73
74
75
76
77
78
    def pseudo_dequantize_tensor(self, w: nn.Linear, scales: torch.Tensor, zeros: torch.Tensor):
        # get repeated count
        repeat_count = w.weight.data.shape[-1] // zeros.shape[-1]

        # get zeros and scales in correct shape
        zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape)
        scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape)

        # dequantize
        w = (w.weight.data - zeros) * scales

        return w
    
Casper Hansen's avatar
Casper Hansen committed
79
80
    def quantize(self):
        for i in tqdm(range(len(self.modules)), desc="AWQ"):
Casper's avatar
Casper committed
81
82
83
            # Move module and inputs to correct device
            common_device = next(self.modules[i].parameters()).device
            if common_device is None or str(common_device) == "cpu":
84
                self.modules[i] = self.modules[i].cuda("cuda:" + str(i % torch.cuda.device_count()))
Casper's avatar
Casper committed
85
                common_device = next(self.modules[i].parameters()).device
86
87
88
89
90
91
92

            if self.module_kwargs.get("position_ids") is not None:
                self.module_kwargs["position_ids"] = self.module_kwargs["position_ids"].to(common_device)

            if self.module_kwargs.get("attention_mask") is not None:
                self.module_kwargs["attention_mask"] = self.module_kwargs["attention_mask"].to(common_device)

Casper's avatar
Casper committed
93
94
            self.inps = self.inps.to(common_device)

Casper's avatar
Casper committed
95
96
            # [STEP 1]: Get layer, extract linear modules, extract input features
            named_linears = get_named_linears(self.modules[i])
97
98

            # Filter out the linear layers we don't want to exclude
99
            named_linears = exclude_layers_to_not_quantize(named_linears, self.modules_to_not_convert)
100

Casper's avatar
Casper committed
101
102
103
104
            input_feat = self._get_input_feat(self.modules[i], named_linears)
            clear_memory()

            # [STEP 2]: Compute and apply scale list
Vik Paruchuri's avatar
Vik Paruchuri committed
105
            module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
Casper's avatar
Casper committed
106
107
                self.modules[i], input_feat, self.module_kwargs
            )
Casper Hansen's avatar
Casper Hansen committed
108
            scales_list = [self._search_best_scale(self.modules[i], **layer) for layer in module_config]
Casper's avatar
Casper committed
109
110
111
112
            apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
            scales_list = append_str_prefix(scales_list, get_op_name(self.model, self.modules[i]) + ".")

            # [STEP 3]: Compute and apply clipping list
Casper Hansen's avatar
Casper Hansen committed
113
114
115
            clip_list = self._search_best_clip(self.modules[i], named_linears, input_feat)
            apply_clip(self.modules[i], clip_list)
            clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".")
Casper's avatar
Casper committed
116
117

            # [STEP 4]: Quantize weights
118
119
120
            self._apply_quant(self.modules[i], named_linears)
            clear_memory()
    
Vik Paruchuri's avatar
Vik Paruchuri committed
121
    def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        for name, linear_layer in named_linears.items():
            # NOTE: small regression in perplexity if linear layer uses .cpu().float()
            linear_layer = linear_layer.cuda().half()

            linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
                linear_layer.weight.data, 
                get_scale_zp=True
            )

            if self.version == 'GEMM':
                scales = scales.t().contiguous()
                zeros = zeros.t().contiguous()
                q_linear_module = WQLinear_GEMM

            elif self.version  == 'GEMV':
                q_linear_module = WQLinear_GEMV
Casper's avatar
Casper committed
138
            
139
140
141
142
143
144
145
146
147
148
149
150
            q_linear = q_linear_module.from_linear(
                linear=linear_layer,
                w_bit=self.w_bit,
                group_size=self.group_size,
                init_only=False,
                scales=scales,
                zeros=zeros
            )

            linear_layer.cpu()
            q_linear.to(next(module.parameters()).device)
            set_op_by_name(module, name, q_linear)
Casper's avatar
Casper committed
151
152
153
            clear_memory()

    @torch.no_grad()
Vik Paruchuri's avatar
Vik Paruchuri committed
154
    def _search_best_scale(self, module, prev_op, layers: List[nn.Linear], inp: torch.Tensor, module2inspect=None, kwargs={}):
Casper Hansen's avatar
Casper Hansen committed
155
156
157
158
159
160
161
        if module2inspect is None:
            assert len(layers) == 1
            module2inspect = layers[0]
        
        if "use_cache" in kwargs:
            kwargs.pop("use_cache")
        
Casper's avatar
Casper committed
162
        # Put x on the right device
Casper Hansen's avatar
Casper Hansen committed
163
        inp = inp.to(next(module2inspect.parameters()).device)
Casper's avatar
Casper committed
164
165

        # [STEP 1]: Compute maximum of weight
Casper Hansen's avatar
Casper Hansen committed
166
167
        weight = torch.cat([_m.weight for _m in layers], dim=0)
        org_shape = weight.shape
Casper's avatar
Casper committed
168
        weight = weight.view(-1, self.group_size)
Casper Hansen's avatar
Casper Hansen committed
169
170
171
        w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
        w_scale = w_scale.view(org_shape)
        w_max = w_scale.mean(0)
Casper's avatar
Casper committed
172
173
174
        clear_memory(weight)

        # [STEP 2]: Compute maximum of x
Casper Hansen's avatar
Casper Hansen committed
175
        x_max = inp.abs().view(-1, inp.shape[-1]).mean(0)
Casper's avatar
Casper committed
176

Casper Hansen's avatar
Casper Hansen committed
177
        # [STEP 3]: Compute output of module
Casper's avatar
Casper committed
178
        with torch.no_grad():
179
180
181
            module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)

            fp16_output = module2inspect(inp, **module_kwargs)
182
183
            if isinstance(fp16_output, tuple):
                fp16_output = fp16_output[0]
Casper's avatar
Casper committed
184
185
186
        
        # [STEP 4]: Compute loss
        best_scales = self._compute_best_scale(
Casper Hansen's avatar
Casper Hansen committed
187
            inp, w_max, x_max, module2inspect,
188
            layers, fp16_output, module_kwargs
Casper's avatar
Casper committed
189
190
        )
        
Casper Hansen's avatar
Casper Hansen committed
191
        return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales)
Casper's avatar
Casper committed
192

Vik Paruchuri's avatar
Vik Paruchuri committed
193
    def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: List[nn.Linear],
194
                                  fp16_output, kwargs={}):
Casper's avatar
Casper committed
195
196
197
        """
        Compute loss and select best scales

Casper's avatar
Casper committed
198
        L(s) = || Q(W * s) (s^-1 * X) - W * X ||
Casper's avatar
Casper committed
199
200
201
202
203
204
205
206
207
208
209
        Q: weight quantization function | pseudo_quantize_tensor(W * s)
        X: inputs from calib dataset    | X
        W: original weights in FP16     | layer
        s: per channel scaling factor   | s^-1 * X
        """
        n_grid = 20
        history = []
        best_ratio = -1
        best_scales = None
        best_error = float('inf')

Casper Hansen's avatar
Casper Hansen committed
210
        org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
Casper's avatar
Casper committed
211
212
213
214
215
        
        device = x.device
        x_max = x_max.view(-1).to(device)
        w_max = w_max.view(-1).to(device)
        
Casper's avatar
Casper committed
216
217
        for ratio in range(n_grid):
            # create new scales
Casper's avatar
Casper committed
218
            ratio = ratio / n_grid
219

Casper Hansen's avatar
Casper Hansen committed
220
            # NOTE: s^-1 * x is fused here, according to paper
221
222
223
224
            if self.duo_scaling:
                scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4)
            else:
                scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
Casper's avatar
Casper committed
225
            scales = scales / (scales.max() * scales.min()).sqrt()
Casper's avatar
Casper committed
226
            scales_view = scales.view(1, -1).to(device)
227

Casper Hansen's avatar
Casper Hansen committed
228
            # Q(W * s)
Casper's avatar
Casper committed
229
            for fc in linears2scale:
Casper's avatar
Casper committed
230
231
                fc.weight.mul_(scales_view)
                fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / scales_view
Casper's avatar
Casper committed
232

233
234
235
236
237
            # W * X
            int_w_output = module2inspect(x, **kwargs)
            if isinstance(int_w_output, tuple):
                int_w_output = int_w_output[0]
            
Casper Hansen's avatar
Casper Hansen committed
238
239
            # compute mean squared error (L2 norm)
            loss = (fp16_output - int_w_output).float().pow(2).mean().item() # NOTE: float prevents overflow
Casper's avatar
Casper committed
240
241

            history.append(loss)
Casper's avatar
Casper committed
242
            if loss < best_error:
Casper's avatar
Casper committed
243
244
                best_error = loss
                best_ratio = ratio
Casper's avatar
Casper committed
245
                best_scales = scales.clone()
Casper Hansen's avatar
Casper Hansen committed
246
            module2inspect.load_state_dict(org_sd)
Casper's avatar
Casper committed
247

Casper's avatar
Casper committed
248
249
250
251
252
253
        if best_ratio == -1:
            logging.debug(history)
            raise Exception

        assert torch.isnan(best_scales).sum() == 0, best_scales

Casper Hansen's avatar
Casper Hansen committed
254
        return best_scales.detach().cpu()
Casper's avatar
Casper committed
255

Casper Hansen's avatar
Casper Hansen committed
256
257
258
259
    @torch.no_grad()
    def _search_best_clip(self, layer, named_linears, input_feat):
        clip_list = []
        avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]
Casper's avatar
Casper committed
260

Casper Hansen's avatar
Casper Hansen committed
261
262
263
264
265
266
267
268
269
270
        for name in named_linears:
            # due to qk bmm, it is hard to clip precisely
            if any([_ in name for _ in avoid_clipping]):
                continue

            named_linears[name].cuda()
            max_val = self._compute_best_clip(named_linears[name].weight, input_feat[name])
            clip_list.append((name, max_val))

            named_linears[name].cpu()
Casper Hansen's avatar
Casper Hansen committed
271
272
        
        return clip_list
Casper Hansen's avatar
Casper Hansen committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

    @torch.no_grad()
    def _compute_best_clip(self, w: torch.Tensor, input_feat: torch.Tensor, n_grid=20, max_shrink=0.5, n_sample_token=512):
        assert w.dim() == 2
        org_w_shape = w.shape
        # w           [co, ci]      -> [co, 1, n_group, group size]
        # input_feat  [n_token, ci] -> [1, n_token, n_group, group size]
        group_size = self.group_size if self.group_size > 0 else w.shape[1]
        input_feat = input_feat.view(-1, input_feat.shape[-1])
        input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
        input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
        w = w.reshape(w.shape[0], 1, -1, group_size)

        oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64  # prevent OOM
        assert w.shape[0] % oc_batch_size == 0
        w_all = w
        best_max_val_all = []
Casper's avatar
Casper committed
290

Casper Hansen's avatar
Casper Hansen committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
        for i_b in range(w.shape[0] // oc_batch_size):
            w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]

            org_max_val = w.abs().amax(dim=-1, keepdim=True)  # co, 1, n_group, 1

            best_max_val = org_max_val.clone()
            min_errs = torch.ones_like(org_max_val) * 1e9
            input_feat = input_feat.to(w.device)
            org_out = (input_feat * w).sum(dim=-1)  # co, n_token, n_group

            for i_s in range(int(max_shrink * n_grid)):
                max_val = org_max_val * (1 - i_s / n_grid)
                min_val = - max_val
                cur_w = torch.clamp(w, min_val, max_val)
                q_w = self.pseudo_quantize_tensor(cur_w)
                cur_out = (input_feat * q_w).sum(dim=-1)

                # co, 1, n_group, 1
                err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
                del cur_w
                del cur_out
                cur_best_idx = err < min_errs
                min_errs[cur_best_idx] = err[cur_best_idx]
                best_max_val[cur_best_idx] = max_val[cur_best_idx]
            best_max_val_all.append(best_max_val)

        best_max_val = torch.cat(best_max_val_all, dim=0)

        clear_memory(input_feat)
        clear_memory(org_out)

        return best_max_val.squeeze(1)

    def init_quant(self, n_samples=128, seqlen=512):
        modules = self.awq_model.get_model_layers(self.model)
Casper's avatar
Casper committed
326
327
328
329
330
331
332
333
334
        samples = get_calib_dataset(
            data=self.calib_data, tokenizer=self.tokenizer, n_samples=n_samples, block_size=seqlen,
            split=self.split, text_column=self.text_column
        )
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

Casper Hansen's avatar
Casper Hansen committed
335
336
        modules[0] = modules[0].cuda()
        self.awq_model.move_embed(self.model, "cuda")
Casper's avatar
Casper committed
337
338
339
340
341
342
343
344
345
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

346
347
348
349
350
351
352
353
354
355
            def forward(self, *args, **kwargs):
                # assume first input to forward is hidden states
                if len(args) > 0:
                    hidden_states = args[0]
                    del args
                else:
                    first_key = list(kwargs.keys())[0]
                    hidden_states = kwargs.pop(first_key)

                inps.append(hidden_states)
Casper's avatar
Casper committed
356
357
358
359
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
Casper Hansen's avatar
Casper Hansen committed
360
        modules[0] = Catcher(modules[0])
Casper's avatar
Casper committed
361
362
363
364
        try:
            self.model(samples.to(next(self.model.parameters()).device))
        except ValueError:  # work with early exit
            pass
365
366
367
368
369
370
371
        
        # Update the layer kwargs with `prepare_inputs_for_generation` method
        # that takes care of everything to avoid unexpected errors.
        layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs)
        # Pop the input_ids as they are not needed at all.
        layer_kwargs.pop("input_ids")

Casper's avatar
Casper committed
372
        del samples
Casper Hansen's avatar
Casper Hansen committed
373
        modules[0] = modules[0].module  # restore
Casper's avatar
Casper committed
374
375
        inps = inps[0]

Casper Hansen's avatar
Casper Hansen committed
376
377
        modules[0] = modules[0].cpu()
        self.awq_model.move_embed(self.model, "cpu")
Casper's avatar
Casper committed
378
379
        
        clear_memory()
380
        
Casper's avatar
Casper committed
381
        if layer_kwargs.get("attention_mask") is not None:
382
            layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to("cuda")
Casper's avatar
Casper committed
383

Casper Hansen's avatar
Casper Hansen committed
384
        return modules, layer_kwargs, inps
Casper's avatar
Casper committed
385
386
387
388
389
390
391
392
393
394
    
    def _get_input_feat(self, layer, named_linears):
        # firstly, get input features of all linear layers
        def cache_input_hook(m, x, y, name, feat_dict):
            x = x[0]
            x = x.detach().cpu()
            feat_dict[name].append(x)

        input_feat = defaultdict(list)
        handles = []
395
396
397
398
399

        # FIXME: Workaround for Mixtral to use block_sparse_moe input features
        if self.awq_model.model_type == "mixtral":
            named_linears = {**named_linears, "block_sparse_moe": layer.block_sparse_moe}

Casper's avatar
Casper committed
400
401
402
403
        for name in named_linears:
            handles.append(named_linears[name].register_forward_hook(
                functools.partial(cache_input_hook, name=name,
                                feat_dict=input_feat)))
Casper Hansen's avatar
Casper Hansen committed
404
        self.inps = self.inps.to(next(layer.parameters()).device)  # in case multi-gpu
Casper's avatar
Casper committed
405
        # get output as next layer's input
406
407
408
409
410
411
412
        
        # Sanitize the kwargs in case we use transformers version that contains
        # kwargs that are not handled by the module.
        # Useful for trust_remote_code models.
        module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)

        self.inps = layer(self.inps, **module_kwargs)[0]
Casper's avatar
Casper committed
413
414
415
416
417
418
        for h in handles:
            h.remove()
        # now solve for scaling and clipping
        input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
        
        return input_feat
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438


    def _sanitize_kwargs(self, inputs_kwargs, module):
        """
        Remove the arguments that are not supported in the module's
        forward pass to avoid breaking behaviour between different versions
        of transformers. 

        Args:
            inputs_kwargs (`dict`):
                The input dictionary to pass to the model layer
            module (`torch.nn.Module`):
                Target module to quantize.
        """
        module_signature = inspect.signature(module.forward).parameters
        sanitized_kwargs = {}
        for k, v in  inputs_kwargs.items():
            if k in module_signature:
                sanitized_kwargs[k] = v
        return sanitized_kwargs