modules.py 7.95 KB
Newer Older
1
2
3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
Tim Dettmers's avatar
Tim Dettmers committed
4
# LICENSE file in the root directory of this source tree.
Tom Aarsen's avatar
Tom Aarsen committed
5
from typing import Optional, TypeVar, Union, overload
Tim Dettmers's avatar
Tim Dettmers committed
6

7
import torch
Tim Dettmers's avatar
Tim Dettmers committed
8
import torch.nn.functional as F
9
from torch import Tensor, device, dtype, nn
Tim Dettmers's avatar
Tim Dettmers committed
10

11
import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
12
13
from bitsandbytes.optim import GlobalOptimManager

14
15
T = TypeVar("T", bound="torch.nn.Module")

Tim Dettmers's avatar
Tim Dettmers committed
16

Tim Dettmers's avatar
Tim Dettmers committed
17
class StableEmbedding(torch.nn.Embedding):
18
19
20
21
22
23
24
25
26
27
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        _weight: Optional[Tensor] = None,
28
29
        device=None,
        dtype=None,
30
    ) -> None:
31
        super().__init__(
32
33
34
35
36
37
38
39
            num_embeddings,
            embedding_dim,
            padding_idx,
            max_norm,
            norm_type,
            scale_grad_by_freq,
            sparse,
            _weight,
40
41
            device,
            dtype,
42
        )
43
        self.norm = torch.nn.LayerNorm(embedding_dim, device=device)
44
45
46
        GlobalOptimManager.get_instance().register_module_override(
            self, "weight", {"optim_bits": 32}
        )
Tim Dettmers's avatar
Tim Dettmers committed
47
48
49
50
51

    def reset_parameters(self) -> None:
        torch.nn.init.xavier_uniform_(self.weight)
        self._fill_padding_idx_with_zero()

52
    """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
Tim Dettmers's avatar
Tim Dettmers committed
53
54
55
56
        to make the Layer compatible with Pytorch < 1.9.
        This means that if this changes in future PyTorch releases this need to change too
        which is cumbersome. However, with this we can ensure compatibility with previous
        PyTorch releases.
57
58
    """

Tim Dettmers's avatar
Tim Dettmers committed
59
60
61
62
63
64
65
    def _fill_padding_idx_with_zero(self) -> None:
        if self.padding_idx is not None:
            with torch.no_grad():
                self.weight[self.padding_idx].fill_(0)

    def forward(self, input: Tensor) -> Tensor:
        emb = F.embedding(
66
67
68
69
70
71
72
73
            input,
            self.weight,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse,
        )
Tim Dettmers's avatar
Tim Dettmers committed
74

75
76
77
78
        # always apply layer norm in full precision
        emb = emb.to(torch.get_default_dtype())

        return self.norm(emb).to(self.weight.dtype)
79
80
81


class Embedding(torch.nn.Embedding):
82
83
84
85
86
87
88
89
90
91
92
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        _weight: Optional[Tensor] = None,
    ) -> None:
93
        super().__init__(
94
95
96
97
98
99
100
101
102
103
104
105
            num_embeddings,
            embedding_dim,
            padding_idx,
            max_norm,
            norm_type,
            scale_grad_by_freq,
            sparse,
            _weight,
        )
        GlobalOptimManager.get_instance().register_module_override(
            self, "weight", {"optim_bits": 32}
        )
106
107
108
109
110

    def reset_parameters(self) -> None:
        torch.nn.init.xavier_uniform_(self.weight)
        self._fill_padding_idx_with_zero()

111
    """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
112
113
114
115
        to make the Layer compatible with Pytorch < 1.9.
        This means that if this changes in future PyTorch releases this need to change too
        which is cumbersome. However, with this we can ensure compatibility with previous
        PyTorch releases.
116
117
    """

118
119
120
121
122
123
124
    def _fill_padding_idx_with_zero(self) -> None:
        if self.padding_idx is not None:
            with torch.no_grad():
                self.weight[self.padding_idx].fill_(0)

    def forward(self, input: Tensor) -> Tensor:
        emb = F.embedding(
125
126
127
128
129
130
131
132
            input,
            self.weight,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse,
        )
133
134

        return emb
Tim Dettmers's avatar
Tim Dettmers committed
135

136

Tim Dettmers's avatar
Tim Dettmers committed
137
class Int8Params(torch.nn.Parameter):
138
    def __new__(
139
140
141
142
143
144
        cls,
        data=None,
        requires_grad=True,
        has_fp16_weights=False,
        CB=None,
        SCB=None,
145
    ):
Tim Dettmers's avatar
Tim Dettmers committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        cls.has_fp16_weights = has_fp16_weights
        cls.CB = None
        cls.SCB = None
        if data is None:
            data = torch.empty(0)
        return torch.Tensor._make_subclass(cls, data, requires_grad)

    def cuda(self, device):
        if self.has_fp16_weights:
            return super().cuda(device)
        else:
            # we store the 8-bit rows-major weight
            # we convert this weight to the turning/ampere weight during the first inference pass
            B = self.data.contiguous().half().cuda(device)
            CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
            del CBt
dbaranchuk's avatar
dbaranchuk committed
162
            del SCBt
Tim Dettmers's avatar
Tim Dettmers committed
163
            self.data = CB
164
165
            setattr(self, "CB", CB)
            setattr(self, "SCB", SCB)
Tim Dettmers's avatar
Tim Dettmers committed
166
167
168
169

        return self

    @overload
170
171
172
173
174
175
    def to(
        self: T,
        device: Optional[Union[int, device]] = ...,
        dtype: Optional[Union[dtype, str]] = ...,
        non_blocking: bool = ...,
    ) -> T:
Tim Dettmers's avatar
Tim Dettmers committed
176
177
178
179
180
181
182
183
184
185
186
        ...

    @overload
    def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
        ...

    @overload
    def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
        ...

    def to(self, *args, **kwargs):
187
188
189
190
191
192
193
194
195
196
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
            *args, **kwargs
        )

        if (
            device is not None
            and device.type == "cuda"
            and self.data.device.type == "cpu"
        ):
            return self.cuda(device)
Tim Dettmers's avatar
Tim Dettmers committed
197
        else:
198
            new_param = Int8Params(
199
200
201
                super().to(
                    device=device, dtype=dtype, non_blocking=non_blocking
                ),
202
203
204
                requires_grad=self.requires_grad,
                has_fp16_weights=self.has_fp16_weights,
            )
Tim Dettmers's avatar
Tim Dettmers committed
205
206
207
208
209
210
211
            new_param.CB = self.CB
            new_param.SCB = self.SCB

            return new_param


class Linear8bitLt(nn.Linear):
212
213
214
215
    def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
                       memory_efficient_backward=False, threshold=0.0, index=None):
        super().__init__(input_features, output_features, bias)
        assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
Tim Dettmers's avatar
Tim Dettmers committed
216
        self.state = bnb.MatmulLtState()
217
        self.index = index
Tim Dettmers's avatar
Tim Dettmers committed
218
219
220

        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
221
        self.state.memory_efficient_backward = memory_efficient_backward
Tim Dettmers's avatar
Tim Dettmers committed
222
223
224
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True

225
        self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
Tim Dettmers's avatar
Tim Dettmers committed
226
227
228
229
230
231
232

    def init_8bit_state(self):
        self.state.CB = self.weight.CB
        self.state.SCB = self.weight.SCB
        self.weight.CB = None
        self.weight.SCB = None

233
    def forward(self, x: torch.Tensor):
Tim Dettmers's avatar
Tim Dettmers committed
234
        self.state.is_training = self.training
235
236
        if self.weight.CB is not None:
            self.init_8bit_state()
237
238

        # weights are cast automatically as Int8Params, but the bias has to be cast manually
239
240
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)
Tim Dettmers's avatar
Tim Dettmers committed
241

Tim Dettmers's avatar
Tim Dettmers committed
242
        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
243
        if not self.state.has_fp16_weights:
244
            if self.state.CB is not None and self.state.CxB is not None:
245
246
247
248
                # we converted 8-bit row major to turing/ampere format in the first inference pass
                # we no longer need the row-major weight
                del self.state.CB
                self.weight.data = self.state.CxB
Tim Dettmers's avatar
Tim Dettmers committed
249
        return out