modules.py 17.4 KB
Newer Older
1
2
3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
Tim Dettmers's avatar
Tim Dettmers committed
4
# LICENSE file in the root directory of this source tree.
Tom Aarsen's avatar
Tom Aarsen committed
5
from typing import Optional, TypeVar, Union, overload
Tim Dettmers's avatar
Tim Dettmers committed
6

7
import torch
Tim Dettmers's avatar
Tim Dettmers committed
8
import torch.nn.functional as F
9
from torch import Tensor, device, dtype, nn
Tim Dettmers's avatar
Tim Dettmers committed
10

11
import bitsandbytes as bnb
12
13
import bitsandbytes.functional
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout
Tim Dettmers's avatar
Tim Dettmers committed
14
from bitsandbytes.optim import GlobalOptimManager
15
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
Tim Dettmers's avatar
Tim Dettmers committed
16

17
18
T = TypeVar("T", bound="torch.nn.Module")

Tim Dettmers's avatar
Tim Dettmers committed
19

Tim Dettmers's avatar
Tim Dettmers committed
20
class StableEmbedding(torch.nn.Embedding):
21
22
23
24
25
26
27
28
29
30
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        _weight: Optional[Tensor] = None,
31
32
        device=None,
        dtype=None,
33
    ) -> None:
34
        super().__init__(
35
36
37
38
39
40
41
42
            num_embeddings,
            embedding_dim,
            padding_idx,
            max_norm,
            norm_type,
            scale_grad_by_freq,
            sparse,
            _weight,
43
44
            device,
            dtype,
45
        )
46
        self.norm = torch.nn.LayerNorm(embedding_dim, device=device)
47
48
49
        GlobalOptimManager.get_instance().register_module_override(
            self, "weight", {"optim_bits": 32}
        )
Tim Dettmers's avatar
Tim Dettmers committed
50
51
52
53
54

    def reset_parameters(self) -> None:
        torch.nn.init.xavier_uniform_(self.weight)
        self._fill_padding_idx_with_zero()

55
    """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
Tim Dettmers's avatar
Tim Dettmers committed
56
57
58
59
        to make the Layer compatible with Pytorch < 1.9.
        This means that if this changes in future PyTorch releases this need to change too
        which is cumbersome. However, with this we can ensure compatibility with previous
        PyTorch releases.
60
61
    """

Tim Dettmers's avatar
Tim Dettmers committed
62
63
64
65
66
67
68
    def _fill_padding_idx_with_zero(self) -> None:
        if self.padding_idx is not None:
            with torch.no_grad():
                self.weight[self.padding_idx].fill_(0)

    def forward(self, input: Tensor) -> Tensor:
        emb = F.embedding(
69
70
71
72
73
74
75
76
            input,
            self.weight,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse,
        )
Tim Dettmers's avatar
Tim Dettmers committed
77

78
79
80
81
        # always apply layer norm in full precision
        emb = emb.to(torch.get_default_dtype())

        return self.norm(emb).to(self.weight.dtype)
82
83
84


class Embedding(torch.nn.Embedding):
85
86
87
88
89
90
91
92
93
94
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        _weight: Optional[Tensor] = None,
95
        device: Optional[device] = None,
96
    ) -> None:
97
        super().__init__(
98
99
100
101
102
103
104
105
            num_embeddings,
            embedding_dim,
            padding_idx,
            max_norm,
            norm_type,
            scale_grad_by_freq,
            sparse,
            _weight,
106
            device=device
107
108
109
110
        )
        GlobalOptimManager.get_instance().register_module_override(
            self, "weight", {"optim_bits": 32}
        )
111
112
113
114
115

    def reset_parameters(self) -> None:
        torch.nn.init.xavier_uniform_(self.weight)
        self._fill_padding_idx_with_zero()

116
    """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
117
118
119
120
        to make the Layer compatible with Pytorch < 1.9.
        This means that if this changes in future PyTorch releases this need to change too
        which is cumbersome. However, with this we can ensure compatibility with previous
        PyTorch releases.
121
122
    """

123
124
125
126
127
128
129
    def _fill_padding_idx_with_zero(self) -> None:
        if self.padding_idx is not None:
            with torch.no_grad():
                self.weight[self.padding_idx].fill_(0)

    def forward(self, input: Tensor) -> Tensor:
        emb = F.embedding(
130
131
132
133
134
135
136
137
            input,
            self.weight,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse,
        )
138
139

        return emb
Tim Dettmers's avatar
Tim Dettmers committed
140

141
142
class Params4bit(torch.nn.Parameter):
    def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'):
Tim Dettmers's avatar
Tim Dettmers committed
143
144
        if data is None:
            data = torch.empty(0)
145
146
147
148
149

        self = torch.Tensor._make_subclass(cls, data, requires_grad)
        self.blocksize = blocksize
        self.compress_statistics = compress_statistics
        self.quant_type = quant_type
150
151
        self.quant_state = quant_state
        self.data = data
152
        return self
Tim Dettmers's avatar
Tim Dettmers committed
153
154
155

    def cuda(self, device):
        w = self.data.contiguous().half().cuda(device)
156
157
        w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type)
        self.data = w_4bit
Tim Dettmers's avatar
Tim Dettmers committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        self.quant_state = quant_state

        return self

    @overload
    def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T:
        ...

    @overload
    def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
        ...

    @overload
    def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
        ...

    def to(self, *args, **kwargs):
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

        if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"):
            return self.cuda(device)
        else:
180
181
182
183
184
185
            s = self.quant_state
            if s is not None:
                # make sure the quantization state is on the right device
                s[0] = s[0].to(device)
                if self.compress_statistics:
                    # TODO: refactor this. This is a nightmare
186
187
188
189
190
191
192
                    # for 4-bit: 
                    # state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
                    # state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
                    #s[-2][0] = s[-2][0].to(device) # offset
                    #s[-2][1][0] = s[-2][1][0].to(device) # nested absmax

                    # for 8-bit
193
194
195
                    s[-2][0] = s[-2][0].to(device) # offset
                    s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics
                    s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook
196
            new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
197
198
199
                                  requires_grad=self.requires_grad, quant_state=self.quant_state,
                                   blocksize=self.blocksize, compress_statistics=self.compress_statistics,
                                   quant_type=self.quant_type)
Tim Dettmers's avatar
Tim Dettmers committed
200
201
202

            return new_param

203
class Linear4bit(nn.Linear):
204
205
    def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
        super().__init__(input_features, output_features, bias, device)
206
        self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
Tim Dettmers's avatar
Tim Dettmers committed
207
        self.compute_dtype = compute_dtype
Tim Dettmers's avatar
Tim Dettmers committed
208
209
210
211
212
213

    def forward(self, x: torch.Tensor):
        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

214
215
        if getattr(self.weight, 'quant_state', None) is None:
            print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
216
        inp_dtype = x.dtype
Tim Dettmers's avatar
Tim Dettmers committed
217
218
219
        if self.compute_dtype is not None:
            x = x.to(self.compute_dtype)

220
        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
221
        out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
Tim Dettmers's avatar
Tim Dettmers committed
222

223
        out = out.to(inp_dtype)
Tim Dettmers's avatar
Tim Dettmers committed
224
225
226

        return out

227
class LinearFP4(Linear4bit):
228
229
    def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
        super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
230
231

class LinearNF4(Linear4bit):
232
233
    def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
        super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
234

235

Tim Dettmers's avatar
Tim Dettmers committed
236

Tim Dettmers's avatar
Tim Dettmers committed
237
class Int8Params(torch.nn.Parameter):
238
    def __new__(
239
240
241
242
243
244
        cls,
        data=None,
        requires_grad=True,
        has_fp16_weights=False,
        CB=None,
        SCB=None,
245
    ):
Tim Dettmers's avatar
Tim Dettmers committed
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        cls.has_fp16_weights = has_fp16_weights
        cls.CB = None
        cls.SCB = None
        if data is None:
            data = torch.empty(0)
        return torch.Tensor._make_subclass(cls, data, requires_grad)

    def cuda(self, device):
        if self.has_fp16_weights:
            return super().cuda(device)
        else:
            # we store the 8-bit rows-major weight
            # we convert this weight to the turning/ampere weight during the first inference pass
            B = self.data.contiguous().half().cuda(device)
            CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
            del CBt
dbaranchuk's avatar
dbaranchuk committed
262
            del SCBt
Tim Dettmers's avatar
Tim Dettmers committed
263
            self.data = CB
264
265
            setattr(self, "CB", CB)
            setattr(self, "SCB", SCB)
Tim Dettmers's avatar
Tim Dettmers committed
266
267
268
269

        return self

    @overload
270
271
272
273
274
275
    def to(
        self: T,
        device: Optional[Union[int, device]] = ...,
        dtype: Optional[Union[dtype, str]] = ...,
        non_blocking: bool = ...,
    ) -> T:
Tim Dettmers's avatar
Tim Dettmers committed
276
277
278
279
280
281
282
283
284
285
286
        ...

    @overload
    def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
        ...

    @overload
    def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
        ...

    def to(self, *args, **kwargs):
287
288
289
290
291
292
293
294
295
296
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
            *args, **kwargs
        )

        if (
            device is not None
            and device.type == "cuda"
            and self.data.device.type == "cpu"
        ):
            return self.cuda(device)
Tim Dettmers's avatar
Tim Dettmers committed
297
        else:
298
            new_param = Int8Params(
299
300
301
                super().to(
                    device=device, dtype=dtype, non_blocking=non_blocking
                ),
302
303
304
                requires_grad=self.requires_grad,
                has_fp16_weights=self.has_fp16_weights,
            )
Tim Dettmers's avatar
Tim Dettmers committed
305
306
307
308
309
310
            new_param.CB = self.CB
            new_param.SCB = self.SCB

            return new_param


Tim Dettmers's avatar
Tim Dettmers committed
311

Tim Dettmers's avatar
Tim Dettmers committed
312
class Linear8bitLt(nn.Linear):
313
    def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
314
315
                       memory_efficient_backward=False, threshold=0.0, index=None, device=None):
        super().__init__(input_features, output_features, bias, device)
316
        assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
Tim Dettmers's avatar
Tim Dettmers committed
317
        self.state = bnb.MatmulLtState()
318
        self.index = index
Tim Dettmers's avatar
Tim Dettmers committed
319
320
321

        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
322
        self.state.memory_efficient_backward = memory_efficient_backward
Tim Dettmers's avatar
Tim Dettmers committed
323
324
325
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True

326
        self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
Tim Dettmers's avatar
Tim Dettmers committed
327

328
    def _save_to_state_dict(self, destination, prefix, keep_vars):
329
330
331
332
333
334
335
336
337
        if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
            # reorder weight layout back from ampere/turing to row
            reorder_layout = True
            weight_clone = self.weight.data.clone()
        else:
            reorder_layout = False

        try:
            if reorder_layout:
338
                self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357

            super()._save_to_state_dict(destination, prefix, keep_vars)

            # we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
            weight_name = "SCB"

            # case 1: .cuda was called, SCB is in self.weight
            param_from_weight = getattr(self.weight, weight_name)
            # case 2: self.init_8bit_state was called, SCB is in self.state
            param_from_state = getattr(self.state, weight_name)

            key_name = prefix + f"{weight_name}"
            if param_from_weight is not None:
                destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
            elif not self.state.has_fp16_weights and param_from_state is not None:
                destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
        finally:
            if reorder_layout:
                self.weight.data = weight_clone
358
359
360
361
362
363
364
365

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
                                      error_msgs)
        for key in unexpected_keys:
            input_name = key[len(prefix):]
            if input_name == "SCB":
366
367
368
369
370
                if self.weight.SCB is None:
                    # buffers not yet initialized, can't call them directly without
                    raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
                                       "not supported. Please call module.cuda() before module.load_state_dict()")

371
372
373
                input_param = state_dict[key]
                self.weight.SCB.copy_(input_param)
                unexpected_keys.remove(key)
Tim Dettmers's avatar
Tim Dettmers committed
374
375
376
377
378
379
380

    def init_8bit_state(self):
        self.state.CB = self.weight.CB
        self.state.SCB = self.weight.SCB
        self.weight.CB = None
        self.weight.SCB = None

381
    def forward(self, x: torch.Tensor):
Tim Dettmers's avatar
Tim Dettmers committed
382
        self.state.is_training = self.training
383
384
        if self.weight.CB is not None:
            self.init_8bit_state()
385
386

        # weights are cast automatically as Int8Params, but the bias has to be cast manually
387
388
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)
Tim Dettmers's avatar
Tim Dettmers committed
389

Tim Dettmers's avatar
Tim Dettmers committed
390
        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
Tim Dettmers's avatar
Tim Dettmers committed
391

392
        if not self.state.has_fp16_weights:
393
            if self.state.CB is not None and self.state.CxB is not None:
394
395
396
397
                # we converted 8-bit row major to turing/ampere format in the first inference pass
                # we no longer need the row-major weight
                del self.state.CB
                self.weight.data = self.state.CxB
Tim Dettmers's avatar
Tim Dettmers committed
398
        return out
Tim Dettmers's avatar
Tim Dettmers committed
399

Mitchell Wortsman's avatar
Mitchell Wortsman committed
400

401
class OutlierAwareLinear(nn.Linear):
402
403
    def __init__(self, input_features, output_features, bias=True, device=None):
        super().__init__(input_features, output_features, bias, device)
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
        self.outlier_dim = None
        self.is_quantized = False

    def forward_with_outliers(self, x, outlier_idx):
        raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function')

    def quantize_weight(self, w, outlier_idx):
        raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function')

    def forward(self, x):
        if self.outlier_dim is None:
            tracer = OutlierTracer.get_instance()
            if not tracer.is_initialized():
                print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer')
            outlier_idx = tracer.get_outliers(self.weight)
            #print(outlier_idx, tracer.get_hvalue(self.weight))
            self.outlier_dim = outlier_idx

        if not self.is_quantized:
            w = self.quantize_weight(self.weight, self.outlier_dim)
            self.weight.data.copy_(w)
            self.is_quantized = True

427
class SwitchBackLinearBnb(nn.Linear):
Mitchell Wortsman's avatar
Mitchell Wortsman committed
428
429
430
431
432
433
434
435
436
    def __init__(
        self,
        input_features,
        output_features,
        bias=True,
        has_fp16_weights=True,
        memory_efficient_backward=False,
        threshold=0.0,
        index=None,
437
        device=None
Mitchell Wortsman's avatar
Mitchell Wortsman committed
438
439
    ):
        super().__init__(
440
            input_features, output_features, bias, device
Mitchell Wortsman's avatar
Mitchell Wortsman committed
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        )
        self.state = bnb.MatmulLtState()
        self.index = index

        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
        self.state.memory_efficient_backward = memory_efficient_backward
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True

        self.weight = Int8Params(
            self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
        )

    def init_8bit_state(self):
        self.state.CB = self.weight.CB
        self.state.SCB = self.weight.SCB
        self.weight.CB = None
        self.weight.SCB = None

    def forward(self, x):
        self.state.is_training = self.training

        if self.weight.CB is not None:
            self.init_8bit_state()

        out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias