utils.py 2.68 KB
Newer Older
1
2
import bitsandbytes as bnb
import torch
3
4
import torch.nn as nn

5
6
7
8
9
10
11
12
13
14
15

class Linear8bit(nn.Linear):
    def __init__(
        self,
        input_features,
        output_features,
        bias=True,
        has_fp16_weights=False,
        memory_efficient_backward=False,
        threshold=6.0,
        weight_data=None,
16
        bias_data=None,
17
    ):
18
        super(Linear8bit, self).__init__(input_features, output_features, bias)
19
20
21
22
23
24
25
        self.state = bnb.MatmulLtState()
        self.bias = bias_data
        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
        self.state.memory_efficient_backward = memory_efficient_backward
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True
26

27
28
29
30
        self.register_parameter("SCB", nn.Parameter(torch.empty(0), requires_grad=False))
        self.weight = weight_data
        self.quant()

31
    def quant(self):
32
33
34
35
36
37
38
39
40
41
        weight = self.weight.data.contiguous().half().cuda()
        CB, _, SCB, _, _ = bnb.functional.double_quant(weight)
        delattr(self, "weight")
        setattr(self, "weight", nn.Parameter(CB, requires_grad=False))
        delattr(self, "SCB")
        setattr(self, "SCB", nn.Parameter(SCB, requires_grad=False))
        del weight

    def forward(self, x):
        self.state.is_training = self.training
42

43
44
        if self.bias is not None and self.bias.dtype != torch.float16:
            self.bias.data = self.bias.data.half()
45

46
47
        self.state.CB = self.weight.data
        self.state.SCB = self.SCB.data
48

49
50
51
52
        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
        del self.state.CxB
        return out

53

54
55
56
57
58
def replace_module(model):
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            replace_module(module)

59
        if isinstance(module, nn.Linear) and "out_proj" not in name:
60
            model._modules[name] = Linear8bit(
61
62
63
64
65
66
                input_features=module.in_features,
                output_features=module.out_features,
                threshold=6.0,
                weight_data=module.weight,
                bias_data=module.bias,
            )
67
68
    return model

69

70
71
72
73
74
75
76
77
78
79
80
81
def getModelSize(model):
    param_size = 0
    param_sum = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
        param_sum += param.nelement()
    buffer_size = 0
    buffer_sum = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
        buffer_sum += buffer.nelement()
    all_size = (param_size + buffer_size) / 1024 / 1024
82
    print("Model Size: {:.3f}MB".format(all_size))
83
    return (param_size, param_sum, buffer_size, buffer_sum, all_size)