auto_scale.py 6.84 KB
Newer Older
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
1
import gc
Ji Lin's avatar
Ji Lin committed
2
3
4
import torch
import torch.nn as nn

5
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
Ji Lin's avatar
Ji Lin committed
6
7
8
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm

9
from .qmodule import ScaledActivation
Casper Hansen's avatar
Casper Hansen committed
10
from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name
Ji Lin's avatar
Ji Lin committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

__all__ = ["auto_scale_block", "apply_scale"]


@torch.no_grad()
def get_weight_scale(weight, q_group_size=-1):
    org_shape = weight.shape
    if q_group_size > 0:
        weight = weight.view(-1, q_group_size)
    scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
    scale = scale.view(org_shape)
    scale = scale.mean(0)
    return scale


@torch.no_grad()
def get_act_scale(x):
    return x.abs().view(-1, x.shape[-1]).mean(0)


@torch.no_grad()
def scale_ln_fcs(ln, fcs, scales):
    if not isinstance(fcs, list):
        fcs = [fcs]
    
    scales = scales.to(ln.weight.device)

38
39
40
41
42
43
44
    # debugging start even scales = 1 does not work?
    """
    scales = scales * 0
    scales = scales + 1
    """
    # debugging end

Ji Lin's avatar
Ji Lin committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    ln.weight.div_(scales)
    if hasattr(ln, 'bias') and ln.bias is not None:
        ln.bias.div_(scales)

    for fc in fcs:
        fc.weight.mul_(scales.view(1, -1))

    for p in ln.parameters():
        assert torch.isnan(p).sum() == 0
    for fc in fcs:
        for p in fc.parameters():
            assert torch.isnan(p).sum() == 0


@torch.no_grad()
def scale_fc_fc(fc1, fc2, scales):
    assert isinstance(fc1, nn.Linear)
    assert isinstance(fc2, nn.Linear)
63
    # assert fc1.out_features == fc2.in_features
Ji Lin's avatar
Ji Lin committed
64
65
66
    
    scales = scales.to(fc1.weight.device)

67
68
    # fc1.weight.div_(scales.view(-1, 1))
    fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
Ji Lin's avatar
Ji Lin committed
69
70
71
72
73
74
75
76
77
78
79
    if fc1.bias is not None:
        fc1.bias.div_(scales.view(-1))

    fc2.weight.mul_(scales.view(1, -1))

    for p in fc1.parameters():
        assert torch.isnan(p).sum() == 0
    for p in fc2.parameters():
        assert torch.isnan(p).sum() == 0


80
81
82
83
84
85
86
87
88
89
90
@torch.no_grad()
def scale_gelu_fc(gelu, fc, scales):
    assert isinstance(gelu, nn.GELU) or isinstance(gelu, BloomGelu)
    assert isinstance(fc, nn.Linear)

    fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))

    for p in fc.parameters():
        assert torch.isnan(p).sum() == 0
    

Ji Lin's avatar
Ji Lin committed
91
@torch.no_grad()
Casper Hansen's avatar
Casper Hansen committed
92
def auto_scale_block(awq_model,
93
94
95
                     module, 
                     module_kwargs,
                     quant_config,
Ji Lin's avatar
Ji Lin committed
96
97
98
                     input_feat):
    from .quantizer import pseudo_quantize_tensor
    # firstly, get the weight quantize function
99
100
    if quant_config['w_bit'] is not None:
        def w_quantize_func(p): return pseudo_quantize_tensor(p, **quant_config).detach()
Ji Lin's avatar
Ji Lin committed
101
102
103
104
105
106
107
108
109
110
111
112
    else:
        def w_quantize_func(p): return p

    if "use_cache" in module_kwargs:
        module_kwargs.pop("use_cache")

    # find the best scale ratio
    def _search_module_scale(block, linears2scale: list, x, kwargs={}):
        # w: co, ci
        # x: n, ci
        weight = torch.cat([_m.weight for _m in linears2scale], dim=0)
        w_max = get_weight_scale(
113
            weight, q_group_size=quant_config.get("q_group_size", -1))
114
115
        # Clear GPU memory
        del weight
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
116
        gc.collect()
117
        torch.cuda.empty_cache()
Ji Lin's avatar
Ji Lin committed
118

119
        x = x.to(next(block.parameters()).device)
Ji Lin's avatar
Ji Lin committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        with torch.no_grad():
            org_out = block(x, **kwargs)
            if isinstance(org_out, tuple):
                org_out = org_out[0]

        x_max = get_act_scale(x)

        best_error = float('inf')
        best_ratio = -1
        best_scales = None

        n_grid = 20
        history = []

        org_sd = {k: v.cpu() for k, v in block.state_dict().items()}
        for ratio in range(n_grid):
            ratio = ratio * 1 / n_grid
            scales = (x_max.pow(ratio) / w_max.pow(1-ratio)
                      ).clamp(min=1e-4).view(-1)
            scales = scales / (scales.max() * scales.min()).sqrt()
            for fc in linears2scale:
141
                fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
Ji Lin's avatar
Ji Lin committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
                fc.weight.data = w_quantize_func(
                    fc.weight.data) / (scales.view(1, -1))
            out = block(x, **kwargs)
            if isinstance(out, tuple):
                out = out[0]

            loss = (org_out - out).float().pow(2).mean().item()  # float prevents overflow
            history.append(loss)
            is_best = loss < best_error
            if is_best:
                best_error = loss
                best_ratio = ratio
                best_scales = scales
            block.load_state_dict(org_sd)
        if best_ratio == -1:
            print(history)
            raise Exception
        # print(best_ratio)
        best_scales = best_scales.view(-1)

        assert torch.isnan(best_scales).sum() == 0, best_scales
        return best_scales.detach()

    def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}):
        # module2inspect: if given, we will check the output diff of this module instead of layers
        if module2inspect is None:
            assert len(layers) == 1
            module2inspect = layers[0]

        scales = _search_module_scale(module2inspect, layers, inp, kwargs)
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
172
        scales = scales.detach().cpu()
Ji Lin's avatar
Ji Lin committed
173
174
175
        # prev_op_name, [layer_name], scale
        return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales)

Casper Hansen's avatar
Casper Hansen committed
176
177
178
179
    layers: list[dict] = awq_model.get_layers_for_scaling(
        module, input_feat, module_kwargs
    )
    scales_list = [_auto_get_scale(**layer) for layer in layers]
Ji Lin's avatar
Ji Lin committed
180
181
182
183
184
185
186

    return scales_list

def apply_scale(module, scales_list, input_feat_dict=None):
    for prev_op_name, layer_names, scales in scales_list:
        prev_op = get_op_by_name(module, prev_op_name)
        layers = [get_op_by_name(module, name) for name in layer_names]
187
188
189
190

        prev_op.cuda()
        for layer in layers:
            layer.cuda()
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
191
        scales.cuda()
Ji Lin's avatar
Ji Lin committed
192
193
194
195
196
197
        
        if isinstance(prev_op, nn.Linear):
            assert len(layers) == 1
            scale_fc_fc(prev_op, layers[0], scales)
        elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)):
            scale_ln_fcs(prev_op, layers, scales)
198
199
200
201
        elif isinstance(prev_op, nn.GELU) or isinstance(prev_op, BloomGelu):
            new_module = ScaledActivation(prev_op, scales)
            set_op_by_name(module, prev_op_name, new_module)
            scale_gelu_fc(prev_op, layers[0], scales)
Ji Lin's avatar
Ji Lin committed
202
203
204
205
206
207
208
209
        else:
            raise NotImplementedError(
                f"prev_op {type(prev_op)} not supported yet!")
            
        # apply the scaling to input feat if given; prepare it for clipping
        if input_feat_dict is not None:  
            for layer_name in layer_names:
                inp = input_feat_dict[layer_name]
210
211
212
213
214
                inp.div_(scales.view(1, -1).to(inp.device))

        prev_op.cpu()
        for layer in layers:
            layer.cpu()
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
215
        scales.cpu()