"vscode:/vscode.git/clone" did not exist on "e6ba2000aef3e61ca84bb114472badecbd533ee9"
bitsandbytes.py 5.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from typing import Any, Dict, List, Optional

import torch
from torch.nn.parameter import Parameter

from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)


class BitsAndBytesConfig(QuantizationConfig):
    """Config class for BitsAndBytes Quantization.

    Reference: https://arxiv.org/abs/2305.14314
    """

    def __init__(
        self,
        adapter_name_or_path: str,
        target_modules: List[str],
    ) -> None:

        self.adapter_name_or_path = adapter_name_or_path
        self.target_modules = target_modules

    def __repr__(self) -> str:
        return (
            f"BitsAndBytesConfig(adapter_name_or_path={self.adapter_name_or_path}"
        )

    @classmethod
    def get_name(self) -> str:
        return "bitsandbytes"

    @classmethod
    def get_supported_act_dtypes(self) -> List[torch.dtype]:
        return [torch.float32, torch.float16, torch.bfloat16]

    @classmethod
41
    def get_min_capability(cls) -> int:
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        return 70

    @staticmethod
    def get_config_filenames() -> List[str]:
        return [
            "adapter_config.json",
        ]

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
        adapter_name = cls.get_from_keys(config, ["adapter_name_or_path"])
        default_target_modules = [
            "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
            "o_proj"
        ]
        if adapter_name == "":
            target_modules = default_target_modules
        else:
            target_modules = cls.get_from_keys(config, ["target_modules"])
        return cls(adapter_name, target_modules)

    def get_quant_method(
            self,
            layer: torch.nn.Module) -> Optional["BitsAndBytesLinearMethod"]:
        if isinstance(layer, LinearBase):
            return BitsAndBytesLinearMethod(self)
        return None

    def get_scaled_act_names(self) -> List[str]:
        return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]


class BitsAndBytesLinearMethod(LinearMethodBase):
    """Linear method for BitsAndBytes.

    Args:
       quant_config: The BitsAndBytes quantization config.
    """

    def __init__(self, quant_config: BitsAndBytesConfig):
        try:
            import bitsandbytes
            if bitsandbytes.__version__ < "0.42.0":
                raise ImportError("bitsandbytes version is wrong. Please "
                                  "install bitsandbytes>=0.42.0.")
        except ImportError as err:
            raise ImportError("Please install bitsandbytes>=0.42.0 via "
                              "`pip install bitsandbytes>=0.42.0` to use "
                              "bitsandbytes quantizer.") from err

        self.quant_config = quant_config

    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
                       output_partition_sizes: List[int], input_size: int,
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
        quant_ratio = 0
        if params_dtype.is_floating_point:
            quant_ratio = torch.finfo(params_dtype).bits // torch.iinfo(
                torch.uint8).bits
        else:
            quant_ratio = torch.iinfo(params_dtype).bits // torch.iinfo(
                torch.uint8).bits

        if input_size_per_partition * sum(
                output_partition_sizes) % quant_ratio != 0:
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape. ")
        qweight = Parameter(
            torch.empty(
                input_size_per_partition * sum(output_partition_sizes) //
                quant_ratio,
                1,
                dtype=torch.uint8,
            ),
            requires_grad=False,
        )

        set_weight_attrs(
            qweight,
            {
                "input_dim": 0,
                # In bitsandbytes, a tensor of shape [n,m] is quantized to
                #[n*m/pack_ratio, 1],so the output_dim is 0
                "output_dim": 0,
                "pack_factor": quant_ratio,
                "use_bitsandbytes": True,
            })
        layer.register_parameter("qweight", qweight)
        set_weight_attrs(qweight, extra_weight_attrs)

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:

        # only load the bitsandbytes module when needed
        from bitsandbytes import matmul_4bit

        original_type = x.dtype
        bf_x = x.to(torch.bfloat16)

        qweight = layer.qweight
        quant_states = qweight.bnb_quant_state
        offsets = qweight.bnb_shard_offsets

        out_dim_0 = x.shape[0]
        out_dim_1 = sum(
            [quant_state[1].shape[0] for quant_state in quant_states.items()])
        out = torch.empty(out_dim_0,
                          out_dim_1,
                          dtype=torch.bfloat16,
                          device=x.device)

        current_index = 0
        for i in range(len(quant_states)):
            output_size = quant_states[i].shape[0]
            # It is more efficient to use out kwarg like
            # matmul_4bit(..., out = ...).  Infeasible now due to the bug
            # https://github.com/TimDettmers/bitsandbytes/issues/1235.
            # Need to change  after the bug is fixed.
            out[:, current_index:current_index + output_size] = matmul_4bit(
                bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i])

            current_index += output_size

        out = out.to(original_type)

        if bias is not None:
            out += bias

        return out