modules.py 8.97 KB
Newer Older
1
2
3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
Tim Dettmers's avatar
Tim Dettmers committed
4
# LICENSE file in the root directory of this source tree.
5
6
7
8
9
10
11
12
13
14
15
16
17
from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    Mapping,
    Optional,
    Set,
    Tuple,
    TypeVar,
    Union,
    overload,
)
Tim Dettmers's avatar
Tim Dettmers committed
18

19
import torch
Tim Dettmers's avatar
Tim Dettmers committed
20
import torch.nn.functional as F
21
22
from torch import Tensor, device, dtype, nn
from torch.nn.parameter import Parameter
Tim Dettmers's avatar
Tim Dettmers committed
23

24
import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
25
26
from bitsandbytes.optim import GlobalOptimManager

27
28
T = TypeVar("T", bound="torch.nn.Module")

Tim Dettmers's avatar
Tim Dettmers committed
29

Tim Dettmers's avatar
Tim Dettmers committed
30
class StableEmbedding(torch.nn.Embedding):
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        _weight: Optional[Tensor] = None,
    ) -> None:
        super(StableEmbedding, self).__init__(
            num_embeddings,
            embedding_dim,
            padding_idx,
            max_norm,
            norm_type,
            scale_grad_by_freq,
            sparse,
            _weight,
        )
Tim Dettmers's avatar
Tim Dettmers committed
52
        self.norm = torch.nn.LayerNorm(embedding_dim)
53
54
55
        GlobalOptimManager.get_instance().register_module_override(
            self, "weight", {"optim_bits": 32}
        )
Tim Dettmers's avatar
Tim Dettmers committed
56
57
58
59
60

    def reset_parameters(self) -> None:
        torch.nn.init.xavier_uniform_(self.weight)
        self._fill_padding_idx_with_zero()

61
    """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
Tim Dettmers's avatar
Tim Dettmers committed
62
63
64
65
        to make the Layer compatible with Pytorch < 1.9.
        This means that if this changes in future PyTorch releases this need to change too
        which is cumbersome. However, with this we can ensure compatibility with previous
        PyTorch releases.
66
67
    """

Tim Dettmers's avatar
Tim Dettmers committed
68
69
70
71
72
73
74
    def _fill_padding_idx_with_zero(self) -> None:
        if self.padding_idx is not None:
            with torch.no_grad():
                self.weight[self.padding_idx].fill_(0)

    def forward(self, input: Tensor) -> Tensor:
        emb = F.embedding(
75
76
77
78
79
80
81
82
            input,
            self.weight,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse,
        )
Tim Dettmers's avatar
Tim Dettmers committed
83
84

        return self.norm(emb)
85
86
87


class Embedding(torch.nn.Embedding):
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        _weight: Optional[Tensor] = None,
    ) -> None:
        super(Embedding, self).__init__(
            num_embeddings,
            embedding_dim,
            padding_idx,
            max_norm,
            norm_type,
            scale_grad_by_freq,
            sparse,
            _weight,
        )
        GlobalOptimManager.get_instance().register_module_override(
            self, "weight", {"optim_bits": 32}
        )
112
113
114
115
116

    def reset_parameters(self) -> None:
        torch.nn.init.xavier_uniform_(self.weight)
        self._fill_padding_idx_with_zero()

117
    """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
118
119
120
121
        to make the Layer compatible with Pytorch < 1.9.
        This means that if this changes in future PyTorch releases this need to change too
        which is cumbersome. However, with this we can ensure compatibility with previous
        PyTorch releases.
122
123
    """

124
125
126
127
128
129
130
    def _fill_padding_idx_with_zero(self) -> None:
        if self.padding_idx is not None:
            with torch.no_grad():
                self.weight[self.padding_idx].fill_(0)

    def forward(self, input: Tensor) -> Tensor:
        emb = F.embedding(
131
132
133
134
135
136
137
138
            input,
            self.weight,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse,
        )
139
140

        return emb
Tim Dettmers's avatar
Tim Dettmers committed
141

142

Tim Dettmers's avatar
Tim Dettmers committed
143
class Int8Params(torch.nn.Parameter):
144
    def __new__(
145
146
147
148
149
150
        cls,
        data=None,
        requires_grad=True,
        has_fp16_weights=False,
        CB=None,
        SCB=None,
151
    ):
Tim Dettmers's avatar
Tim Dettmers committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        cls.has_fp16_weights = has_fp16_weights
        cls.CB = None
        cls.SCB = None
        if data is None:
            data = torch.empty(0)
        return torch.Tensor._make_subclass(cls, data, requires_grad)

    def cuda(self, device):
        if self.has_fp16_weights:
            return super().cuda(device)
        else:
            # we store the 8-bit rows-major weight
            # we convert this weight to the turning/ampere weight during the first inference pass
            B = self.data.contiguous().half().cuda(device)
            CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
            del CBt
            del SCBt
            self.data = CB
170
171
            setattr(self, "CB", CB)
            setattr(self, "SCB", SCB)
Tim Dettmers's avatar
Tim Dettmers committed
172
173
174
175

        return self

    @overload
176
177
178
179
180
181
    def to(
        self: T,
        device: Optional[Union[int, device]] = ...,
        dtype: Optional[Union[dtype, str]] = ...,
        non_blocking: bool = ...,
    ) -> T:
Tim Dettmers's avatar
Tim Dettmers committed
182
183
184
185
186
187
188
189
190
191
192
        ...

    @overload
    def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
        ...

    @overload
    def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
        ...

    def to(self, *args, **kwargs):
193
194
195
196
197
198
199
200
201
202
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
            *args, **kwargs
        )

        if (
            device is not None
            and device.type == "cuda"
            and self.data.device.type == "cpu"
        ):
            return self.cuda(device)
Tim Dettmers's avatar
Tim Dettmers committed
203
        else:
204
            new_param = Int8Params(
205
206
207
                super().to(
                    device=device, dtype=dtype, non_blocking=non_blocking
                ),
208
209
210
                requires_grad=self.requires_grad,
                has_fp16_weights=self.has_fp16_weights,
            )
Tim Dettmers's avatar
Tim Dettmers committed
211
212
213
214
215
216
217
            new_param.CB = self.CB
            new_param.SCB = self.SCB

            return new_param


class Linear8bitLt(nn.Linear):
218
219
220
221
222
223
224
225
226
    def __init__(
        self,
        input_features,
        output_features,
        bias=True,
        has_fp16_weights=True,
        threshold=0.0,
        index=None,
    ):
227
228
229
        super(Linear8bitLt, self).__init__(
            input_features, output_features, bias
        )
Tim Dettmers's avatar
Tim Dettmers committed
230
        self.state = bnb.MatmulLtState()
231
        self.index = index
Tim Dettmers's avatar
Tim Dettmers committed
232
233
234
235
236
237

        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True

Tim Dettmers's avatar
Tim Dettmers committed
238
        self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
Tim Dettmers's avatar
Tim Dettmers committed
239
240
241
242
243
244
245
246
247
248

    def init_8bit_state(self):
        self.state.CB = self.weight.CB
        self.state.SCB = self.weight.SCB
        self.weight.CB = None
        self.weight.SCB = None

    def forward(self, x):
        self.state.is_training = self.training

249
250
        if self.weight.CB is not None:
            self.init_8bit_state()
Tim Dettmers's avatar
Tim Dettmers committed
251
252
        if self.bias.dtype != torch.float16:
            self.bias.data = self.bias.data.half()
253
254
        # assert not self.state.has_fp16_weights
        # if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
Tim Dettmers's avatar
Tim Dettmers committed
255

Tim Dettmers's avatar
Tim Dettmers committed
256
        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
Tim Dettmers's avatar
Tim Dettmers committed
257
258
259
260
261
262
263
264
265

        if not self.state.has_fp16_weights and self.state.CB is not None:
            # we converted 8-bit row major to turing/ampere format in the first inference pass
            # we no longer need the row-major weight
            del self.state.CB
            self.weight.data = self.state.CxB

        return out

266

Tim Dettmers's avatar
Tim Dettmers committed
267
class Linear8bit(nn.Linear):
268
269
270
271
272
273
274
275
276
277
    def __init__(
        self,
        input_features,
        output_features,
        bias=True,
        quant_type="vector",
        index=None,
        args=None,
        sparse_decomp=False,
    ):
Tim Dettmers's avatar
Tim Dettmers committed
278
279
280
281
282
283
284
285
286
287
        super(Linear8bit, self).__init__(input_features, output_features, bias)
        self.quant_type = quant_type
        self.index = index
        self.args = args
        self.iter = 0

    def forward(self, x):
        self.iter += 1
        if self.iter % self.args.clip_freq == 0:
            with torch.no_grad():
288
289
290
                maxval, maxidx = torch.topk(
                    torch.abs(self.weight.flatten()), k=self.args.clip_idx
                )
Tim Dettmers's avatar
Tim Dettmers committed
291
                if not dist.is_initialized() or dist.get_rank() == 0:
292
                    print("clip", maxval[-1].item())
Tim Dettmers's avatar
Tim Dettmers committed
293
294
295
                self.weight.clip_(-maxval[-1], maxval[-1])

        if self.args is not None:
296
297
298
299
300
301
302
            out = bnb.nn.functional.sparse_decomposed_linear8bit(
                x,
                self.weight,
                self.bias,
                qval=self.args.sparse_decomp_val,
                quant_type=self.args.quant_type,
            )
Tim Dettmers's avatar
Tim Dettmers committed
303
        else:
304
305
306
            out = bnb.nn.functional.linear8bit(
                x, self.weight, self.bias, quant_type=self.args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
307
308

        return out