bitsandbytes.py 22.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Any, Callable, Optional, Union
5
6

import torch
7
from packaging import version
8

9
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
10
                                                        FusedMoEConfig,
11
                                                        FusedMoEMethodBase)
12
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
13
                                               UnquantizedLinearMethod,
14
                                               set_weight_attrs)
15
from vllm.model_executor.layers.quantization import QuantizationMethods
16
17
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
18
from vllm.platforms import current_platform
19
from vllm.utils import direct_register_custom_op
20
21
22
23
24
25
26
27


class BitsAndBytesConfig(QuantizationConfig):
    """Config class for BitsAndBytes Quantization.

    Reference: https://arxiv.org/abs/2305.14314
    """

28
29
30
31
32
    def __init__(
        self,
        load_in_8bit: bool = False,
        load_in_4bit: bool = True,
        bnb_4bit_compute_dtype: str = "float32",
33
        bnb_4bit_quant_storage: str = "uint8",
34
35
36
37
        bnb_4bit_quant_type: str = "fp4",
        bnb_4bit_use_double_quant: bool = False,
        llm_int8_enable_fp32_cpu_offload: bool = False,
        llm_int8_has_fp16_weight: bool = False,
38
        llm_int8_skip_modules: Optional[list[str]] = None,
39
        llm_int8_threshold: float = 6.0,
40
    ) -> None:
41
        super().__init__()
42
43
44
        self.load_in_8bit = load_in_8bit
        self.load_in_4bit = load_in_4bit
        self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
45
        self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
46
47
48
49
        self.bnb_4bit_quant_type = bnb_4bit_quant_type
        self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
        self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
        self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
50
        self.llm_int8_skip_modules = llm_int8_skip_modules or []
51
        self.llm_int8_threshold = llm_int8_threshold
52

53
54
55
56
        if self.bnb_4bit_quant_storage not in ["uint8"]:
            raise ValueError("Unsupported bnb_4bit_quant_storage: "
                             f"{self.bnb_4bit_quant_storage}")

57
    def __repr__(self) -> str:
58
59
60
        return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
                f"load_in_4bit={self.load_in_4bit}, "
                f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
61
                f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, "
62
63
                f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
                f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
64
65

    @classmethod
66
    def get_name(self) -> QuantizationMethods:
67
68
69
        return "bitsandbytes"

    @classmethod
70
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
71
72
73
        return [torch.float32, torch.float16, torch.bfloat16]

    @classmethod
74
    def get_min_capability(cls) -> int:
75
76
77
        return 70

    @staticmethod
78
    def get_config_filenames() -> list[str]:
79
        return []
80
81

    @classmethod
82
    def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig":
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

        def get_safe_value(config, keys, default_value=None):
            try:
                value = cls.get_from_keys(config, keys)
                return value if value is not None else default_value
            except ValueError:
                return default_value

        load_in_8bit = get_safe_value(config, ["load_in_8bit"],
                                      default_value=False)
        load_in_4bit = get_safe_value(config, ["load_in_4bit"],
                                      default_value=True)
        bnb_4bit_compute_dtype = get_safe_value(config,
                                                ["bnb_4bit_compute_dtype"],
                                                default_value="float32")
98
99
100
        bnb_4bit_quant_storage = get_safe_value(config,
                                                ["bnb_4bit_quant_storage"],
                                                default_value="uint8")
101
102
103
104
105
106
107
108
109
110
111
112
113
        bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"],
                                             default_value="fp4")
        bnb_4bit_use_double_quant = get_safe_value(
            config, ["bnb_4bit_use_double_quant"], default_value=False)
        llm_int8_enable_fp32_cpu_offload = get_safe_value(
            config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False)
        llm_int8_has_fp16_weight = get_safe_value(config,
                                                  ["llm_int8_has_fp16_weight"],
                                                  default_value=False)
        llm_int8_skip_modules = get_safe_value(config,
                                               ["llm_int8_skip_modules"],
                                               default_value=[])
        llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"],
114
                                            default_value=6.0)
115
116
117
118
119

        return cls(
            load_in_8bit=load_in_8bit,
            load_in_4bit=load_in_4bit,
            bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
120
            bnb_4bit_quant_storage=bnb_4bit_quant_storage,
121
122
123
124
125
126
            bnb_4bit_quant_type=bnb_4bit_quant_type,
            bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
            llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload,
            llm_int8_has_fp16_weight=llm_int8_has_fp16_weight,
            llm_int8_skip_modules=llm_int8_skip_modules,
            llm_int8_threshold=llm_int8_threshold)
127

128
129
130
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional[Union["LinearMethodBase", "BitsAndBytesMoEMethod"]]:
131
        if isinstance(layer, LinearBase):
132
133
            if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
                return UnquantizedLinearMethod()
134
            return BitsAndBytesLinearMethod(self)
135
        elif isinstance(layer, FusedMoE):
136
            return BitsAndBytesMoEMethod(self, layer.moe_config)
137
138
139
        return None


140
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
141
142
143
144
    # Split the prefix into its dot-separated components
    components = prefix.split('.')

    # Check if any of the skip modules exactly matches any component
145
146
147
148
149
150
151
152
153
154
    substr_check = any(module_name in components
                       for module_name in llm_int8_skip_modules)

    # Allow certain layers to not be quantized
    set_components = set(".".join(components[:i + 1])
                         for i in range(len(components)))
    set_llm_int8_skip_modules = set(llm_int8_skip_modules)
    prefix_check = len(set_llm_int8_skip_modules & set_components) != 0

    return substr_check or prefix_check
155
156


157
158
159
160
161
162
163
def calculate_quant_ratio(dtype):
    if dtype.is_floating_point:
        return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
    else:
        return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits


164
165
166
167
168
169
170
171
172
173
class BitsAndBytesLinearMethod(LinearMethodBase):
    """Linear method for BitsAndBytes.

    Args:
       quant_config: The BitsAndBytes quantization config.
    """

    def __init__(self, quant_config: BitsAndBytesConfig):
        try:
            import bitsandbytes
174
175
            if version.parse(
                    bitsandbytes.__version__) < version.parse("0.46.1"):
176
                raise ImportError("bitsandbytes version is wrong. Please "
177
                                  "install bitsandbytes>=0.46.1.")
178
        except ImportError as err:
179
180
            raise ImportError("Please install bitsandbytes>=0.46.1 via "
                              "`pip install bitsandbytes>=0.46.1` to use "
181
182
183
184
185
186
                              "bitsandbytes quantizer.") from err

        self.quant_config = quant_config

    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
187
                       output_partition_sizes: list[int], input_size: int,
188
189
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        from bitsandbytes.nn import Int8Params

        def create_qweight_for_8bit():
            qweight = Int8Params(
                data=torch.empty(sum(output_partition_sizes),
                                 input_size_per_partition,
                                 dtype=torch.int8),
                has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight,
                requires_grad=False)
            set_weight_attrs(
                qweight, {
                    "input_dim": 0,
                    "output_dim": 0,
                    "pack_factor": 1,
                    "use_bitsandbytes_8bit": True,
                    "generation": 0
                })
            return qweight

        def create_qweight_for_4bit():
            quant_ratio = calculate_quant_ratio(params_dtype)

            total_size = input_size_per_partition * sum(output_partition_sizes)
            if total_size % quant_ratio != 0:
                raise ValueError(
                    "The input size is not aligned with the quantized "
                    "weight shape.")

            qweight = torch.nn.Parameter(torch.empty(total_size // quant_ratio,
                                                     1,
                                                     dtype=torch.uint8),
                                         requires_grad=False)
            set_weight_attrs(
                qweight, {
                    "input_dim": 0,
                    "output_dim": 0,
                    "pack_factor": quant_ratio,
                    "use_bitsandbytes_4bit": True
                })
            return qweight

        if self.quant_config.load_in_8bit:
            qweight = create_qweight_for_8bit()
233
        else:
234
            qweight = create_qweight_for_4bit()
235
236
237
        # Enable parameters to have the same name as in the BNB
        # checkpoint format.
        layer.register_parameter("weight", qweight)
238
239
240
241
242
243
244
        set_weight_attrs(qweight, extra_weight_attrs)

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        if self.quant_config.load_in_8bit:
            return self._apply_8bit_weight(layer, x, bias)
        else:
            return self._apply_4bit_weight(layer, x, bias)

    def _apply_8bit_weight(
            self,
            layer: torch.nn.Module,
            x: torch.Tensor,
            bias: Optional[torch.Tensor] = None) -> torch.Tensor:

        # only load the bitsandbytes module when needed
        from bitsandbytes import MatmulLtState, matmul

        original_type = x.dtype
260
261
262
263
264
        original_shape = x.shape
        reshape_after_matmul = False
        if x.ndim > 2:
            x = x.reshape(-1, x.size(-1))
            reshape_after_matmul = True
265
266
        bf_x = x.to(torch.bfloat16)

267
        qweight = layer.weight
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        offsets = qweight.bnb_shard_offsets
        quant_states = qweight.bnb_quant_state
        matmul_states = qweight.matmul_state
        generation = qweight.generation

        out_dim_0 = x.shape[0]
        out_dim_1 = sum(
            [quant_state[1].shape[0] for quant_state in quant_states.items()])
        out = torch.empty(out_dim_0,
                          out_dim_1,
                          dtype=torch.float16,
                          device=x.device)

        current_index = 0
        for i in range(len(quant_states)):
            output_size = quant_states[i].shape[0]

            # in profile_run or the first generation of inference,
            # create new matmul_states
            if generation == 0 or generation == 1:
                matmul_states[i] = MatmulLtState()
                matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]]
290
                matmul_states[i].SCB = quant_states[i].to(x.device)
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
                matmul_states[i].threshold = (
                    self.quant_config.llm_int8_threshold)
                matmul_states[i].has_fp16_weights = (
                    self.quant_config.llm_int8_has_fp16_weight)
                matmul_states[i].is_training = False
                if matmul_states[i].threshold > 0.0 and not matmul_states[
                        i].has_fp16_weights:
                    matmul_states[i].use_pool = True

            new_x = bf_x.unsqueeze(0)

            out[:, current_index:current_index + output_size] = matmul(
                new_x,
                qweight[offsets[i]:offsets[i + 1]],
                state=matmul_states[i])

            current_index += output_size

            # only update the matmul_states if it is not profile_run
            if (generation > 0
                    and not self.quant_config.llm_int8_has_fp16_weight
                    and matmul_states[i].CB is not None
                    and matmul_states[i].CxB is not None):
                del matmul_states[i].CB
                qweight[offsets[i]:offsets[i + 1]] = matmul_states[i].CxB

        out = out.to(original_type)

319
320
321
        if reshape_after_matmul:
            out = out.view(*original_shape[:-1], out.size(-1))

322
323
324
325
326
327
328
329
330
331
332
333
334
        if bias is not None:
            out += bias

        qweight.generation += 1

        return out

    def _apply_4bit_weight(
            self,
            layer: torch.nn.Module,
            x: torch.Tensor,
            bias: Optional[torch.Tensor] = None) -> torch.Tensor:

335
        original_type = x.dtype
336
337
338
339
340
        original_shape = x.shape
        reshape_after_matmul = False
        if x.ndim > 2:
            x = x.reshape(-1, x.size(-1))
            reshape_after_matmul = True
341
342
        bf_x = x.to(torch.bfloat16)

343
        qweight = layer.weight
344
345
346
347
348
349
350
351
352
353
        quant_states = qweight.bnb_quant_state
        offsets = qweight.bnb_shard_offsets

        out_dim_0 = x.shape[0]
        out_dim_1 = sum(
            [quant_state[1].shape[0] for quant_state in quant_states.items()])
        out = torch.empty(out_dim_0,
                          out_dim_1,
                          dtype=torch.bfloat16,
                          device=x.device)
354
        apply_bnb_4bit(bf_x, qweight, offsets, out)
355
356
        out = out.to(original_type)

357
358
359
        if reshape_after_matmul:
            out = out.view(*original_shape[:-1], out.size(-1))

360
361
362
363
        if bias is not None:
            out += bias

        return out
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396


def _apply_bnb_4bit(
    x: torch.Tensor,
    weight: torch.Tensor,
    offsets: torch.Tensor,
    out: torch.Tensor,
) -> None:
    # only load the bitsandbytes module when needed
    from bitsandbytes import matmul_4bit
    quant_states = weight.bnb_quant_state
    current_index = 0
    for i in range(len(quant_states)):
        output_size = quant_states[i].shape[0]
        # It is more efficient to use out kwarg like
        # matmul_4bit(..., out = ...).  Infeasible now due to the bug
        # https://github.com/TimDettmers/bitsandbytes/issues/1235.
        # Need to change  after the bug is fixed.
        out[:, current_index:current_index + output_size] = matmul_4bit(
            x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
        current_index += output_size


def _apply_bnb_4bit_fake(
    x: torch.Tensor,
    weight: torch.Tensor,
    offsets: torch.Tensor,
    out: torch.Tensor,
) -> None:
    return


try:
397
398
399
400
401
    direct_register_custom_op(op_name="apply_bnb_4bit",
                              op_func=_apply_bnb_4bit,
                              mutates_args=["out"],
                              fake_impl=_apply_bnb_4bit_fake,
                              dispatch_key=current_platform.dispatch_key)
402
403
404
405
    apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit

except AttributeError as error:
    raise error
406
407
408
409
410
411
412
413
414


class BitsAndBytesMoEMethod(FusedMoEMethodBase):
    """MoE method for BitsAndBytes.

    Args:
       quant_config: The BitsAndBytes quantization config.
    """

415
416
417
418
419
420
    def __init__(
        self,
        quant_config: BitsAndBytesConfig,
        moe: FusedMoEConfig,
    ):
        super().__init__(moe)
421
422
        try:
            import bitsandbytes
423
424
            if version.parse(
                    bitsandbytes.__version__) < version.parse("0.46.1"):
425
                raise ImportError("bitsandbytes version is wrong. Please "
426
                                  "install bitsandbytes>=0.46.1.")
427
        except ImportError as err:
428
429
            raise ImportError("Please install bitsandbytes>=0.46.1 via "
                              "`pip install bitsandbytes>=0.46.1` to use "
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
                              "bitsandbytes quantizer.") from err
        self.quant_config = quant_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        if self.quant_config.load_in_8bit:
            call_fun = self._create_weights_8bit
        else:
            call_fun = self._create_weights_4bit
        call_fun(
            layer,
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            params_dtype,
            **extra_weight_attrs,
        )

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
469
        routed_scaling_factor: float = 1.0,
470
471
472
473
474
475
476
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
477
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
478
        from vllm.model_executor.layers.fused_moe import fused_experts
479
        assert self.fused_experts is None
480
481
482
483
484
485
486
487
488
489
490
491
492
493

        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `BitsAndBytesMoEMethod` yet.")
        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
494
            routed_scaling_factor=routed_scaling_factor,
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype)
        if self.quant_config.load_in_8bit:
            w13, w2 = self._apply_8bit_dequant(layer)
        else:
            w13, w2 = self._apply_4bit_dequnt(layer)
        return fused_experts(
            hidden_states=x,
            w1=w13,
            w2=w2,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            activation=activation,
            apply_router_weight_on_input=apply_router_weight_on_input,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
        )

    def _create_weights_4bit(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        quant_ratio = calculate_quant_ratio(params_dtype)
        # Fused gate_up_proj (column parallel)
        w13_total_size = (hidden_size * 2 *
                          intermediate_size_per_partition) // quant_ratio
        w13_qweight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                w13_total_size,
                1,
                dtype=torch.uint8,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_qweight)
        set_weight_attrs(w13_qweight, extra_weight_attrs)
        set_weight_attrs(
            w13_qweight,
            {
                "num_experts":
                num_experts,
                "input_dim":
                hidden_size,
                "output_dim":
                2 * intermediate_size_per_partition,
                "experts_shape": (
                    num_experts,
                    intermediate_size_per_partition * 2,
                    hidden_size,
                ),
                "pack_factor":
                quant_ratio,
                "use_bitsandbytes_4bit":
                True,
            },
        )
        # down_proj (row parallel)
        w2_total_size = (hidden_size *
                         intermediate_size_per_partition) // quant_ratio
        w2_qweight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                w2_total_size,
                1,
                dtype=torch.uint8,
            ),
            requires_grad=False,
        )
        set_weight_attrs(
            w2_qweight,
            {
                "num_experts":
                num_experts,
                "input_dim":
                intermediate_size_per_partition,
                "output_dim":
                hidden_size,
                "experts_shape": (
                    num_experts,
                    hidden_size,
                    intermediate_size_per_partition,
                ),
                "pack_factor":
                quant_ratio,
                "use_bitsandbytes_4bit":
                True,
            },
        )
        layer.register_parameter("w2_weight", w2_qweight)
        set_weight_attrs(w2_qweight, extra_weight_attrs)

    def _create_weights_8bit(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        raise NotImplementedError

    def _apply_4bit_dequnt(
            self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
        from bitsandbytes.functional import dequantize_4bit
        w13 = dequantize_4bit(
            layer.w13_weight.reshape(-1, 1),
            layer.w13_weight.bnb_quant_state,
        )
        w2 = dequantize_4bit(
            layer.w2_weight.reshape(-1, 1),
            layer.w2_weight.bnb_quant_state,
        )
        w13 = w13.reshape(layer.w13_weight.experts_shape)
        w2 = w2.reshape(layer.w2_weight.experts_shape)
        return w13, w2

    def _apply_8bit_dequant(
            self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError