lora_mixin.py 13.2 KB
Newer Older
MPU王荣胜's avatar
MPU王荣胜 committed
1
"""
MPU王荣胜's avatar
MPU王荣胜 committed
2
In this mixin, I use a different implementation than sat/model/finetune/lora.py
MPU王荣胜's avatar
MPU王荣胜 committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
I just use a fake linear layer to replace any model with lora mixin.
"""

import torch
import torch.nn as nn
from sat.model.base_model import BaseMixin
import math
from sat.helpers import print_all
from sat.model.transformer import RowParallelLinear, ColumnParallelLinear

class HackLinear(nn.Linear):
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        if prefix + 'weight' in state_dict:
            self.weight.data.copy_(state_dict[prefix+'weight'])
        if prefix + 'bias' in state_dict:
            self.bias.data.copy_(state_dict[prefix+'bias'])

MPU王荣胜's avatar
MPU王荣胜 committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class HackRowParallelLinear(RowParallelLinear):
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        if prefix + 'weight' in state_dict:
            self.weight.data.copy_(state_dict[prefix+'weight'])
        if prefix + 'bias' in state_dict:
            self.bias.data.copy_(state_dict[prefix+'bias'])

class HackColumnParallelLinear(ColumnParallelLinear):
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        if prefix + 'weight' in state_dict:
            self.weight.data.copy_(state_dict[prefix+'weight'])
        if prefix + 'bias' in state_dict:
            self.bias.data.copy_(state_dict[prefix+'bias'])

MPU王荣胜's avatar
MPU王荣胜 committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
try:
    from bitsandbytes.nn import LinearNF4
    def copy_nested_list(src, dst):
        for i in range(len(dst)):
            if type(dst[i]) is torch.Tensor:
                dst[i].copy_(src[i])
            elif type(dst[i]) is list:
                copy_nested_list(src[i], dst[i])
            else:
                dst[i] = src[i]
    class HackLinearNF4(LinearNF4):
        def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
            if prefix + 'weight' in state_dict:
                self.weight.data.copy_(state_dict[prefix+'weight'])
                if self.weight.data.dtype == torch.uint8:
                    copy_nested_list(state_dict[prefix+'quant_state'], self.weight.quant_state)
            if prefix + 'bias' in state_dict:
                self.bias.data.copy_(state_dict[prefix+'bias'])
        def _save_to_state_dict(self, destination, prefix, keep_vars):
            super()._save_to_state_dict(destination, prefix, keep_vars)
            destination[prefix+'quant_state'] = self.weight.quant_state
except Exception as exception:
    print_all("Failed to load bitsandbytes:" + str(exception), level='WARNING')


class HackParameterList(nn.ParameterList):
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        for i in range(len(self)):
            if prefix + str(i) in state_dict:
                self[i].data.copy_(state_dict[prefix+str(i)])

MPU王荣胜's avatar
MPU王荣胜 committed
65
66
67
68
69
map_cls = {
    nn.Linear: (HackLinear, {}),
    ColumnParallelLinear: (HackColumnParallelLinear, {'gather_output': False}),
    RowParallelLinear: (HackRowParallelLinear, {'input_is_parallel': True})
}
MPU王荣胜's avatar
MPU王荣胜 committed
70

MPU王荣胜's avatar
MPU王荣胜 committed
71
72
class LoraLinear(nn.Module):
    def __init__(self, original_cls, partition, in_dim, out_dim, r, lora_alpha=1., lora_dropout=0., head_first=False, num_attention_heads=None, hidden_size_per_attention_head=None, qlora=False):
MPU王荣胜's avatar
MPU王荣胜 committed
73
74
75
76
77
78
79
80
81
82
83
84
85
        """
        You can use safely with this layer, ONLY WHEN query_key_value output is query_key_value order.
        If you use a different order like ChatGLM
        """
        super().__init__()
        if lora_dropout and lora_dropout > 0:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        self.r = r
        self.lora_alpha = lora_alpha
        self.scaling = self.lora_alpha / self.r
        if qlora:
MPU王荣胜's avatar
MPU王荣胜 committed
86
87
88
89
            try:
                self.original = HackLinearNF4(in_dim, out_dim)
            except:
                raise Exception('Build 4bit layer failed. You need to install the latest bitsandbytes. Try `pip install bitsandbytes`. If you still meet error after installation, try running `from bitsandbytes.nn import LinearNF4` with python and fix the error.')
MPU王荣胜's avatar
MPU王荣胜 committed
90
        else:
MPU王荣胜's avatar
MPU王荣胜 committed
91
92
93
94
95
            base_cls, kwargs = map_cls[original_cls]
            self.original = base_cls(in_dim, out_dim, **kwargs)
        self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, in_dim))) for _ in range(partition)])
        self.matrix_B = HackParameterList([nn.Parameter(torch.empty((out_dim // partition, r))) for _ in range(partition)])
        for i in range(partition):
MPU王荣胜's avatar
MPU王荣胜 committed
96
97
98
            nn.init.kaiming_uniform_(self.matrix_A[i], a=math.sqrt(5))
            nn.init.zeros_(self.matrix_B[i])
        self.head_first = head_first
MPU王荣胜's avatar
MPU王荣胜 committed
99
        self.partition = partition
MPU王荣胜's avatar
MPU王荣胜 committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        if head_first:
            assert num_attention_heads is not None and hidden_size_per_attention_head is not None, "You should set num_attention_heads and hidden_size_per_attention_head if you use head_first=True!"
            self.num_attention_heads = num_attention_heads
            self.hidden_size_per_attention_head = hidden_size_per_attention_head

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        # This is not a perfect version, becuase it doesn't handle errors and unexpected keys.
        if prefix + 'weight' in state_dict:
            # load from normal Linear
            self.original._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
        else:
            # load from LoraLinear
            super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
            
    def forward(self, x):
        mixed_raw_layer = self.original(x)
        lora_outputs = []
MPU王荣胜's avatar
MPU王荣胜 committed
117
        for i in range(self.partition):
MPU王荣胜's avatar
MPU王荣胜 committed
118
119
120
121
122
123
            lora_outputs.append((self.lora_dropout(x) @ self.matrix_A[i].T @ self.matrix_B[i].T) * self.scaling)
        if self.head_first:
            new_tensor_shape = lora_outputs[0].size()[:-1] + (
                self.num_attention_heads,
                self.hidden_size_per_attention_head,
            )
MPU王荣胜's avatar
MPU王荣胜 committed
124
            for i in range(self.partition):
MPU王荣胜's avatar
MPU王荣胜 committed
125
126
127
128
129
130
131
132
                lora_outputs[i] = lora_outputs[i].view(*new_tensor_shape)
            mixed_raw_layer = mixed_raw_layer + torch.cat(lora_outputs, -1).view(*mixed_raw_layer.size())
        else:
            mixed_raw_layer = mixed_raw_layer + torch.cat(lora_outputs, -1)

        return mixed_raw_layer


MPU王荣胜's avatar
MPU王荣胜 committed
133
def replace_linear_with_lora(lin, partition, r, *args, **kw_args):
MPU王荣胜's avatar
MPU王荣胜 committed
134
135
    # not supported for linear without bias for now
    out_dim, in_dim = lin.weight.shape
MPU王荣胜's avatar
MPU王荣胜 committed
136
137
138
    original_cls = type(lin)
    del lin
    return LoraLinear(original_cls, partition, in_dim, out_dim, r, *args, **kw_args)
MPU王荣胜's avatar
MPU王荣胜 committed
139
140

def merge_linear_lora(lin):
MPU王荣胜's avatar
MPU王荣胜 committed
141
142
143
144
145
146
147
148
149
    if lin.original.weight.data.dtype is not torch.uint8:
        weight = lin.original.weight
        out_dim, in_dim = weight.shape
        new_lin = nn.Linear(in_dim, out_dim)
    else:
        import bitsandbytes.functional as F
        weight = F.dequantize_fp4(lin.original.weight.data, lin.original.weight.quant_state).to(lin.original.bias.data.dtype)
        out_dim, in_dim = weight.shape
        new_lin = HackLinearNF4(in_dim, out_dim)
MPU王荣胜's avatar
MPU王荣胜 committed
150
151
    new_lin.bias.data = lin.original.bias.data
    new_qkv = []
MPU王荣胜's avatar
MPU王荣胜 committed
152
    for i in range(lin.partition):
MPU王荣胜's avatar
MPU王荣胜 committed
153
154
155
156
        new_qkv.append(lin.matrix_A[i].data.T.float() @ lin.matrix_B[i].data.T.float() * lin.scaling)
    if lin.head_first:
        ini_shape = new_qkv[0].shape
        new_qkv = [x.view(ini_shape[0], lin.num_attention_heads, -1) for x in new_qkv]
MPU王荣胜's avatar
MPU王荣胜 committed
157
        new_qkv = torch.cat(new_qkv, -1).view(ini_shape[0], lin.partition*ini_shape[1])
MPU王荣胜's avatar
MPU王荣胜 committed
158
159
    else:
        new_qkv = torch.cat(new_qkv, -1)
MPU王荣胜's avatar
MPU王荣胜 committed
160
161
    new_lin.weight.data = weight + new_qkv.T.to(lin.original.bias.data.dtype)
    return new_lin.cuda() if torch.cuda.is_available() else new_lin
MPU王荣胜's avatar
MPU王荣胜 committed
162
163
164
165
166
167
168
169
170
171
172

class LoraMixin(BaseMixin):
    def __init__(self, 
                layer_num,
                r: int = 0,
                lora_alpha: int = 1,
                lora_dropout: float = 0.,
                layer_range = None,
                head_first = False,
                num_attention_heads = None,
                hidden_size_per_attention_head = None,
MPU王荣胜's avatar
MPU王荣胜 committed
173
174
                qlora = False,
                cross_attention = True):
MPU王荣胜's avatar
MPU王荣胜 committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        super().__init__()
        self.r = r
        self.lora_alpha = lora_alpha
        self.lora_dropout = lora_dropout

        if layer_range is None:
            layer_range = [i for i in range(layer_num)]
        self.layer_range = layer_range

        self.scaling = self.lora_alpha / self.r
        self.head_first = head_first
        self.num_attention_heads = num_attention_heads
        self.hidden_size_per_attention_head = hidden_size_per_attention_head
        self.qlora = qlora
MPU王荣胜's avatar
MPU王荣胜 committed
189
        self.cross_attention = cross_attention
MPU王荣胜's avatar
MPU王荣胜 committed
190
191
192

    def reinit(self, parent_model):
        for i in self.layer_range:
MPU王荣胜's avatar
MPU王荣胜 committed
193
194
195
196
197
198
199
200
            print(f'replacing layer {i} attention with lora')
            parent_model.transformer.layers[i].attention.dense = replace_linear_with_lora(parent_model.transformer.layers[i].attention.dense, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
            parent_model.transformer.layers[i].attention.query_key_value = replace_linear_with_lora(parent_model.transformer.layers[i].attention.query_key_value, 3, self.r, self.lora_alpha, self.lora_dropout, head_first=self.head_first, num_attention_heads=self.num_attention_heads, hidden_size_per_attention_head=self.hidden_size_per_attention_head, qlora=self.qlora)
            if self.cross_attention and parent_model.transformer.layers[i].is_decoder:
                print(f'replacing layer {i} cross attention with lora')
                parent_model.transformer.layers[i].cross_attention.dense = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.dense, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
                parent_model.transformer.layers[i].cross_attention.query = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.query, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
                parent_model.transformer.layers[i].cross_attention.key_value = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.key_value, 2, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
MPU王荣胜's avatar
MPU王荣胜 committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        if self.qlora:
            print('replacing chatglm linear layer with 4bit')
            def replace_linear_with_nf4(model, name=None, cache={}):
                if type(model) in (nn.Linear, RowParallelLinear, ColumnParallelLinear):
                    out_dim, in_dim = model.weight.shape
                    return HackLinearNF4(in_dim, out_dim)
                names = set()
                for name, child in model.named_children():
                    if name not in names:
                        if child in cache:
                            new_child = cache[child]
                        else:
                            new_child = replace_linear_with_nf4(child, name=name, cache=cache)
                            cache[child] = new_child
                        setattr(model, name, new_child)
                        names.add(name)
                flag = True
                while flag:
                    flag = False
                    for name, child in model.named_children():
                        if name not in names:
                            setattr(model, name, cache[child])
                            names.add(name)
                            flag = True
                return model
            replace_linear_with_nf4(parent_model.transformer, None, {})

    def merge_lora(self):
        for i in self.layer_range:
MPU王荣胜's avatar
MPU王荣胜 committed
230
            print(f'merge layer {i} lora attention back to linear')
MPU王荣胜's avatar
MPU王荣胜 committed
231
            self.transformer.layers[i].attention.dense = merge_linear_lora(self.transformer.layers[i].attention.dense)
MPU王荣胜's avatar
MPU王荣胜 committed
232
233
234
235
236
237
            self.transformer.layers[i].attention.query_key_value = merge_linear_lora(self.transformer.layers[i].attention.query_key_value)
            if self.transformer.layers[i].is_decoder:
                print(f'merge layer {i} lora cross attention back to linear')
                self.transformer.layers[i].cross_attention.dense = merge_linear_lora(self.transformer.layers[i].cross_attention.dense)
                self.transformer.layers[i].cross_attention.query = merge_linear_lora(self.transformer.layers[i].cross_attention.query)
                self.transformer.layers[i].cross_attention.key_value = merge_linear_lora(self.transformer.layers[i].cross_attention.key_value)
MPU王荣胜's avatar
MPU王荣胜 committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

if __name__ == '__main__':
    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.child = nn.Linear(100, 200)
        
        def forward(self, x):
            return self.child(x)

    model = Model()
    torch.save(model.state_dict(), "linear.pt")
    x = torch.randn(2, 100)
    out1 = model(x)
    model.child = LoraLinear(100, 200, 10)
    model.load_state_dict(torch.load("linear.pt"), strict=False)
    out2 = model(x)
    torch.save(model.state_dict(), "lora.pt")
    ckpt = torch.load("lora.pt")
    breakpoint()
    model.load_state_dict(ckpt, strict=False)
    out3 = model(x)
    breakpoint()