modules.py 8.15 KB
Newer Older
1
2
3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
Tim Dettmers's avatar
Tim Dettmers committed
4
# LICENSE file in the root directory of this source tree.
5
6
7
8
9
10
11
12
13
14
15
16
17
from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    Mapping,
    Optional,
    Set,
    Tuple,
    TypeVar,
    Union,
    overload,
)
Tim Dettmers's avatar
Tim Dettmers committed
18

19
import torch
Tim Dettmers's avatar
Tim Dettmers committed
20
import torch.nn.functional as F
21
22
from torch import Tensor, device, dtype, nn
from torch.nn.parameter import Parameter
Tim Dettmers's avatar
Tim Dettmers committed
23

24
import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
25
26
from bitsandbytes.optim import GlobalOptimManager

27
28
T = TypeVar("T", bound="torch.nn.Module")

Tim Dettmers's avatar
Tim Dettmers committed
29

Tim Dettmers's avatar
Tim Dettmers committed
30
class StableEmbedding(torch.nn.Embedding):
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        _weight: Optional[Tensor] = None,
    ) -> None:
        super(StableEmbedding, self).__init__(
            num_embeddings,
            embedding_dim,
            padding_idx,
            max_norm,
            norm_type,
            scale_grad_by_freq,
            sparse,
            _weight,
        )
Tim Dettmers's avatar
Tim Dettmers committed
52
        self.norm = torch.nn.LayerNorm(embedding_dim)
53
54
55
        GlobalOptimManager.get_instance().register_module_override(
            self, "weight", {"optim_bits": 32}
        )
Tim Dettmers's avatar
Tim Dettmers committed
56
57
58
59
60

    def reset_parameters(self) -> None:
        torch.nn.init.xavier_uniform_(self.weight)
        self._fill_padding_idx_with_zero()

61
    """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
Tim Dettmers's avatar
Tim Dettmers committed
62
63
64
65
        to make the Layer compatible with Pytorch < 1.9.
        This means that if this changes in future PyTorch releases this need to change too
        which is cumbersome. However, with this we can ensure compatibility with previous
        PyTorch releases.
66
67
    """

Tim Dettmers's avatar
Tim Dettmers committed
68
69
70
71
72
73
74
    def _fill_padding_idx_with_zero(self) -> None:
        if self.padding_idx is not None:
            with torch.no_grad():
                self.weight[self.padding_idx].fill_(0)

    def forward(self, input: Tensor) -> Tensor:
        emb = F.embedding(
75
76
77
78
79
80
81
82
            input,
            self.weight,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse,
        )
Tim Dettmers's avatar
Tim Dettmers committed
83
84

        return self.norm(emb)
85
86
87


class Embedding(torch.nn.Embedding):
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        _weight: Optional[Tensor] = None,
    ) -> None:
        super(Embedding, self).__init__(
            num_embeddings,
            embedding_dim,
            padding_idx,
            max_norm,
            norm_type,
            scale_grad_by_freq,
            sparse,
            _weight,
        )
        GlobalOptimManager.get_instance().register_module_override(
            self, "weight", {"optim_bits": 32}
        )
112
113
114
115
116

    def reset_parameters(self) -> None:
        torch.nn.init.xavier_uniform_(self.weight)
        self._fill_padding_idx_with_zero()

117
    """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
118
119
120
121
        to make the Layer compatible with Pytorch < 1.9.
        This means that if this changes in future PyTorch releases this need to change too
        which is cumbersome. However, with this we can ensure compatibility with previous
        PyTorch releases.
122
123
    """

124
125
126
127
128
129
130
    def _fill_padding_idx_with_zero(self) -> None:
        if self.padding_idx is not None:
            with torch.no_grad():
                self.weight[self.padding_idx].fill_(0)

    def forward(self, input: Tensor) -> Tensor:
        emb = F.embedding(
131
132
133
134
135
136
137
138
            input,
            self.weight,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse,
        )
139
140

        return emb
Tim Dettmers's avatar
Tim Dettmers committed
141

142

Tim Dettmers's avatar
Tim Dettmers committed
143
class Int8Params(torch.nn.Parameter):
144
    def __new__(
145
146
147
148
149
150
        cls,
        data=None,
        requires_grad=True,
        has_fp16_weights=False,
        CB=None,
        SCB=None,
151
    ):
Tim Dettmers's avatar
Tim Dettmers committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        cls.has_fp16_weights = has_fp16_weights
        cls.CB = None
        cls.SCB = None
        if data is None:
            data = torch.empty(0)
        return torch.Tensor._make_subclass(cls, data, requires_grad)

    def cuda(self, device):
        if self.has_fp16_weights:
            return super().cuda(device)
        else:
            # we store the 8-bit rows-major weight
            # we convert this weight to the turning/ampere weight during the first inference pass
            B = self.data.contiguous().half().cuda(device)
            CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
            del CBt
dbaranchuk's avatar
dbaranchuk committed
168
            del SCBt
Tim Dettmers's avatar
Tim Dettmers committed
169
            self.data = CB
170
171
            setattr(self, "CB", CB)
            setattr(self, "SCB", SCB)
Tim Dettmers's avatar
Tim Dettmers committed
172
173
174
175

        return self

    @overload
176
177
178
179
180
181
    def to(
        self: T,
        device: Optional[Union[int, device]] = ...,
        dtype: Optional[Union[dtype, str]] = ...,
        non_blocking: bool = ...,
    ) -> T:
Tim Dettmers's avatar
Tim Dettmers committed
182
183
184
185
186
187
188
189
190
191
192
        ...

    @overload
    def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
        ...

    @overload
    def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
        ...

    def to(self, *args, **kwargs):
193
194
195
196
197
198
199
200
201
202
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
            *args, **kwargs
        )

        if (
            device is not None
            and device.type == "cuda"
            and self.data.device.type == "cpu"
        ):
            return self.cuda(device)
Tim Dettmers's avatar
Tim Dettmers committed
203
        else:
204
            new_param = Int8Params(
205
206
207
                super().to(
                    device=device, dtype=dtype, non_blocking=non_blocking
                ),
208
209
210
                requires_grad=self.requires_grad,
                has_fp16_weights=self.has_fp16_weights,
            )
Tim Dettmers's avatar
Tim Dettmers committed
211
212
213
214
215
216
217
            new_param.CB = self.CB
            new_param.SCB = self.SCB

            return new_param


class Linear8bitLt(nn.Linear):
218
219
220
221
222
223
    def __init__(
        self,
        input_features,
        output_features,
        bias=True,
        has_fp16_weights=True,
dbaranchuk's avatar
dbaranchuk committed
224
        memory_efficient_backward=False,
225
226
227
        threshold=0.0,
        index=None,
    ):
228
229
230
        super(Linear8bitLt, self).__init__(
            input_features, output_features, bias
        )
Tim Dettmers's avatar
Tim Dettmers committed
231
        self.state = bnb.MatmulLtState()
232
        self.index = index
Tim Dettmers's avatar
Tim Dettmers committed
233
234
235

        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
236
        self.state.memory_efficient_backward = memory_efficient_backward
Tim Dettmers's avatar
Tim Dettmers committed
237
238
239
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True

justheuristic's avatar
debug  
justheuristic committed
240
241
242
        self.weight = Int8Params(
            self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
        )
Tim Dettmers's avatar
Tim Dettmers committed
243
244
245
246
247
248
249
250
251
252

    def init_8bit_state(self):
        self.state.CB = self.weight.CB
        self.state.SCB = self.weight.SCB
        self.weight.CB = None
        self.weight.SCB = None

    def forward(self, x):
        self.state.is_training = self.training

253
254
        if self.weight.CB is not None:
            self.init_8bit_state()
255
256
257

        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != torch.float16:
Tim Dettmers's avatar
Tim Dettmers committed
258
            self.bias.data = self.bias.data.half()
Tim Dettmers's avatar
Tim Dettmers committed
259

Tim Dettmers's avatar
Tim Dettmers committed
260
        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
Tim Dettmers's avatar
Tim Dettmers committed
261

262
263
264
265
266
267
268
269
        if not self.state.has_fp16_weights:
            if not self.state.memory_efficient_backward and self.state.CB is not None:
                # we converted 8-bit row major to turing/ampere format in the first inference pass
                # we no longer need the row-major weight
                del self.state.CB
                self.weight.data = self.state.CxB
            elif self.state.memory_efficient_backward and self.state.CxB is not None:
                # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
270
                # Thus, we delete CxB from the state.
271
                del self.state.CxB
Tim Dettmers's avatar
Tim Dettmers committed
272
273

        return out