bnb.py 3.95 KB
Newer Older
1
2
from dataclasses import dataclass

Nicolas Patry's avatar
Nicolas Patry committed
3
import bitsandbytes as bnb
4
import torch
Nicolas Patry's avatar
Nicolas Patry committed
5
from bitsandbytes.nn import Int8Params, Params4bit
6
from text_generation_server.utils.weights import UnquantizedWeight
Nicolas Patry's avatar
Nicolas Patry committed
7
8


9
@dataclass
10
class BNBWeight(UnquantizedWeight):
11
12
13
14
15
16
    weight: torch.Tensor

    def get_linear(self, bias: torch.Tensor):
        return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0)


Nicolas Patry's avatar
Nicolas Patry committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class Linear8bitLt(torch.nn.Module):
    def __init__(
        self,
        weight,
        bias,
        has_fp16_weights=True,
        memory_efficient_backward=False,
        threshold=0.0,
        index=None,
    ):
        super().__init__()
        assert (
            not memory_efficient_backward
        ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
        self.state = bnb.MatmulLtState()
        self.index = index

        # Necessary for stacked layers
        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
        self.state.memory_efficient_backward = memory_efficient_backward
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True

        self.weight = Int8Params(
            weight.data,
            has_fp16_weights=has_fp16_weights,
            requires_grad=has_fp16_weights,
        )
        self.weight.cuda(weight.device)
        self.bias = bias

    def init_8bit_state(self):
        self.state.CB = self.weight.CB
        self.state.SCB = self.weight.SCB
        self.weight.CB = None
        self.weight.SCB = None

    def forward(self, x: torch.Tensor):
        self.state.is_training = self.training
        if self.weight.CB is not None:
            self.init_8bit_state()

        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)

        if not self.state.has_fp16_weights:
            if self.state.CB is not None and self.state.CxB is not None:
                # we converted 8-bit row major to turing/ampere format in the first inference pass
                # we no longer need the row-major weight
                del self.state.CB
                self.weight.data = self.state.CxB
        return out


75
@dataclass
76
class BNBFP4Weight(UnquantizedWeight):
77
78
79
80
81
82
83
    weight: torch.Tensor

    def get_linear(self, bias: torch.Tensor):
        return Linear4bit(self.weight, bias, quant_type="fp4")


@dataclass
84
class BNBNF4Weight(UnquantizedWeight):
85
86
87
88
89
90
    weight: torch.Tensor

    def get_linear(self, bias: torch.Tensor):
        return Linear4bit(self.weight, bias, quant_type="nf4")


91
class Linear4bit(torch.nn.Module):
Nicolas Patry's avatar
Nicolas Patry committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    def __init__(self, weight, bias, quant_type):
        super().__init__()
        self.weight = Params4bit(
            weight.data,
            requires_grad=False,
            compress_statistics=True,
            quant_type=quant_type,
        )
        self.compute_dtype = None
        self.weight.cuda(weight.device)
        self.bias = bias

    def forward(self, x: torch.Tensor):
        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        if getattr(self.weight, "quant_state", None) is None:
            print(
                "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
            )
        inp_dtype = x.dtype
        if self.compute_dtype is not None:
            x = x.to(self.compute_dtype)

        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
        out = bnb.matmul_4bit(
            x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
        )

        out = out.to(inp_dtype)

        return out