auto_scale.py 8.68 KB
Newer Older
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
1
import gc
Ji Lin's avatar
Ji Lin committed
2
3
import torch
import torch.nn as nn
4
import logging
Ji Lin's avatar
Ji Lin committed
5

6
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
Ji Lin's avatar
Ji Lin committed
7
8
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
EC2 Default User's avatar
EC2 Default User committed
9
from transformers.activations import NewGELUActivation
10
from awq.modules.act import ScaledActivation
Casper Hansen's avatar
Casper Hansen committed
11
from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name
Ji Lin's avatar
Ji Lin committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

__all__ = ["auto_scale_block", "apply_scale"]


@torch.no_grad()
def get_weight_scale(weight, q_group_size=-1):
    org_shape = weight.shape
    if q_group_size > 0:
        weight = weight.view(-1, q_group_size)
    scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
    scale = scale.view(org_shape)
    scale = scale.mean(0)
    return scale


@torch.no_grad()
def get_act_scale(x):
    return x.abs().view(-1, x.shape[-1]).mean(0)


@torch.no_grad()
def scale_ln_fcs(ln, fcs, scales):
    if not isinstance(fcs, list):
        fcs = [fcs]
    
    scales = scales.to(ln.weight.device)

39
40
41
42
43
44
45
    # debugging start even scales = 1 does not work?
    """
    scales = scales * 0
    scales = scales + 1
    """
    # debugging end

Ji Lin's avatar
Ji Lin committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    ln.weight.div_(scales)
    if hasattr(ln, 'bias') and ln.bias is not None:
        ln.bias.div_(scales)

    for fc in fcs:
        fc.weight.mul_(scales.view(1, -1))

    for p in ln.parameters():
        assert torch.isnan(p).sum() == 0
    for fc in fcs:
        for p in fc.parameters():
            assert torch.isnan(p).sum() == 0


@torch.no_grad()
def scale_fc_fc(fc1, fc2, scales):
    assert isinstance(fc1, nn.Linear)
    assert isinstance(fc2, nn.Linear)
64
    # assert fc1.out_features == fc2.in_features
Ji Lin's avatar
Ji Lin committed
65
66
67
    
    scales = scales.to(fc1.weight.device)

68
69
    # fc1.weight.div_(scales.view(-1, 1))
    fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
Ji Lin's avatar
Ji Lin committed
70
71
72
73
74
75
76
77
78
79
80
    if fc1.bias is not None:
        fc1.bias.div_(scales.view(-1))

    fc2.weight.mul_(scales.view(1, -1))

    for p in fc1.parameters():
        assert torch.isnan(p).sum() == 0
    for p in fc2.parameters():
        assert torch.isnan(p).sum() == 0


81
82
@torch.no_grad()
def scale_gelu_fc(gelu, fc, scales):
EC2 Default User's avatar
EC2 Default User committed
83
    assert any(isinstance(gelu,t) for t in [nn.GELU, BloomGelu, NewGELUActivation])
84
85
86
87
88
89
    assert isinstance(fc, nn.Linear)

    fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))

    for p in fc.parameters():
        assert torch.isnan(p).sum() == 0
Casper Hansen's avatar
Casper Hansen committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

def pseudo_quantize_tensor(w, w_bit=4,
                           zero_point=True, 
                           q_group_size=-1,
                           inplace=False,
                           get_scale_zp=False
                           ):
    org_w_shape = w.shape
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)
    assert w.dim() == 2
    if zero_point:
        max_val = w.amax(dim=1, keepdim=True)
        min_val = w.amin(dim=1, keepdim=True)
        max_int = 2 ** w_bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
    else:  # we actually never used this
        assert min_val is None
        max_val = w.abs().amax(dim=1, keepdim=True)
        max_val = max_val.clamp(min=1e-5)
        max_int = 2 ** (w_bit - 1) - 1
        min_int = - 2 ** (w_bit - 1)
        scales = max_val / max_int
        zeros = 0

    assert torch.isnan(scales).sum() == 0
    assert torch.isnan(w).sum() == 0

    if inplace:
        ((w.div_(scales).round_().add_(zeros)).clamp_(
            min_int, max_int).sub_(zeros)).mul_(scales)
    else:
        w = (torch.clamp(torch.round(w / scales) +
                         zeros, min_int, max_int) - zeros) * scales
    assert torch.isnan(w).sum() == 0

    w = w.reshape(org_w_shape)

    if get_scale_zp:
        return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
    else:
        return w    
135

Ji Lin's avatar
Ji Lin committed
136
@torch.no_grad()
Casper Hansen's avatar
Casper Hansen committed
137
def auto_scale_block(awq_model,
138
139
140
                     module, 
                     module_kwargs,
                     quant_config,
Ji Lin's avatar
Ji Lin committed
141
                     input_feat):
Casper Hansen's avatar
Casper Hansen committed
142
    # from .quantizer import pseudo_quantize_tensor
Ji Lin's avatar
Ji Lin committed
143
    # firstly, get the weight quantize function
144
    if quant_config['w_bit'] is not None:
Casper Hansen's avatar
Casper Hansen committed
145
        def w_quantize_func(p): return pseudo_quantize_tensor(p, w_bit=quant_config["w_bit"], q_group_size=quant_config["q_group_size"]).detach()
Ji Lin's avatar
Ji Lin committed
146
147
148
149
150
151
152
    else:
        def w_quantize_func(p): return p

    if "use_cache" in module_kwargs:
        module_kwargs.pop("use_cache")

    # find the best scale ratio
Casper Hansen's avatar
Casper Hansen committed
153
154
155
156
157
    def _search_module_scale(module2inspect, layers: list, inp, kwargs={}):
        if module2inspect is None:
            assert len(layers) == 1
            module2inspect = layers[0]
            
Ji Lin's avatar
Ji Lin committed
158
159
        # w: co, ci
        # x: n, ci
Casper Hansen's avatar
Casper Hansen committed
160
161
162
163
164
165
166
        weight = torch.cat([_m.weight for _m in layers], dim=0)
        org_shape = weight.shape
        weight = weight.view(-1, quant_config.get("q_group_size"))
        w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
        w_scale = w_scale.view(org_shape)
        w_max = w_scale.mean(0)

167
168
        # Clear GPU memory
        del weight
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
169
        gc.collect()
170
        torch.cuda.empty_cache()
Ji Lin's avatar
Ji Lin committed
171

Casper Hansen's avatar
Casper Hansen committed
172
        inp = inp.to(next(module2inspect.parameters()).device)
Ji Lin's avatar
Ji Lin committed
173
        with torch.no_grad():
Casper Hansen's avatar
Casper Hansen committed
174
            org_out = module2inspect(inp, **kwargs)
Ji Lin's avatar
Ji Lin committed
175
176
177
            if isinstance(org_out, tuple):
                org_out = org_out[0]

Casper Hansen's avatar
Casper Hansen committed
178
        x_max = get_act_scale(inp)
Ji Lin's avatar
Ji Lin committed
179
180
181
182
183
184
185
186

        best_error = float('inf')
        best_ratio = -1
        best_scales = None

        n_grid = 20
        history = []

Casper Hansen's avatar
Casper Hansen committed
187
        org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
Ji Lin's avatar
Ji Lin committed
188
189
190
191
192
        for ratio in range(n_grid):
            ratio = ratio * 1 / n_grid
            scales = (x_max.pow(ratio) / w_max.pow(1-ratio)
                      ).clamp(min=1e-4).view(-1)
            scales = scales / (scales.max() * scales.min()).sqrt()
Casper Hansen's avatar
Casper Hansen committed
193
            for fc in layers:
194
                fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
Ji Lin's avatar
Ji Lin committed
195
196
                fc.weight.data = w_quantize_func(
                    fc.weight.data) / (scales.view(1, -1))
Casper Hansen's avatar
Casper Hansen committed
197
            out = module2inspect(inp, **kwargs)
Ji Lin's avatar
Ji Lin committed
198
199
200
201
202
203
204
205
206
207
            if isinstance(out, tuple):
                out = out[0]

            loss = (org_out - out).float().pow(2).mean().item()  # float prevents overflow
            history.append(loss)
            is_best = loss < best_error
            if is_best:
                best_error = loss
                best_ratio = ratio
                best_scales = scales
Casper Hansen's avatar
Casper Hansen committed
208
            module2inspect.load_state_dict(org_sd)
Ji Lin's avatar
Ji Lin committed
209
        if best_ratio == -1:
210
            logging.debug(history)
Ji Lin's avatar
Ji Lin committed
211
212
213
214
215
216
217
218
            raise Exception
        best_scales = best_scales.view(-1)

        assert torch.isnan(best_scales).sum() == 0, best_scales
        return best_scales.detach()

    def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}):
        scales = _search_module_scale(module2inspect, layers, inp, kwargs)
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
219
        scales = scales.detach().cpu()
Casper Hansen's avatar
Casper Hansen committed
220
        
Ji Lin's avatar
Ji Lin committed
221
222
223
        # prev_op_name, [layer_name], scale
        return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales)

Casper Hansen's avatar
Casper Hansen committed
224
225
226
227
    layers: list[dict] = awq_model.get_layers_for_scaling(
        module, input_feat, module_kwargs
    )
    scales_list = [_auto_get_scale(**layer) for layer in layers]
Ji Lin's avatar
Ji Lin committed
228
229
230
231
232
233
234

    return scales_list

def apply_scale(module, scales_list, input_feat_dict=None):
    for prev_op_name, layer_names, scales in scales_list:
        prev_op = get_op_by_name(module, prev_op_name)
        layers = [get_op_by_name(module, name) for name in layer_names]
235
236
237
238

        prev_op.cuda()
        for layer in layers:
            layer.cuda()
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
239
        scales.cuda()
Ji Lin's avatar
Ji Lin committed
240
241
242
243
        
        if isinstance(prev_op, nn.Linear):
            assert len(layers) == 1
            scale_fc_fc(prev_op, layers[0], scales)
244
245
246

        elif any(isinstance(prev_op,t) for t in [nn.LayerNorm, LlamaRMSNorm]) \
             or 'rmsnorm' in str(prev_op.__class__).lower():
Ji Lin's avatar
Ji Lin committed
247
            scale_ln_fcs(prev_op, layers, scales)
248

EC2 Default User's avatar
EC2 Default User committed
249
        elif any(isinstance(prev_op,t) for t in [nn.GELU, BloomGelu, NewGELUActivation]):
250
251
252
            new_module = ScaledActivation(prev_op, scales)
            set_op_by_name(module, prev_op_name, new_module)
            scale_gelu_fc(prev_op, layers[0], scales)
253
            
Ji Lin's avatar
Ji Lin committed
254
255
256
257
258
259
260
261
        else:
            raise NotImplementedError(
                f"prev_op {type(prev_op)} not supported yet!")
            
        # apply the scaling to input feat if given; prepare it for clipping
        if input_feat_dict is not None:  
            for layer_name in layer_names:
                inp = input_feat_dict[layer_name]
262
263
264
265
266
                inp.div_(scales.view(1, -1).to(inp.device))

        prev_op.cpu()
        for layer in layers:
            layer.cpu()
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
267
        scales.cpu()