auto_scale.py 7.12 KB
Newer Older
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
1
import gc
Ji Lin's avatar
Ji Lin committed
2
3
import torch
import torch.nn as nn
4
import logging
Ji Lin's avatar
Ji Lin committed
5

6
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
Ji Lin's avatar
Ji Lin committed
7
8
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
9
from transformers.activations import NewGELUActivation, PytorchGELUTanh
10
from awq.modules.act import ScaledActivation
Casper Hansen's avatar
Casper Hansen committed
11
from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name
Ji Lin's avatar
Ji Lin committed
12
13
14

__all__ = ["auto_scale_block", "apply_scale"]

15
16
norms = [nn.LayerNorm, LlamaRMSNorm]
act_functions = [nn.GELU, BloomGelu, NewGELUActivation, PytorchGELUTanh]
Ji Lin's avatar
Ji Lin committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

@torch.no_grad()
def get_weight_scale(weight, q_group_size=-1):
    org_shape = weight.shape
    if q_group_size > 0:
        weight = weight.view(-1, q_group_size)
    scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
    scale = scale.view(org_shape)
    scale = scale.mean(0)
    return scale


@torch.no_grad()
def get_act_scale(x):
    return x.abs().view(-1, x.shape[-1]).mean(0)


@torch.no_grad()
def scale_ln_fcs(ln, fcs, scales):
    if not isinstance(fcs, list):
        fcs = [fcs]
    
    scales = scales.to(ln.weight.device)

41
42
43
44
45
46
47
    # debugging start even scales = 1 does not work?
    """
    scales = scales * 0
    scales = scales + 1
    """
    # debugging end

Ji Lin's avatar
Ji Lin committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    ln.weight.div_(scales)
    if hasattr(ln, 'bias') and ln.bias is not None:
        ln.bias.div_(scales)

    for fc in fcs:
        fc.weight.mul_(scales.view(1, -1))

    for p in ln.parameters():
        assert torch.isnan(p).sum() == 0
    for fc in fcs:
        for p in fc.parameters():
            assert torch.isnan(p).sum() == 0


@torch.no_grad()
def scale_fc_fc(fc1, fc2, scales):
    assert isinstance(fc1, nn.Linear)
    assert isinstance(fc2, nn.Linear)
66
    # assert fc1.out_features == fc2.in_features
Ji Lin's avatar
Ji Lin committed
67
68
69
    
    scales = scales.to(fc1.weight.device)

70
71
    # fc1.weight.div_(scales.view(-1, 1))
    fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
Ji Lin's avatar
Ji Lin committed
72
73
74
75
76
77
78
79
80
81
82
    if fc1.bias is not None:
        fc1.bias.div_(scales.view(-1))

    fc2.weight.mul_(scales.view(1, -1))

    for p in fc1.parameters():
        assert torch.isnan(p).sum() == 0
    for p in fc2.parameters():
        assert torch.isnan(p).sum() == 0


83
84
@torch.no_grad()
def scale_gelu_fc(gelu, fc, scales):
85
    assert any(isinstance(gelu,t) for t in act_functions)
86
87
88
89
90
91
92
93
    assert isinstance(fc, nn.Linear)

    fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))

    for p in fc.parameters():
        assert torch.isnan(p).sum() == 0
    

Ji Lin's avatar
Ji Lin committed
94
@torch.no_grad()
Casper Hansen's avatar
Casper Hansen committed
95
def auto_scale_block(awq_model,
96
97
98
                     module, 
                     module_kwargs,
                     quant_config,
Ji Lin's avatar
Ji Lin committed
99
100
101
                     input_feat):
    from .quantizer import pseudo_quantize_tensor
    # firstly, get the weight quantize function
102
    if quant_config['w_bit'] is not None:
Casper Hansen's avatar
Casper Hansen committed
103
        def w_quantize_func(p): return pseudo_quantize_tensor(p, w_bit=quant_config["w_bit"], q_group_size=quant_config["q_group_size"]).detach()
Ji Lin's avatar
Ji Lin committed
104
105
106
107
108
109
110
111
112
113
114
115
    else:
        def w_quantize_func(p): return p

    if "use_cache" in module_kwargs:
        module_kwargs.pop("use_cache")

    # find the best scale ratio
    def _search_module_scale(block, linears2scale: list, x, kwargs={}):
        # w: co, ci
        # x: n, ci
        weight = torch.cat([_m.weight for _m in linears2scale], dim=0)
        w_max = get_weight_scale(
116
            weight, q_group_size=quant_config.get("q_group_size", -1))
117
118
        # Clear GPU memory
        del weight
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
119
        gc.collect()
120
        torch.cuda.empty_cache()
Ji Lin's avatar
Ji Lin committed
121

122
        x = x.to(next(block.parameters()).device)
Ji Lin's avatar
Ji Lin committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        with torch.no_grad():
            org_out = block(x, **kwargs)
            if isinstance(org_out, tuple):
                org_out = org_out[0]

        x_max = get_act_scale(x)

        best_error = float('inf')
        best_ratio = -1
        best_scales = None

        n_grid = 20
        history = []

        org_sd = {k: v.cpu() for k, v in block.state_dict().items()}
        for ratio in range(n_grid):
            ratio = ratio * 1 / n_grid
            scales = (x_max.pow(ratio) / w_max.pow(1-ratio)
                      ).clamp(min=1e-4).view(-1)
            scales = scales / (scales.max() * scales.min()).sqrt()
            for fc in linears2scale:
144
                fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
Ji Lin's avatar
Ji Lin committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
                fc.weight.data = w_quantize_func(
                    fc.weight.data) / (scales.view(1, -1))
            out = block(x, **kwargs)
            if isinstance(out, tuple):
                out = out[0]

            loss = (org_out - out).float().pow(2).mean().item()  # float prevents overflow
            history.append(loss)
            is_best = loss < best_error
            if is_best:
                best_error = loss
                best_ratio = ratio
                best_scales = scales
            block.load_state_dict(org_sd)
        if best_ratio == -1:
160
            logging.debug(history)
Ji Lin's avatar
Ji Lin committed
161
162
163
164
165
166
167
168
169
170
171
172
173
            raise Exception
        best_scales = best_scales.view(-1)

        assert torch.isnan(best_scales).sum() == 0, best_scales
        return best_scales.detach()

    def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}):
        # module2inspect: if given, we will check the output diff of this module instead of layers
        if module2inspect is None:
            assert len(layers) == 1
            module2inspect = layers[0]

        scales = _search_module_scale(module2inspect, layers, inp, kwargs)
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
174
        scales = scales.detach().cpu()
Ji Lin's avatar
Ji Lin committed
175
176
177
        # prev_op_name, [layer_name], scale
        return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales)

Casper Hansen's avatar
Casper Hansen committed
178
179
180
181
    layers: list[dict] = awq_model.get_layers_for_scaling(
        module, input_feat, module_kwargs
    )
    scales_list = [_auto_get_scale(**layer) for layer in layers]
Ji Lin's avatar
Ji Lin committed
182
183
184
185
186
187
188

    return scales_list

def apply_scale(module, scales_list, input_feat_dict=None):
    for prev_op_name, layer_names, scales in scales_list:
        prev_op = get_op_by_name(module, prev_op_name)
        layers = [get_op_by_name(module, name) for name in layer_names]
189
190
191
192

        prev_op.cuda()
        for layer in layers:
            layer.cuda()
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
193
        scales.cuda()
Ji Lin's avatar
Ji Lin committed
194
195
196
197
        
        if isinstance(prev_op, nn.Linear):
            assert len(layers) == 1
            scale_fc_fc(prev_op, layers[0], scales)
198

199
        elif any(isinstance(prev_op,t) for t in norms) \
200
             or 'rmsnorm' in str(prev_op.__class__).lower():
Ji Lin's avatar
Ji Lin committed
201
            scale_ln_fcs(prev_op, layers, scales)
202

203
        elif any(isinstance(prev_op,t) for t in act_functions):
204
205
206
            new_module = ScaledActivation(prev_op, scales)
            set_op_by_name(module, prev_op_name, new_module)
            scale_gelu_fc(prev_op, layers[0], scales)
207
            
Ji Lin's avatar
Ji Lin committed
208
209
210
211
212
213
214
215
        else:
            raise NotImplementedError(
                f"prev_op {type(prev_op)} not supported yet!")
            
        # apply the scaling to input feat if given; prepare it for clipping
        if input_feat_dict is not None:  
            for layer_name in layer_names:
                inp = input_feat_dict[layer_name]
216
217
218
219
220
                inp.div_(scales.view(1, -1).to(inp.device))

        prev_op.cpu()
        for layer in layers:
            layer.cpu()
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
221
        scales.cpu()