bnb.py 4.14 KB
Newer Older
1
from dataclasses import dataclass
Nicolas Patry's avatar
Nicolas Patry committed
2
from functools import lru_cache
3

Nicolas Patry's avatar
Nicolas Patry committed
4
import bitsandbytes as bnb
5
import torch
Nicolas Patry's avatar
Nicolas Patry committed
6
from bitsandbytes.nn import Int8Params, Params4bit
7
8
from loguru import logger
from text_generation_server.utils.weights import Weight
Nicolas Patry's avatar
Nicolas Patry committed
9
10
11
12
13
14
15
16
17


@lru_cache(1)
def warn_deprecate_bnb():
    logger.warning(
        "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
    )


18
19
20
21
22
23
24
25
@dataclass
class BNBWeight(Weight):
    weight: torch.Tensor

    def get_linear(self, bias: torch.Tensor):
        return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0)


Nicolas Patry's avatar
Nicolas Patry committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class Linear8bitLt(torch.nn.Module):
    def __init__(
        self,
        weight,
        bias,
        has_fp16_weights=True,
        memory_efficient_backward=False,
        threshold=0.0,
        index=None,
    ):
        super().__init__()
        assert (
            not memory_efficient_backward
        ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
        self.state = bnb.MatmulLtState()
        self.index = index

        # Necessary for stacked layers
        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
        self.state.memory_efficient_backward = memory_efficient_backward
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True

        self.weight = Int8Params(
            weight.data,
            has_fp16_weights=has_fp16_weights,
            requires_grad=has_fp16_weights,
        )
        self.weight.cuda(weight.device)
        self.bias = bias

    def init_8bit_state(self):
        self.state.CB = self.weight.CB
        self.state.SCB = self.weight.SCB
        self.weight.CB = None
        self.weight.SCB = None

    def forward(self, x: torch.Tensor):
        self.state.is_training = self.training
        if self.weight.CB is not None:
            self.init_8bit_state()

        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)

        if not self.state.has_fp16_weights:
            if self.state.CB is not None and self.state.CxB is not None:
                # we converted 8-bit row major to turing/ampere format in the first inference pass
                # we no longer need the row-major weight
                del self.state.CB
                self.weight.data = self.state.CxB
        return out


84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
@dataclass
class BNBFP4Weight(Weight):
    weight: torch.Tensor

    def get_linear(self, bias: torch.Tensor):
        return Linear4bit(self.weight, bias, quant_type="fp4")


@dataclass
class BNBNF4Weight(Weight):
    weight: torch.Tensor

    def get_linear(self, bias: torch.Tensor):
        return Linear4bit(self.weight, bias, quant_type="nf4")


100
class Linear4bit(torch.nn.Module):
Nicolas Patry's avatar
Nicolas Patry committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    def __init__(self, weight, bias, quant_type):
        super().__init__()
        self.weight = Params4bit(
            weight.data,
            requires_grad=False,
            compress_statistics=True,
            quant_type=quant_type,
        )
        self.compute_dtype = None
        self.weight.cuda(weight.device)
        self.bias = bias

    def forward(self, x: torch.Tensor):
        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        if getattr(self.weight, "quant_state", None) is None:
            print(
                "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
            )
        inp_dtype = x.dtype
        if self.compute_dtype is not None:
            x = x.to(self.compute_dtype)

        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
        out = bnb.matmul_4bit(
            x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
        )

        out = out.to(inp_dtype)

        return out