bitsandbytes.py 20.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Any, Union
5
6

import torch
7
from packaging import version
8

9
10
11
12
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEQuantConfig,
)
13
14
15
16
from vllm.model_executor.layers.fused_moe.layer import (
    FusedMoE,
    FusedMoEMethodBase,
)
17
18
19
20
21
22
23
24
25
26
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
    set_weight_attrs,
)
from vllm.model_executor.layers.quantization import (
    QuantizationConfig,
    QuantizationMethods,
)
27
from vllm.platforms import current_platform
28
from vllm.utils.torch_utils import direct_register_custom_op
29
30
31
32
33
34
35
36


class BitsAndBytesConfig(QuantizationConfig):
    """Config class for BitsAndBytes Quantization.

    Reference: https://arxiv.org/abs/2305.14314
    """

37
38
39
40
41
    def __init__(
        self,
        load_in_8bit: bool = False,
        load_in_4bit: bool = True,
        bnb_4bit_compute_dtype: str = "float32",
42
        bnb_4bit_quant_storage: str = "uint8",
43
44
45
46
        bnb_4bit_quant_type: str = "fp4",
        bnb_4bit_use_double_quant: bool = False,
        llm_int8_enable_fp32_cpu_offload: bool = False,
        llm_int8_has_fp16_weight: bool = False,
47
        llm_int8_skip_modules: list[str] | None = None,
48
        llm_int8_threshold: float = 6.0,
49
    ) -> None:
50
        super().__init__()
51
52
53
        self.load_in_8bit = load_in_8bit
        self.load_in_4bit = load_in_4bit
        self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
54
        self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
55
56
57
58
        self.bnb_4bit_quant_type = bnb_4bit_quant_type
        self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
        self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
        self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
59
        self.llm_int8_skip_modules = llm_int8_skip_modules or []
60
        self.llm_int8_threshold = llm_int8_threshold
61

62
        if self.bnb_4bit_quant_storage not in ["uint8"]:
63
64
65
            raise ValueError(
                f"Unsupported bnb_4bit_quant_storage: {self.bnb_4bit_quant_storage}"
            )
66

67
    def __repr__(self) -> str:
68
69
70
71
72
73
74
75
        return (
            f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
            f"load_in_4bit={self.load_in_4bit}, "
            f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
            f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, "
            f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
            f"llm_int8_skip_modules={self.llm_int8_skip_modules})"
        )
76
77

    @classmethod
78
    def get_name(self) -> QuantizationMethods:
79
80
81
        return "bitsandbytes"

    @classmethod
82
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
83
84
85
        return [torch.float32, torch.float16, torch.bfloat16]

    @classmethod
86
    def get_min_capability(cls) -> int:
87
88
89
        return 70

    @staticmethod
90
    def get_config_filenames() -> list[str]:
91
        return []
92
93

    @classmethod
94
    def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig":
95
96
97
98
99
100
101
        def get_safe_value(config, keys, default_value=None):
            try:
                value = cls.get_from_keys(config, keys)
                return value if value is not None else default_value
            except ValueError:
                return default_value

102
103
104
105
106
107
108
109
110
111
112
        load_in_8bit = get_safe_value(config, ["load_in_8bit"], default_value=False)
        load_in_4bit = get_safe_value(config, ["load_in_4bit"], default_value=True)
        bnb_4bit_compute_dtype = get_safe_value(
            config, ["bnb_4bit_compute_dtype"], default_value="float32"
        )
        bnb_4bit_quant_storage = get_safe_value(
            config, ["bnb_4bit_quant_storage"], default_value="uint8"
        )
        bnb_4bit_quant_type = get_safe_value(
            config, ["bnb_4bit_quant_type"], default_value="fp4"
        )
113
        bnb_4bit_use_double_quant = get_safe_value(
114
115
            config, ["bnb_4bit_use_double_quant"], default_value=False
        )
116
        llm_int8_enable_fp32_cpu_offload = get_safe_value(
117
118
119
120
121
122
123
124
125
126
127
            config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False
        )
        llm_int8_has_fp16_weight = get_safe_value(
            config, ["llm_int8_has_fp16_weight"], default_value=False
        )
        llm_int8_skip_modules = get_safe_value(
            config, ["llm_int8_skip_modules"], default_value=[]
        )
        llm_int8_threshold = get_safe_value(
            config, ["llm_int8_threshold"], default_value=6.0
        )
128
129
130
131
132

        return cls(
            load_in_8bit=load_in_8bit,
            load_in_4bit=load_in_4bit,
            bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
133
            bnb_4bit_quant_storage=bnb_4bit_quant_storage,
134
135
136
137
138
            bnb_4bit_quant_type=bnb_4bit_quant_type,
            bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
            llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload,
            llm_int8_has_fp16_weight=llm_int8_has_fp16_weight,
            llm_int8_skip_modules=llm_int8_skip_modules,
139
140
            llm_int8_threshold=llm_int8_threshold,
        )
141

142
143
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
144
    ) -> Union["LinearMethodBase", "BitsAndBytesMoEMethod"] | None:
145
        if isinstance(layer, LinearBase):
146
147
            if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
                return UnquantizedLinearMethod()
148
            return BitsAndBytesLinearMethod(self)
149
        elif isinstance(layer, FusedMoE):
150
            return BitsAndBytesMoEMethod(self, layer.moe_config)
151
152
153
        return None


154
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
155
    # Split the prefix into its dot-separated components
156
    components = prefix.split(".")
157
158

    # Check if any of the skip modules exactly matches any component
159
160
161
    substr_check = any(
        module_name in components for module_name in llm_int8_skip_modules
    )
162
163

    # Allow certain layers to not be quantized
164
    set_components = set(".".join(components[: i + 1]) for i in range(len(components)))
165
166
167
168
    set_llm_int8_skip_modules = set(llm_int8_skip_modules)
    prefix_check = len(set_llm_int8_skip_modules & set_components) != 0

    return substr_check or prefix_check
169
170


171
172
173
174
175
176
177
def calculate_quant_ratio(dtype):
    if dtype.is_floating_point:
        return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
    else:
        return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits


178
179
180
181
182
183
184
185
186
187
class BitsAndBytesLinearMethod(LinearMethodBase):
    """Linear method for BitsAndBytes.

    Args:
       quant_config: The BitsAndBytes quantization config.
    """

    def __init__(self, quant_config: BitsAndBytesConfig):
        try:
            import bitsandbytes
188
189
190
191
192
193

            if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
                raise ImportError(
                    "bitsandbytes version is wrong. Please "
                    "install bitsandbytes>=0.46.1."
                )
194
        except ImportError as err:
195
196
197
198
199
            raise ImportError(
                "Please install bitsandbytes>=0.46.1 via "
                "`pip install bitsandbytes>=0.46.1` to use "
                "bitsandbytes quantizer."
            ) from err
200
201
202

        self.quant_config = quant_config

203
204
205
206
207
208
209
210
211
212
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
213
214
215
216
        from bitsandbytes.nn import Int8Params

        def create_qweight_for_8bit():
            qweight = Int8Params(
217
218
219
220
221
                data=torch.empty(
                    sum(output_partition_sizes),
                    input_size_per_partition,
                    dtype=torch.int8,
                ),
222
                has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight,
223
224
                requires_grad=False,
            )
225
            set_weight_attrs(
226
227
                qweight,
                {
228
229
230
231
                    "input_dim": 0,
                    "output_dim": 0,
                    "pack_factor": 1,
                    "use_bitsandbytes_8bit": True,
232
233
234
                    "generation": 0,
                },
            )
235
236
237
238
239
240
241
242
            return qweight

        def create_qweight_for_4bit():
            quant_ratio = calculate_quant_ratio(params_dtype)

            total_size = input_size_per_partition * sum(output_partition_sizes)
            if total_size % quant_ratio != 0:
                raise ValueError(
243
244
                    "The input size is not aligned with the quantized weight shape."
                )
245

246
247
248
249
            qweight = torch.nn.Parameter(
                torch.empty(total_size // quant_ratio, 1, dtype=torch.uint8),
                requires_grad=False,
            )
250
            set_weight_attrs(
251
252
                qweight,
                {
253
254
255
                    "input_dim": 0,
                    "output_dim": 0,
                    "pack_factor": quant_ratio,
256
257
258
                    "use_bitsandbytes_4bit": True,
                },
            )
259
260
261
262
            return qweight

        if self.quant_config.load_in_8bit:
            qweight = create_qweight_for_8bit()
263
        else:
264
            qweight = create_qweight_for_4bit()
265
266
267
        # Enable parameters to have the same name as in the BNB
        # checkpoint format.
        layer.register_parameter("weight", qweight)
268
269
        set_weight_attrs(qweight, extra_weight_attrs)

270
271
272
273
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
274
        bias: torch.Tensor | None = None,
275
    ) -> torch.Tensor:
276
277
278
279
280
281
        if self.quant_config.load_in_8bit:
            return self._apply_8bit_weight(layer, x, bias)
        else:
            return self._apply_4bit_weight(layer, x, bias)

    def _apply_8bit_weight(
282
283
284
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
285
        bias: torch.Tensor | None = None,
286
    ) -> torch.Tensor:
287
288
289
290
        # only load the bitsandbytes module when needed
        from bitsandbytes import MatmulLtState, matmul

        original_type = x.dtype
291
292
293
294
295
        original_shape = x.shape
        reshape_after_matmul = False
        if x.ndim > 2:
            x = x.reshape(-1, x.size(-1))
            reshape_after_matmul = True
296
297
        bf_x = x.to(torch.bfloat16)

298
        qweight = layer.weight
299
300
301
302
303
304
305
        offsets = qweight.bnb_shard_offsets
        quant_states = qweight.bnb_quant_state
        matmul_states = qweight.matmul_state
        generation = qweight.generation

        out_dim_0 = x.shape[0]
        out_dim_1 = sum(
306
307
308
            [quant_state[1].shape[0] for quant_state in quant_states.items()]
        )
        out = torch.empty(out_dim_0, out_dim_1, dtype=torch.float16, device=x.device)
309
310
311
312
313
314
315
316
317

        current_index = 0
        for i in range(len(quant_states)):
            output_size = quant_states[i].shape[0]

            # in profile_run or the first generation of inference,
            # create new matmul_states
            if generation == 0 or generation == 1:
                matmul_states[i] = MatmulLtState()
318
                matmul_states[i].CB = qweight[offsets[i] : offsets[i + 1]]
319
                matmul_states[i].SCB = quant_states[i].to(x.device)
320
321
322
323
                matmul_states[i].threshold = self.quant_config.llm_int8_threshold
                matmul_states[
                    i
                ].has_fp16_weights = self.quant_config.llm_int8_has_fp16_weight
324
                matmul_states[i].is_training = False
325
326
327
328
                if (
                    matmul_states[i].threshold > 0.0
                    and not matmul_states[i].has_fp16_weights
                ):
329
330
331
332
                    matmul_states[i].use_pool = True

            new_x = bf_x.unsqueeze(0)

333
334
335
            out[:, current_index : current_index + output_size] = matmul(
                new_x, qweight[offsets[i] : offsets[i + 1]], state=matmul_states[i]
            )
336
337
338
339

            current_index += output_size

            # only update the matmul_states if it is not profile_run
340
341
342
343
344
345
            if (
                generation > 0
                and not self.quant_config.llm_int8_has_fp16_weight
                and matmul_states[i].CB is not None
                and matmul_states[i].CxB is not None
            ):
346
                del matmul_states[i].CB
347
                qweight[offsets[i] : offsets[i + 1]] = matmul_states[i].CxB
348
349
350

        out = out.to(original_type)

351
352
353
        if reshape_after_matmul:
            out = out.view(*original_shape[:-1], out.size(-1))

354
355
356
357
358
359
360
361
        if bias is not None:
            out += bias

        qweight.generation += 1

        return out

    def _apply_4bit_weight(
362
363
364
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
365
        bias: torch.Tensor | None = None,
366
    ) -> torch.Tensor:
367
        original_type = x.dtype
368
369
370
371
372
        original_shape = x.shape
        reshape_after_matmul = False
        if x.ndim > 2:
            x = x.reshape(-1, x.size(-1))
            reshape_after_matmul = True
373
374
        bf_x = x.to(torch.bfloat16)

375
        qweight = layer.weight
376
377
378
379
380
        quant_states = qweight.bnb_quant_state
        offsets = qweight.bnb_shard_offsets

        out_dim_0 = x.shape[0]
        out_dim_1 = sum(
381
382
383
            [quant_state[1].shape[0] for quant_state in quant_states.items()]
        )
        out = torch.empty(out_dim_0, out_dim_1, dtype=torch.bfloat16, device=x.device)
384
        apply_bnb_4bit(bf_x, qweight, offsets, out)
385
386
        out = out.to(original_type)

387
388
389
        if reshape_after_matmul:
            out = out.view(*original_shape[:-1], out.size(-1))

390
391
392
393
        if bias is not None:
            out += bias

        return out
394
395
396
397
398
399
400
401
402
403


def _apply_bnb_4bit(
    x: torch.Tensor,
    weight: torch.Tensor,
    offsets: torch.Tensor,
    out: torch.Tensor,
) -> None:
    # only load the bitsandbytes module when needed
    from bitsandbytes import matmul_4bit
404

405
406
407
408
409
410
411
412
    quant_states = weight.bnb_quant_state
    current_index = 0
    for i in range(len(quant_states)):
        output_size = quant_states[i].shape[0]
        # It is more efficient to use out kwarg like
        # matmul_4bit(..., out = ...).  Infeasible now due to the bug
        # https://github.com/TimDettmers/bitsandbytes/issues/1235.
        # Need to change  after the bug is fixed.
413
414
415
        out[:, current_index : current_index + output_size] = matmul_4bit(
            x, weight[offsets[i] : offsets[i + 1]].t(), quant_states[i]
        )
416
417
418
419
420
421
422
423
424
425
426
427
428
        current_index += output_size


def _apply_bnb_4bit_fake(
    x: torch.Tensor,
    weight: torch.Tensor,
    offsets: torch.Tensor,
    out: torch.Tensor,
) -> None:
    return


try:
429
430
431
432
433
434
435
    direct_register_custom_op(
        op_name="apply_bnb_4bit",
        op_func=_apply_bnb_4bit,
        mutates_args=["out"],
        fake_impl=_apply_bnb_4bit_fake,
        dispatch_key=current_platform.dispatch_key,
    )
436
437
438
439
    apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit

except AttributeError as error:
    raise error
440
441
442
443
444
445
446
447
448


class BitsAndBytesMoEMethod(FusedMoEMethodBase):
    """MoE method for BitsAndBytes.

    Args:
       quant_config: The BitsAndBytes quantization config.
    """

449
450
451
452
453
454
    def __init__(
        self,
        quant_config: BitsAndBytesConfig,
        moe: FusedMoEConfig,
    ):
        super().__init__(moe)
455
456
        try:
            import bitsandbytes
457
458
459
460
461
462

            if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
                raise ImportError(
                    "bitsandbytes version is wrong. Please "
                    "install bitsandbytes>=0.46.1."
                )
463
        except ImportError as err:
464
465
466
467
468
            raise ImportError(
                "Please install bitsandbytes>=0.46.1 via "
                "`pip install bitsandbytes>=0.46.1` to use "
                "bitsandbytes quantizer."
            ) from err
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
        self.quant_config = quant_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        if self.quant_config.load_in_8bit:
            call_fun = self._create_weights_8bit
        else:
            call_fun = self._create_weights_4bit
        call_fun(
            layer,
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            params_dtype,
            **extra_weight_attrs,
        )

493
    def get_fused_moe_quant_config(
494
        self, layer: torch.nn.Module
495
    ) -> FusedMoEQuantConfig | None:
496
497
        return None

498
499
    def apply(
        self,
500
        layer: FusedMoE,
501
        x: torch.Tensor,
502
503
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
504
        shared_experts_input: torch.Tensor | None,
505
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
506
        from vllm.model_executor.layers.fused_moe import fused_experts
507

508
        # TODO(bnell): Do these need to be called on the hot path?
509
510
511
512
513
514
515
516
517
518
        if self.quant_config.load_in_8bit:
            w13, w2 = self._apply_8bit_dequant(layer)
        else:
            w13, w2 = self._apply_4bit_dequnt(layer)
        return fused_experts(
            hidden_states=x,
            w1=w13,
            w2=w2,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
519
            inplace=not self.moe.disable_inplace,
520
521
522
523
            activation=layer.activation,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
524
            quant_config=self.moe_quant_config,
525
526
527
528
529
530
531
532
533
534
535
536
537
        )

    def _create_weights_4bit(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        quant_ratio = calculate_quant_ratio(params_dtype)
        # Fused gate_up_proj (column parallel)
538
539
540
        w13_total_size = (
            hidden_size * 2 * intermediate_size_per_partition
        ) // quant_ratio
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        w13_qweight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                w13_total_size,
                1,
                dtype=torch.uint8,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_qweight)
        set_weight_attrs(w13_qweight, extra_weight_attrs)
        set_weight_attrs(
            w13_qweight,
            {
555
556
557
                "num_experts": num_experts,
                "input_dim": hidden_size,
                "output_dim": 2 * intermediate_size_per_partition,
558
559
560
561
562
                "experts_shape": (
                    num_experts,
                    intermediate_size_per_partition * 2,
                    hidden_size,
                ),
563
564
                "pack_factor": quant_ratio,
                "use_bitsandbytes_4bit": True,
565
566
567
            },
        )
        # down_proj (row parallel)
568
        w2_total_size = (hidden_size * intermediate_size_per_partition) // quant_ratio
569
570
571
572
573
574
575
576
577
578
579
580
        w2_qweight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                w2_total_size,
                1,
                dtype=torch.uint8,
            ),
            requires_grad=False,
        )
        set_weight_attrs(
            w2_qweight,
            {
581
582
583
                "num_experts": num_experts,
                "input_dim": intermediate_size_per_partition,
                "output_dim": hidden_size,
584
585
586
587
588
                "experts_shape": (
                    num_experts,
                    hidden_size,
                    intermediate_size_per_partition,
                ),
589
590
                "pack_factor": quant_ratio,
                "use_bitsandbytes_4bit": True,
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
            },
        )
        layer.register_parameter("w2_weight", w2_qweight)
        set_weight_attrs(w2_qweight, extra_weight_attrs)

    def _create_weights_8bit(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        raise NotImplementedError

    def _apply_4bit_dequnt(
608
609
        self, layer: torch.nn.Module
    ) -> tuple[torch.Tensor, torch.Tensor]:
610
        from bitsandbytes.functional import dequantize_4bit
611

612
613
614
615
616
617
618
619
620
621
622
623
624
        w13 = dequantize_4bit(
            layer.w13_weight.reshape(-1, 1),
            layer.w13_weight.bnb_quant_state,
        )
        w2 = dequantize_4bit(
            layer.w2_weight.reshape(-1, 1),
            layer.w2_weight.bnb_quant_state,
        )
        w13 = w13.reshape(layer.w13_weight.experts_shape)
        w2 = w2.reshape(layer.w2_weight.experts_shape)
        return w13, w2

    def _apply_8bit_dequant(
625
626
        self, layer: torch.nn.Module
    ) -> tuple[torch.Tensor, torch.Tensor]:
627
        raise NotImplementedError