bitblas.py 17 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, Optional
4
5

import torch
6
from packaging import version
7
8
9

from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
10
11
12
13
from vllm.model_executor.layers.quantization import (
    QuantizationConfig,
    QuantizationMethods,
)
14
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
15
16
17
18
19
    BITBLAS_OPTIMIZE_FEATURES,
    BITBLAS_SUPPORTED_NUM_BITS,
    BITBLAS_SUPPORTED_SYM,
    MINIMUM_BITBLAS_VERSION,
)
20
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
21
22
23
24
25
26
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    ChannelQuantScaleParameter,
    GroupQuantScaleParameter,
    PackedvLLMParameter,
)
27
28
29
30
31
32
33
34
35
36
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)


class BitBLASConfig(QuantizationConfig):
    """Config class for BitBLAS.

    Reference: https://github.com/Microsoft/BitBLAS
    """
37

38
39
40
41
42
43
44
45
46
47
    TORCH_DTYPE = torch.float16
    STORAGE_DTYPE = "int8"  # assume int8 storage
    TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE)
    # "original" or "rescale" or "quantized",
    # gptq_with_bitblas prefer "quantized implementation"
    ZEROS_MODE = "quantized"

    def __init__(
        self,
        weight_bits: int,
48
49
50
51
        group_size: int | None,
        desc_act: bool | None,
        is_sym: bool | None,
        quant_method: str | None,
52
53
54
55
        lm_head_quantized: bool,
    ) -> None:
        try:
            import bitblas
56

57
            if version.parse(bitblas.__version__) < version.parse(
58
59
                MINIMUM_BITBLAS_VERSION
            ):
60
61
                raise ImportError(
                    "bitblas version is wrong. Please "
62
63
                    f"install bitblas>={MINIMUM_BITBLAS_VERSION}"
                )
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        except ImportError as e:
            bitblas_import_exception = e
            raise ValueError(
                "Trying to use the bitblas backend, but could not import"
                f"with the following error: {bitblas_import_exception}. "
                "Please install bitblas through the following command: "
                f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
            ) from bitblas_import_exception

        if desc_act and group_size == -1:
            # In this case, act_order == True is the same as act_order == False
            # (since we have only one group per output channel)
            desc_act = False

78
        super().__init__()
79
80
81
82
83
84
85
86
87
88
89
90
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.desc_act = desc_act
        self.is_sym = is_sym
        self.quant_method = quant_method
        self.lm_head_quantized = lm_head_quantized

        # Verify
        if self.weight_bits not in BITBLAS_SUPPORTED_NUM_BITS:
            raise ValueError(
                f"BitBLAS does not support weight_bits = {self.weight_bits}. "
                f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} "
91
92
                "are supported."
            )
93
94
95
96

        if self.is_sym not in BITBLAS_SUPPORTED_SYM:
            raise ValueError(
                f"BitBLAS does not support is_sym = {self.is_sym}. "
97
98
                f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported."
            )
99
100
101
102
103
104
105
106
107
108
109
110
111
112

        storage_dtype = self.STORAGE_DTYPE
        storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))

        self.storage_dtype = storage_dtype
        self.storage_torch_dtype = self.TORCH_STORAGE_DTYPE
        # 4 Bits packed into 32 bit datatype.
        self.pack_factor = storage_nbit // weight_bits
        self.nbits = weight_bits

        # Zeros type for the quantized weights.
        self.zeros_mode = self.ZEROS_MODE

    def __repr__(self) -> str:
113
114
115
116
117
118
119
        return (
            f"BitBLASConfig(weight_bits={self.weight_bits}, "
            f"group_size={self.group_size}, "
            f"desc_act={self.desc_act}, "
            f"is_sym={self.is_sym}, "
            f"quant_method={self.quant_method})"
        )
120
121

    @classmethod
122
    def get_name(cls) -> QuantizationMethods:
123
124
125
        return "bitblas"

    @classmethod
126
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
127
128
129
130
131
132
133
134
        return [torch.half, torch.bfloat16]

    @classmethod
    # Need to figure it out
    def get_min_capability(cls) -> int:
        return 70

    @classmethod
135
    def get_config_filenames(cls) -> list[str]:
136
137
138
        return ["quantize_config.json"]

    @staticmethod
139
140
141
    def get_from_keys(
        config: dict[str, Any], keys: list[str], default: Any = None
    ) -> Any:
142
143
144
145
146
147
148
        """Get a value from the model's quantization config."""
        for key in keys:
            if key in config:
                return config[key]
        return default

    @classmethod
149
    def from_config(cls, config: dict[str, Any]) -> "BitBLASConfig":
150
151
152
153
154
        weight_bits = cls.get_from_keys(config, ["bits"])
        group_size = cls.get_from_keys(config, ["group_size"], -1)
        desc_act = cls.get_from_keys(config, ["desc_act"], False)
        is_sym = cls.get_from_keys(config, ["sym"], False)
        quant_method = cls.get_from_keys(config, ["quant_method"])
155
156
157
158
        lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
        return cls(
            weight_bits, group_size, desc_act, is_sym, quant_method, lm_head_quantized
        )
159
160

    @classmethod
161
    def override_quantization_method(
162
        cls, hf_quant_cfg, user_quant
163
    ) -> QuantizationMethods | None:
164
165
        # compat: autogptq >=0.8.0 use checkpoint_format: str
        # compat: autogptq <=0.7.1 is_bitblas_format: bool
166
167
168
        is_bitblas_format = hf_quant_cfg.get(
            "checkpoint_format"
        ) == "bitblas" or hf_quant_cfg.get("is_bitblas_format", False)
169

170
171
172
        is_valid_user_quant = (
            user_quant is None or user_quant == "gptq" or user_quant == "bitblas"
        )
173
174

        if is_bitblas_format and is_valid_user_quant:
175
176
177
            msg = "The model is serialized in {} format. Using {} kernel.".format(
                cls.get_name(), cls.get_name()
            )
178
179
180
181
182
            logger.info(msg)
            return cls.get_name()

        return None

183
184
185
186
187
188
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["BitBLASLinearMethod"]:
        if isinstance(layer, LinearBase) or (
            isinstance(layer, ParallelLMHead) and self.lm_head_quantized
        ):
189
190
191
192
193
194
195
196
197
198
            return BitBLASLinearMethod(self)
        return None


class BitBLASLinearMethod(LinearMethodBase):
    """Linear method for BitBLAS.

    Args:
        quant_config: The BitBLAS quantization config.
    """
199

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    # USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS
    # Instead of BITBLAS_OPTIMIZE_FEATURES
    # If you want to high contiguous batching
    # performance
    OPT_FEATURES = BITBLAS_OPTIMIZE_FEATURES
    ENABLE_TUNING = True
    BITBLAS_DTYPES = {
        torch.float32: "float32",
        torch.float16: "float16",
        torch.bfloat16: "bfloat16",
        torch.half: "float16",
        torch.int8: "int8",
    }

    def __init__(self, quant_config: BitBLASConfig):
        self.quant_config = quant_config

    def create_weights_gptq(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
221
        output_partition_sizes: list[int],
222
223
224
225
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
226
    ) -> None:
227
228
        """Creates quantized weights for use in linear operations.

229
        The function initializes and returns a dictionary containing quantized
230
231
232
233
234
        weights, scales, and zeros
        for performing quantized matrix multiplication operations.

        Args:
            input_size_per_partition: The size of the input partition.
235
            output_partition_sizes: List of output partition sizes.
236
237
            input_size: The total size of the input (unused).
            output_size: The total size of the output (unused).
238
            params_dtype:
239
240
241
                The data type of the parameters (expected to be torch.float16).

        Returns:
242
            A dictionary containing the quantized weights ('qweight'),
243
244
245
            scales ('scales'), and zeros ('zeros').

        Raises:
246
247
248
            ValueError: If `params_dtype` is not `torch.float16` or if the input
                size per partition is not divisible by the group size
                in `quant_config`.
249
250
251
252
253
        """
        del input_size, output_size  # Unused arguments.
        weight_loader = extra_weight_attrs["weight_loader"]

        if params_dtype not in self.quant_config.get_supported_act_dtypes():
254
255
256
            raise ValueError(
                f"Parameter data type must be torch.float16, but got {params_dtype}"
            )
257
258
259
260
261
        group_size = self.quant_config.group_size
        if group_size is None:
            group_size = -1
        # Validate output_size_per_partition
        output_size_per_partition = sum(output_partition_sizes)
262
        if group_size != -1 and input_size_per_partition % group_size != 0:
263
264
            raise ValueError(
                f"Input size per partition ({input_size_per_partition}) must "
265
266
                f"be divisible by group size ({group_size})."
            )
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

        # Initialize or retrieve the BitBLAS matrix multiplication operator.
        self._configure_bitblas_matmul(
            input_size_per_partition,
            output_size_per_partition,
            params_dtype=params_dtype,
            enable_tuning=self.ENABLE_TUNING,
            bias=False,
            layout="nt",
            bits=self.quant_config.weight_bits,
        )

        # Initialize quantized weights with dimensions
        # Quantized 4Bit weights packed.
        qweight = PackedvLLMParameter(
            data=torch.empty(
                self.bitblas_matmul.retrieve_weight_shape(),
                device="cuda",
                dtype=self.quant_config.storage_torch_dtype,
                requires_grad=False,
            ),
            input_dim=1,
            output_dim=0,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
292
293
294
295
296
            bitblas_tile_size=(
                self.bitblas_matmul.retrieve_weight_shape()[-2]
                if self.bitblas_matmul.propagate_b
                else None
            ),
297
298
299
300
            weight_loader=weight_loader,
        )

        # Compute the number of input groups for channel-wise quantization.
301
        input_groups = 1 if group_size == -1 else input_size_per_partition // group_size
302
303
304

        # Initialize scales and zeros for the quantized weights.
        weight_scale_args = {
305
            "data": torch.empty(
306
307
308
309
310
                output_size_per_partition,
                input_groups,
                device="cuda",
                dtype=params_dtype,
            ),
311
            "weight_loader": weight_loader,
312
313
        }
        if input_groups == 1:
314
            scales = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args)
315
        else:
316
317
318
            scales = GroupQuantScaleParameter(
                output_dim=0, input_dim=1, **weight_scale_args
            )
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337

        if self.quant_config.zeros_mode == "quantized":
            zeros = PackedvLLMParameter(
                data=torch.empty(
                    input_groups,
                    output_size_per_partition // self.quant_config.pack_factor,
                    device="cuda",
                    dtype=self.quant_config.storage_torch_dtype,
                    requires_grad=False,
                ),
                input_dim=0,
                output_dim=1,
                packed_dim=1,
                packed_factor=self.quant_config.pack_factor,
                weight_loader=weight_loader,
            )

        else:
            zeros = BasevLLMParameter(
338
339
340
341
342
343
                torch.empty(
                    output_size_per_partition,
                    input_groups,
                    device="cuda",
                    dtype=params_dtype,
                ),
344
345
346
                weight_loader=weight_loader,
            )
            # Set attributes to indicate how scales and zeros are applied.
347
348
349
350
351
352
353
            set_weight_attrs(
                zeros,
                {
                    "input_dim": None if input_groups == 1 else 1,
                    "output_dim": 0,
                },
            )
354
355
356
357
358
359
360
361
362

        layer.register_parameter("qweight", qweight)
        layer.register_parameter("scales", scales)
        layer.register_parameter("zeros", zeros)

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
363
        output_partition_sizes: list[int],
364
365
366
367
368
369
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        if self.quant_config.quant_method == "gptq":
370
371
372
373
374
375
376
377
378
            return self.create_weights_gptq(
                layer,
                input_size_per_partition,
                output_partition_sizes,
                input_size,
                output_size,
                params_dtype,
                **extra_weight_attrs,
            )
379
380
        else:
            raise ValueError(
381
382
                f"Unsupported quant_method {self.quant_config.quant_method}"
            )
383
384
385
386
387
388
389
390
391
392
393
394
395

    def _configure_bitblas_matmul(
        self,
        infeatures,
        outfeatures,
        params_dtype,
        enable_tuning,
        bias,
        layout,
        bits,
        out_dtype="float16",
    ):
        from bitblas import MatmulConfig
396

397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
        bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]

        with_scaling = False
        with_zeros = False
        group_size = self.quant_config.group_size
        zeros_mode = self.quant_config.zeros_mode
        if self.quant_config.quant_method == "gptq":
            with_scaling = True
            with_zeros = True
            W_dtype = f"uint{bits}"
            if self.quant_config.is_sym:
                with_zeros = False
                W_dtype = f"int{bits}"
        else:
            raise ValueError(
412
413
                f"Unsupported quant_method {self.quant_config.quant_method}"
            )
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430

        matmul_config = MatmulConfig(
            N=outfeatures,
            K=infeatures,
            A_dtype=bitblas_dtype,
            W_dtype=W_dtype,
            out_dtype=out_dtype,
            accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
            storage_dtype=self.quant_config.STORAGE_DTYPE,
            with_scaling=with_scaling,
            with_zeros=with_zeros,
            group_size=group_size,
            with_bias=bias,
            layout=layout,
            zeros_mode=zeros_mode,
        )
        self.bitblas_matmul = self._get_or_create_bitblas_operator(
431
432
            matmul_config, enable_tuning
        )
433
434
435
436

    def _get_or_create_bitblas_operator(self, config, enable_tuning):
        from bitblas import Matmul, auto_detect_nvidia_target
        from bitblas.cache import get_database_path, global_operator_cache
437

438
439
440
        BITBLAS_DATABASE_PATH = get_database_path()
        BITBLAS_TARGET = auto_detect_nvidia_target()
        if global_operator_cache.size() == 0:
441
442
443
            global_operator_cache.load_from_database(
                BITBLAS_DATABASE_PATH, BITBLAS_TARGET
            )
444
445
446

        bitblas_matmul = global_operator_cache.get(config)
        if bitblas_matmul is None:
447
            bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False)
448
            if enable_tuning:
449
                TUNING_MESSAGE = f"BitBLAS Operator {config} is tuning ..."
450
451
452
453
                logger.info(TUNING_MESSAGE)
                bitblas_matmul.hardware_aware_finetune(topk=20)
                global_operator_cache.add(config, bitblas_matmul)
                global_operator_cache.save_into_database(
454
455
                    BITBLAS_DATABASE_PATH, BITBLAS_TARGET
                )
456
                TUNED_MESSAGE = (
457
458
                    f"BitBLAS Operator {config} tuned and saved to database."
                )
459
460
461
462
463
                logger.info(TUNED_MESSAGE)
            else:
                _message = f"BitBLAS Operator {config} created."
                logger.info(_message)
        else:
464
            _message = f"BitBLAS Operator {config} found in global_operator_cache."
465
466
467
468
469
470
471
            logger.info(_message)
        return bitblas_matmul

    def apply_gptq(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
472
        bias: torch.Tensor | None = None,
473
474
475
476
477
478
479
480
481
482
483
484
    ) -> torch.Tensor:
        qweight = layer.qweight
        scales = layer.scales
        qzeros = layer.zeros

        x_2d = x.view(-1, x.shape[-1])

        if self.quant_config.is_sym:
            output_2d = self.bitblas_matmul(x_2d, qweight, scales)
        else:
            output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros)

485
        output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500

        if bias is not None:
            output.add_(bias)  # In-place add

        return output

    def apply(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> torch.Tensor:
        if self.quant_config.quant_method == "gptq":
            return self.apply_gptq(*args, **kwargs)
        else:
            raise ValueError(
501
502
                f"Unsupported quant_method {self.quant_config.quant_method}"
            )