scale.py 5.06 KB
Newer Older
Casper's avatar
Casper committed
1
2
import torch
import torch.nn as nn
Vik Paruchuri's avatar
Vik Paruchuri committed
3
from typing import Tuple, List
4
from awq.utils.utils import get_best_device
Casper's avatar
Casper committed
5
from awq.modules.act import ScaledActivation
Casper Hansen's avatar
Casper Hansen committed
6
from awq.utils.module import get_op_by_name, set_op_by_name
Casper's avatar
Casper committed
7
8
from transformers.models.bloom.modeling_bloom import BloomGelu
from transformers.models.llama.modeling_llama import LlamaRMSNorm
TechxGenus's avatar
TechxGenus committed
9
from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
twaka's avatar
twaka committed
10
from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation
Casper's avatar
Casper committed
11

TechxGenus's avatar
TechxGenus committed
12
allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm]
Casper's avatar
Casper committed
13
14
15
16
17
18
19
20
allowed_act_fns = [
    nn.GELU,
    BloomGelu,
    NewGELUActivation,
    PytorchGELUTanh,
    GELUActivation,
]

Casper's avatar
Casper committed
21
22
23
24
25

@torch.no_grad()
def apply_clip(module, clip_list: Tuple[str, torch.Tensor]):
    for name, max_val in clip_list:
        layer: nn.Linear = get_op_by_name(module, name)
26
        layer.to(get_best_device())
Casper's avatar
Casper committed
27
28
29
30
31
32
33
34
35
36
37
38
        max_val = max_val.to(layer.weight.device)
        org_shape = layer.weight.shape
        layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
        layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
        layer.weight.data = layer.weight.data.reshape(org_shape)
        layer.cpu()


def apply_scale(module, scales_list, input_feat_dict=None):
    for prev_op_name, layer_names, scales in scales_list:
        prev_op = get_op_by_name(module, prev_op_name)
        layers = [get_op_by_name(module, name) for name in layer_names]
Casper's avatar
Casper committed
39

40
41
        best_device = get_best_device()
        prev_op.to(best_device)
Casper's avatar
Casper committed
42
        for layer in layers:
43
44
            layer.to(best_device)
        scales.to(best_device)
Casper's avatar
Casper committed
45
46
47
48
49
50

        if (
            isinstance(prev_op, nn.Linear)
            and type(layers) == list
            and isinstance(layers[0], nn.Linear)
        ):
51
52
53
            scale_fc_fcs(prev_op, layers, scales)

        elif isinstance(prev_op, nn.Linear):
Casper's avatar
Casper committed
54
55
56
            assert len(layers) == 1
            scale_fc_fc(prev_op, layers[0], scales)

Casper's avatar
Casper committed
57
58
59
60
        elif (
            any(isinstance(prev_op, t) for t in allowed_norms)
            or "rmsnorm" in str(prev_op.__class__).lower()
        ):
Casper's avatar
Casper committed
61
62
            scale_ln_fcs(prev_op, layers, scales)

Casper's avatar
Casper committed
63
        elif any(isinstance(prev_op, t) for t in allowed_act_fns):
Casper's avatar
Casper committed
64
65
66
            new_module = ScaledActivation(prev_op, scales)
            set_op_by_name(module, prev_op_name, new_module)
            scale_gelu_fc(prev_op, layers[0], scales)
Casper's avatar
Casper committed
67

Casper's avatar
Casper committed
68
        else:
Casper's avatar
Casper committed
69
70
            raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")

Casper's avatar
Casper committed
71
        # apply the scaling to input feat if given; prepare it for clipping
Casper's avatar
Casper committed
72
        if input_feat_dict is not None:
Casper's avatar
Casper committed
73
            for layer_name in layer_names:
74
75
76
77
                # Skip the modules that are not quantized
                if layer_name in input_feat_dict:
                    inp = input_feat_dict[layer_name]
                    inp.div_(scales.view(1, -1).to(inp.device))
Casper's avatar
Casper committed
78
79
80
81
82
83

        prev_op.cpu()
        for layer in layers:
            layer.cpu()
        scales.cpu()

Casper's avatar
Casper committed
84

Casper's avatar
Casper committed
85
@torch.no_grad()
Vik Paruchuri's avatar
Vik Paruchuri committed
86
def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor):
Casper's avatar
Casper committed
87
88
    if not isinstance(fcs, list):
        fcs = [fcs]
Casper's avatar
Casper committed
89

Casper's avatar
Casper committed
90
91
    scales = scales.to(ln.weight.device)

TechxGenus's avatar
TechxGenus committed
92
93
94
95
96
97
98
99
100
    # GemmaRMSNorm is different from Llama's in that it multiplies
    # (1 + weight) to the output, instead of just weight.
    if isinstance(ln, GemmaRMSNorm):
        ln.weight += 1
        ln.weight.div_(scales)
        ln.weight -= 1
    else:
        ln.weight.div_(scales)

Casper's avatar
Casper committed
101
    if hasattr(ln, "bias") and ln.bias is not None:
Casper's avatar
Casper committed
102
103
104
105
106
107
108
109
110
111
112
        ln.bias.div_(scales)

    for fc in fcs:
        fc.weight.mul_(scales.view(1, -1))

    for p in ln.parameters():
        assert torch.isnan(p).sum() == 0
    for fc in fcs:
        for p in fc.parameters():
            assert torch.isnan(p).sum() == 0

Casper's avatar
Casper committed
113

Casper's avatar
Casper committed
114
@torch.no_grad()
Casper's avatar
Casper committed
115
def scale_fc_fc(fc1: nn.Linear, fc2: nn.Linear, scales: torch.Tensor):
Casper's avatar
Casper committed
116
117
    assert isinstance(fc1, nn.Linear)
    assert isinstance(fc2, nn.Linear)
Casper's avatar
Casper committed
118

Casper's avatar
Casper committed
119
120
    scales = scales.to(fc1.weight.device)

Casper's avatar
Casper committed
121
    fc1.weight[-scales.size(0) :].div_(scales.view(-1, 1))
Casper's avatar
Casper committed
122
123
124
125
126
127
128
129
130
131
    if fc1.bias is not None:
        fc1.bias.div_(scales.view(-1))

    fc2.weight.mul_(scales.view(1, -1))

    for p in fc1.parameters():
        assert torch.isnan(p).sum() == 0
    for p in fc2.parameters():
        assert torch.isnan(p).sum() == 0

Casper's avatar
Casper committed
132

133
134
135
136
@torch.no_grad()
def scale_fc_fcs(fc1: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor):
    if not isinstance(fcs, list):
        fcs = [fcs]
Casper's avatar
Casper committed
137

138
139
    scales = scales.to(fc1.weight.device)

Casper's avatar
Casper committed
140
    fc1.weight[-scales.size(0) :].div_(scales.view(-1, 1))
141
142
143
144
145
    if fc1.bias is not None:
        fc1.bias.div_(scales.view(-1))

    for fc in fcs:
        fc.weight.mul_(scales.view(1, -1))
Casper's avatar
Casper committed
146

147
148
149
150
151
    for p in fc1.parameters():
        assert torch.isnan(p).sum() == 0
    for fc in fcs:
        for p in fc.parameters():
            assert torch.isnan(p).sum() == 0
Casper's avatar
Casper committed
152

Casper's avatar
Casper committed
153

Casper's avatar
Casper committed
154
@torch.no_grad()
Casper's avatar
Casper committed
155
def scale_gelu_fc(gelu: allowed_act_fns, fc: nn.Linear, scales: torch.Tensor):
Casper's avatar
Casper committed
156
    assert any(isinstance(gelu, t) for t in allowed_act_fns)
Casper's avatar
Casper committed
157
158
159
160
161
    assert isinstance(fc, nn.Linear)

    fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))

    for p in fc.parameters():
162
        assert torch.isnan(p).sum() == 0