lora.py 5.46 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import math
from typing import Optional

import loralib as lora
import torch
import torch.nn as nn
import torch.nn.functional as F


class LoraLinear(lora.LoRALayer, nn.Module):
    """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
    """

    def __init__(
        self,
        weight: nn.Parameter,
        bias: Optional[nn.Parameter],
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.,
        fan_in_fan_out: bool = False,    # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        merge_weights: bool = True,
    ):
        nn.Module.__init__(self)
        lora.LoRALayer.__init__(self,
                                r=r,
                                lora_alpha=lora_alpha,
                                lora_dropout=lora_dropout,
                                merge_weights=merge_weights)
        self.weight = weight
        self.bias = bias

        out_features, in_features = weight.shape
        self.in_features = in_features
        self.out_features = out_features

        self.fan_in_fan_out = fan_in_fan_out
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()
        if fan_in_fan_out:
            self.weight.data = self.weight.data.T

    def reset_parameters(self):
        if hasattr(self, 'lora_A'):
51
            # Initialize A with the default values for nn.Linear and set B to zero.
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
52
53
54
55
56
57
58
59
60
61
62
63
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def train(self, mode: bool = True):

        def T(w):
            return w.T if self.fan_in_fan_out else w

        nn.Module.train(self, mode)
        if self.merge_weights and self.merged:
            # Make sure that the weights are not merged
            if self.r > 0:
64
65
66
67
68
69
70
                if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
                    # FIXME(csric): temporary fix
                    self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
                    self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
                    self.reset_parameters()
                else:
                    self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            self.merged = False

    def eval(self):

        def T(w):
            return w.T if self.fan_in_fan_out else w

        nn.Module.eval(self)
        if self.merge_weights and not self.merged:
            # Merge the weights and mark it
            if self.r > 0:
                self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
                delattr(self, 'lora_A')
                delattr(self, 'lora_B')
            self.merged = True

    def forward(self, x: torch.Tensor):

        def T(w):
            return w.T if self.fan_in_fan_out else w

        if self.r > 0 and not self.merged:
            result = F.linear(x, T(self.weight), bias=self.bias)
            if self.r > 0:
                result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
            return result
        else:
            return F.linear(x, T(self.weight), bias=self.bias)


101
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
102
103
104
105
106
    assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
    lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
    return lora_linear


107
def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
108
109
    for name, child in module.named_children():
        if isinstance(child, nn.Linear):
110
            setattr(module, name, _lora_linear_wrapper(child, lora_rank))
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
111
        else:
112
            _convert_to_lora_recursively(child, lora_rank)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
113
114


115
116
117
118
119
120
121
122
123
124
125
126
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
    """Convert a torch.nn.Module to a LoRA module.

    Args:
        module (nn.Module): The module to convert.
        lora_rank (int): LoRA rank.

    Returns:
        nn.Module: The converted module.
    """
    if lora_rank <= 0:
        return module
127
    _convert_to_lora_recursively(module, lora_rank)
128
129
130
131
    lora.mark_only_lora_as_trainable(module, lora_train_bias)
    return module


Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
132
133
class LoRAModule(nn.Module):
    """A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
134
    This class will convert all torch.nn.Linear layer to LoraLinear layer.
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148

    Args:
        lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
        lora_train_bias (str, optional): Whether LoRA train biases.
            'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
            Defaults to 'none'.
    """

    def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
        super().__init__()
        self.lora_rank = lora_rank
        self.lora_train_bias = lora_train_bias

    def convert_to_lora(self) -> None:
149
        convert_to_lora_module(self, self.lora_rank, self.lora_train_bias)