"vscode:/vscode.git/clone" did not exist on "7b966ae2ba73b5391937907bfd8aaf63af033ff1"
bitblas.py 17.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, Optional
4
5

import torch
6
from packaging import version
7
8
9

from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
10
from vllm.model_executor.layers.quantization import QuantizationMethods
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
    BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS,
    BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (BasevLLMParameter,
                                           ChannelQuantScaleParameter,
                                           GroupQuantScaleParameter,
                                           PackedvLLMParameter)
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)


class BitBLASConfig(QuantizationConfig):
    """Config class for BitBLAS.

    Reference: https://github.com/Microsoft/BitBLAS
    """
    TORCH_DTYPE = torch.float16
    STORAGE_DTYPE = "int8"  # assume int8 storage
    TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE)
    # "original" or "rescale" or "quantized",
    # gptq_with_bitblas prefer "quantized implementation"
    ZEROS_MODE = "quantized"

    def __init__(
        self,
        weight_bits: int,
        group_size: Optional[int],
        desc_act: Optional[bool],
        is_sym: Optional[bool],
        quant_method: Optional[str],
        lm_head_quantized: bool,
    ) -> None:
        try:
            import bitblas
49
50
            if version.parse(bitblas.__version__) < version.parse(
                    MINIMUM_BITBLAS_VERSION):
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
                raise ImportError(
                    "bitblas version is wrong. Please "
                    f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
        except ImportError as e:
            bitblas_import_exception = e
            raise ValueError(
                "Trying to use the bitblas backend, but could not import"
                f"with the following error: {bitblas_import_exception}. "
                "Please install bitblas through the following command: "
                f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
            ) from bitblas_import_exception

        if desc_act and group_size == -1:
            # In this case, act_order == True is the same as act_order == False
            # (since we have only one group per output channel)
            desc_act = False

68
        super().__init__()
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.desc_act = desc_act
        self.is_sym = is_sym
        self.quant_method = quant_method
        self.lm_head_quantized = lm_head_quantized

        # Verify
        if self.weight_bits not in BITBLAS_SUPPORTED_NUM_BITS:
            raise ValueError(
                f"BitBLAS does not support weight_bits = {self.weight_bits}. "
                f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} "
                "are supported.")

        if self.is_sym not in BITBLAS_SUPPORTED_SYM:
            raise ValueError(
                f"BitBLAS does not support is_sym = {self.is_sym}. "
                f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported.")

        storage_dtype = self.STORAGE_DTYPE
        storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))

        self.storage_dtype = storage_dtype
        self.storage_torch_dtype = self.TORCH_STORAGE_DTYPE
        # 4 Bits packed into 32 bit datatype.
        self.pack_factor = storage_nbit // weight_bits
        self.nbits = weight_bits

        # Zeros type for the quantized weights.
        self.zeros_mode = self.ZEROS_MODE

    def __repr__(self) -> str:
        return (f"BitBLASConfig(weight_bits={self.weight_bits}, "
                f"group_size={self.group_size}, "
                f"desc_act={self.desc_act}, "
                f"is_sym={self.is_sym}, "
                f"quant_method={self.quant_method})")

    @classmethod
108
    def get_name(cls) -> QuantizationMethods:
109
110
111
        return "bitblas"

    @classmethod
112
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
113
114
115
116
117
118
119
120
        return [torch.half, torch.bfloat16]

    @classmethod
    # Need to figure it out
    def get_min_capability(cls) -> int:
        return 70

    @classmethod
121
    def get_config_filenames(cls) -> list[str]:
122
123
124
        return ["quantize_config.json"]

    @staticmethod
125
126
    def get_from_keys(config: dict[str, Any],
                      keys: list[str],
127
128
129
130
131
132
133
134
                      default: Any = None) -> Any:
        """Get a value from the model's quantization config."""
        for key in keys:
            if key in config:
                return config[key]
        return default

    @classmethod
135
    def from_config(cls, config: dict[str, Any]) -> "BitBLASConfig":
136
137
138
139
140
141
142
143
144
145
146
        weight_bits = cls.get_from_keys(config, ["bits"])
        group_size = cls.get_from_keys(config, ["group_size"], -1)
        desc_act = cls.get_from_keys(config, ["desc_act"], False)
        is_sym = cls.get_from_keys(config, ["sym"], False)
        quant_method = cls.get_from_keys(config, ["quant_method"])
        lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
                                                 default=False)
        return cls(weight_bits, group_size, desc_act, is_sym, quant_method,
                   lm_head_quantized)

    @classmethod
147
148
    def override_quantization_method(
            cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        # compat: autogptq >=0.8.0 use checkpoint_format: str
        # compat: autogptq <=0.7.1 is_bitblas_format: bool
        is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"
                             or hf_quant_cfg.get("is_bitblas_format", False))

        is_valid_user_quant = (user_quant is None or user_quant == "gptq"
                               or user_quant == "bitblas")

        if is_bitblas_format and is_valid_user_quant:
            msg = ("The model is serialized in {} format. Using {} kernel.".
                   format(cls.get_name(), cls.get_name()))
            logger.info(msg)
            return cls.get_name()

        return None

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["BitBLASLinearMethod"]:
        if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
                                             and self.lm_head_quantized):
            return BitBLASLinearMethod(self)
        return None


class BitBLASLinearMethod(LinearMethodBase):
    """Linear method for BitBLAS.

    Args:
        quant_config: The BitBLAS quantization config.
    """
    # USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS
    # Instead of BITBLAS_OPTIMIZE_FEATURES
    # If you want to high contiguous batching
    # performance
    OPT_FEATURES = BITBLAS_OPTIMIZE_FEATURES
    ENABLE_TUNING = True
    BITBLAS_DTYPES = {
        torch.float32: "float32",
        torch.float16: "float16",
        torch.bfloat16: "bfloat16",
        torch.half: "float16",
        torch.int8: "int8",
    }

    def __init__(self, quant_config: BitBLASConfig):
        self.quant_config = quant_config

    def create_weights_gptq(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
200
        output_partition_sizes: list[int],
201
202
203
204
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
205
    ) -> None:
206
207
208
209
210
211
212
213
        """Creates quantized weights for use in linear operations.

        The function initializes and returns a dictionary containing quantized 
        weights, scales, and zeros
        for performing quantized matrix multiplication operations.

        Args:
            input_size_per_partition: The size of the input partition.
214
            output_partition_sizes: List of output partition sizes.
215
216
217
218
219
220
221
222
223
224
            input_size: The total size of the input (unused).
            output_size: The total size of the output (unused).
            params_dtype: 
                The data type of the parameters (expected to be torch.float16).

        Returns:
            A dictionary containing the quantized weights ('qweight'), 
            scales ('scales'), and zeros ('zeros').

        Raises:
225
226
227
            ValueError: If `params_dtype` is not `torch.float16` or if the input
                size per partition is not divisible by the group size
                in `quant_config`.
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        """
        del input_size, output_size  # Unused arguments.
        weight_loader = extra_weight_attrs["weight_loader"]

        if params_dtype not in self.quant_config.get_supported_act_dtypes():
            raise ValueError("Parameter data type must be torch.float16, "
                             f"but got {params_dtype}")
        group_size = self.quant_config.group_size
        if group_size is None:
            group_size = -1
        # Validate output_size_per_partition
        output_size_per_partition = sum(output_partition_sizes)
        if (group_size != -1 and input_size_per_partition % group_size != 0):
            raise ValueError(
                f"Input size per partition ({input_size_per_partition}) must "
                f"be divisible by group size ({group_size}).")

        # Initialize or retrieve the BitBLAS matrix multiplication operator.
        self._configure_bitblas_matmul(
            input_size_per_partition,
            output_size_per_partition,
            params_dtype=params_dtype,
            enable_tuning=self.ENABLE_TUNING,
            bias=False,
            layout="nt",
            bits=self.quant_config.weight_bits,
        )

        # Initialize quantized weights with dimensions
        # Quantized 4Bit weights packed.
        qweight = PackedvLLMParameter(
            data=torch.empty(
                self.bitblas_matmul.retrieve_weight_shape(),
                device="cuda",
                dtype=self.quant_config.storage_torch_dtype,
                requires_grad=False,
            ),
            input_dim=1,
            output_dim=0,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            bitblas_tile_size=(self.bitblas_matmul.retrieve_weight_shape()[-2]
                               if self.bitblas_matmul.propagate_b else None),
            weight_loader=weight_loader,
        )

        # Compute the number of input groups for channel-wise quantization.
        input_groups = (1 if group_size == -1 else input_size_per_partition //
                        group_size)

        # Initialize scales and zeros for the quantized weights.
        weight_scale_args = {
            "data":
            torch.empty(
                output_size_per_partition,
                input_groups,
                device="cuda",
                dtype=params_dtype,
            ),
            "weight_loader":
            weight_loader
        }
        if input_groups == 1:
            scales = ChannelQuantScaleParameter(output_dim=0,
                                                **weight_scale_args)
        else:
            scales = GroupQuantScaleParameter(output_dim=0,
                                              input_dim=1,
                                              **weight_scale_args)

        if self.quant_config.zeros_mode == "quantized":
            zeros = PackedvLLMParameter(
                data=torch.empty(
                    input_groups,
                    output_size_per_partition // self.quant_config.pack_factor,
                    device="cuda",
                    dtype=self.quant_config.storage_torch_dtype,
                    requires_grad=False,
                ),
                input_dim=0,
                output_dim=1,
                packed_dim=1,
                packed_factor=self.quant_config.pack_factor,
                weight_loader=weight_loader,
            )

        else:
            zeros = BasevLLMParameter(
                torch.empty(output_size_per_partition,
                            input_groups,
                            device="cuda",
                            dtype=params_dtype),
                weight_loader=weight_loader,
            )
            # Set attributes to indicate how scales and zeros are applied.
            set_weight_attrs(zeros, {
                "input_dim": None if input_groups == 1 else 1,
                "output_dim": 0,
            })

        layer.register_parameter("qweight", qweight)
        layer.register_parameter("scales", scales)
        layer.register_parameter("zeros", zeros)

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
336
        output_partition_sizes: list[int],
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        if self.quant_config.quant_method == "gptq":
            return self.create_weights_gptq(layer, input_size_per_partition,
                                            output_partition_sizes, input_size,
                                            output_size, params_dtype,
                                            **extra_weight_attrs)
        else:
            raise ValueError(
                f"Unsupported quant_method {self.quant_config.quant_method}")

    def _configure_bitblas_matmul(
        self,
        infeatures,
        outfeatures,
        params_dtype,
        enable_tuning,
        bias,
        layout,
        bits,
        out_dtype="float16",
    ):
        from bitblas import MatmulConfig
        bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]

        with_scaling = False
        with_zeros = False
        group_size = self.quant_config.group_size
        zeros_mode = self.quant_config.zeros_mode
        if self.quant_config.quant_method == "gptq":
            with_scaling = True
            with_zeros = True
            W_dtype = f"uint{bits}"
            if self.quant_config.is_sym:
                with_zeros = False
                W_dtype = f"int{bits}"
        else:
            raise ValueError(
                f"Unsupported quant_method {self.quant_config.quant_method}")

        matmul_config = MatmulConfig(
            N=outfeatures,
            K=infeatures,
            A_dtype=bitblas_dtype,
            W_dtype=W_dtype,
            out_dtype=out_dtype,
            accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
            storage_dtype=self.quant_config.STORAGE_DTYPE,
            with_scaling=with_scaling,
            with_zeros=with_zeros,
            group_size=group_size,
            with_bias=bias,
            layout=layout,
            zeros_mode=zeros_mode,
        )
        self.bitblas_matmul = self._get_or_create_bitblas_operator(
            matmul_config, enable_tuning)

    def _get_or_create_bitblas_operator(self, config, enable_tuning):
        from bitblas import Matmul, auto_detect_nvidia_target
        from bitblas.cache import get_database_path, global_operator_cache
        BITBLAS_DATABASE_PATH = get_database_path()
        BITBLAS_TARGET = auto_detect_nvidia_target()
        if global_operator_cache.size() == 0:
            global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
                                                     BITBLAS_TARGET)

        bitblas_matmul = global_operator_cache.get(config)
        if bitblas_matmul is None:
            bitblas_matmul = Matmul(config,
                                    target=BITBLAS_TARGET,
                                    enable_tuning=False)
            if enable_tuning:
                TUNING_MESSAGE = (f"BitBLAS Operator {config} is tuning ...")
                logger.info(TUNING_MESSAGE)
                bitblas_matmul.hardware_aware_finetune(topk=20)
                global_operator_cache.add(config, bitblas_matmul)
                global_operator_cache.save_into_database(
                    BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
                TUNED_MESSAGE = (
                    f"BitBLAS Operator {config} tuned and saved to database.")
                logger.info(TUNED_MESSAGE)
            else:
                _message = f"BitBLAS Operator {config} created."
                logger.info(_message)
        else:
            _message = (
                f"BitBLAS Operator {config} found in global_operator_cache.")
            logger.info(_message)
        return bitblas_matmul

    def apply_gptq(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        qweight = layer.qweight
        scales = layer.scales
        qzeros = layer.zeros

        x_2d = x.view(-1, x.shape[-1])

        if self.quant_config.is_sym:
            output_2d = self.bitblas_matmul(x_2d, qweight, scales)
        else:
            output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros)

        output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))

        if bias is not None:
            output.add_(bias)  # In-place add

        return output

    def apply(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> torch.Tensor:
        if self.quant_config.quant_method == "gptq":
            return self.apply_gptq(*args, **kwargs)
        else:
            raise ValueError(
                f"Unsupported quant_method {self.quant_config.quant_method}")