bitsandbytes.py 19.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Any, Union
5
6

import torch
7
from packaging import version
8

9
10
11
12
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEQuantConfig,
)
13
14
15
16
from vllm.model_executor.layers.fused_moe.layer import (
    FusedMoE,
    FusedMoEMethodBase,
)
17
18
19
20
21
22
23
24
25
26
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
    set_weight_attrs,
)
from vllm.model_executor.layers.quantization import (
    QuantizationConfig,
    QuantizationMethods,
)
27
from vllm.platforms import current_platform
28
from vllm.utils.torch_utils import direct_register_custom_op
29
30


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def _check_bitsandbytes_version():
    min_version = "0.49.2" if current_platform.is_rocm() else "0.48.1"
    try:
        import bitsandbytes

        if version.parse(bitsandbytes.__version__) < version.parse(min_version):
            raise ImportError(
                "bitsandbytes version is wrong. Please "
                f"install bitsandbytes>={min_version}."
            )
    except ImportError as err:
        raise ImportError(
            f"Please install bitsandbytes>={min_version} via "
            f"`pip install bitsandbytes>={min_version}` to use "
            "bitsandbytes quantizer."
        ) from err


49
50
51
52
53
54
class BitsAndBytesConfig(QuantizationConfig):
    """Config class for BitsAndBytes Quantization.

    Reference: https://arxiv.org/abs/2305.14314
    """

55
56
57
58
59
    def __init__(
        self,
        load_in_8bit: bool = False,
        load_in_4bit: bool = True,
        bnb_4bit_compute_dtype: str = "float32",
60
        bnb_4bit_quant_storage: str = "uint8",
61
62
63
64
        bnb_4bit_quant_type: str = "fp4",
        bnb_4bit_use_double_quant: bool = False,
        llm_int8_enable_fp32_cpu_offload: bool = False,
        llm_int8_has_fp16_weight: bool = False,
65
        llm_int8_skip_modules: list[str] | None = None,
66
        llm_int8_threshold: float = 6.0,
67
    ) -> None:
68
        super().__init__()
69
70
71
        self.load_in_8bit = load_in_8bit
        self.load_in_4bit = load_in_4bit
        self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
72
        self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
73
74
75
76
        self.bnb_4bit_quant_type = bnb_4bit_quant_type
        self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
        self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
        self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
77
        self.llm_int8_skip_modules = llm_int8_skip_modules or []
78
        self.llm_int8_threshold = llm_int8_threshold
79

80
        if self.bnb_4bit_quant_storage not in ["uint8"]:
81
82
83
            raise ValueError(
                f"Unsupported bnb_4bit_quant_storage: {self.bnb_4bit_quant_storage}"
            )
84

85
    def __repr__(self) -> str:
86
87
88
89
90
91
92
93
        return (
            f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
            f"load_in_4bit={self.load_in_4bit}, "
            f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
            f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, "
            f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
            f"llm_int8_skip_modules={self.llm_int8_skip_modules})"
        )
94
95

    @classmethod
96
    def get_name(self) -> QuantizationMethods:
97
98
99
        return "bitsandbytes"

    @classmethod
100
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
101
102
103
        return [torch.float32, torch.float16, torch.bfloat16]

    @classmethod
104
    def get_min_capability(cls) -> int:
105
106
107
        return 70

    @staticmethod
108
    def get_config_filenames() -> list[str]:
109
        return []
110
111

    @classmethod
112
    def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig":
113
114
115
116
117
118
119
        def get_safe_value(config, keys, default_value=None):
            try:
                value = cls.get_from_keys(config, keys)
                return value if value is not None else default_value
            except ValueError:
                return default_value

120
121
122
123
124
125
126
127
128
129
130
        load_in_8bit = get_safe_value(config, ["load_in_8bit"], default_value=False)
        load_in_4bit = get_safe_value(config, ["load_in_4bit"], default_value=True)
        bnb_4bit_compute_dtype = get_safe_value(
            config, ["bnb_4bit_compute_dtype"], default_value="float32"
        )
        bnb_4bit_quant_storage = get_safe_value(
            config, ["bnb_4bit_quant_storage"], default_value="uint8"
        )
        bnb_4bit_quant_type = get_safe_value(
            config, ["bnb_4bit_quant_type"], default_value="fp4"
        )
131
        bnb_4bit_use_double_quant = get_safe_value(
132
133
            config, ["bnb_4bit_use_double_quant"], default_value=False
        )
134
        llm_int8_enable_fp32_cpu_offload = get_safe_value(
135
136
137
138
139
140
141
142
143
144
145
            config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False
        )
        llm_int8_has_fp16_weight = get_safe_value(
            config, ["llm_int8_has_fp16_weight"], default_value=False
        )
        llm_int8_skip_modules = get_safe_value(
            config, ["llm_int8_skip_modules"], default_value=[]
        )
        llm_int8_threshold = get_safe_value(
            config, ["llm_int8_threshold"], default_value=6.0
        )
146
147
148
149
150

        return cls(
            load_in_8bit=load_in_8bit,
            load_in_4bit=load_in_4bit,
            bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
151
            bnb_4bit_quant_storage=bnb_4bit_quant_storage,
152
153
154
155
156
            bnb_4bit_quant_type=bnb_4bit_quant_type,
            bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
            llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload,
            llm_int8_has_fp16_weight=llm_int8_has_fp16_weight,
            llm_int8_skip_modules=llm_int8_skip_modules,
157
158
            llm_int8_threshold=llm_int8_threshold,
        )
159

160
161
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
162
    ) -> Union["LinearMethodBase", "BitsAndBytesMoEMethod"] | None:
163
        if isinstance(layer, LinearBase):
164
165
            if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
                return UnquantizedLinearMethod()
166
            return BitsAndBytesLinearMethod(self)
167
        elif isinstance(layer, FusedMoE):
168
            return BitsAndBytesMoEMethod(self, layer.moe_config)
169
170
171
        return None


172
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
173
    # Split the prefix into its dot-separated components
174
    components = prefix.split(".")
175
176

    # Check if any of the skip modules exactly matches any component
177
178
179
    substr_check = any(
        module_name in components for module_name in llm_int8_skip_modules
    )
180
181

    # Allow certain layers to not be quantized
182
    set_components = set(".".join(components[: i + 1]) for i in range(len(components)))
183
184
185
186
    set_llm_int8_skip_modules = set(llm_int8_skip_modules)
    prefix_check = len(set_llm_int8_skip_modules & set_components) != 0

    return substr_check or prefix_check
187
188


189
190
191
192
193
194
195
def calculate_quant_ratio(dtype):
    if dtype.is_floating_point:
        return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
    else:
        return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits


196
197
198
199
200
201
202
203
class BitsAndBytesLinearMethod(LinearMethodBase):
    """Linear method for BitsAndBytes.

    Args:
       quant_config: The BitsAndBytes quantization config.
    """

    def __init__(self, quant_config: BitsAndBytesConfig):
204
        _check_bitsandbytes_version()
205
206
        self.quant_config = quant_config

207
208
209
210
211
212
213
214
215
216
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
217
218
219
220
        from bitsandbytes.nn import Int8Params

        def create_qweight_for_8bit():
            qweight = Int8Params(
221
222
223
224
225
                data=torch.empty(
                    sum(output_partition_sizes),
                    input_size_per_partition,
                    dtype=torch.int8,
                ),
226
                has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight,
227
228
                requires_grad=False,
            )
229
            set_weight_attrs(
230
231
                qweight,
                {
232
233
234
235
                    "input_dim": 0,
                    "output_dim": 0,
                    "pack_factor": 1,
                    "use_bitsandbytes_8bit": True,
236
237
238
                    "generation": 0,
                },
            )
239
240
241
242
243
244
245
246
            return qweight

        def create_qweight_for_4bit():
            quant_ratio = calculate_quant_ratio(params_dtype)

            total_size = input_size_per_partition * sum(output_partition_sizes)
            if total_size % quant_ratio != 0:
                raise ValueError(
247
248
                    "The input size is not aligned with the quantized weight shape."
                )
249

250
251
252
253
            qweight = torch.nn.Parameter(
                torch.empty(total_size // quant_ratio, 1, dtype=torch.uint8),
                requires_grad=False,
            )
254
            set_weight_attrs(
255
256
                qweight,
                {
257
258
259
                    "input_dim": 0,
                    "output_dim": 0,
                    "pack_factor": quant_ratio,
260
261
262
                    "use_bitsandbytes_4bit": True,
                },
            )
263
264
265
266
            return qweight

        if self.quant_config.load_in_8bit:
            qweight = create_qweight_for_8bit()
267
        else:
268
            qweight = create_qweight_for_4bit()
269
270
271
        # Enable parameters to have the same name as in the BNB
        # checkpoint format.
        layer.register_parameter("weight", qweight)
272
273
        set_weight_attrs(qweight, extra_weight_attrs)

274
275
276
277
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
278
        bias: torch.Tensor | None = None,
279
    ) -> torch.Tensor:
280
281
282
283
284
285
        if self.quant_config.load_in_8bit:
            return self._apply_8bit_weight(layer, x, bias)
        else:
            return self._apply_4bit_weight(layer, x, bias)

    def _apply_8bit_weight(
286
287
288
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
289
        bias: torch.Tensor | None = None,
290
    ) -> torch.Tensor:
291
292
293
294
        # only load the bitsandbytes module when needed
        from bitsandbytes import MatmulLtState, matmul

        original_type = x.dtype
295
296
297
298
299
        original_shape = x.shape
        reshape_after_matmul = False
        if x.ndim > 2:
            x = x.reshape(-1, x.size(-1))
            reshape_after_matmul = True
300
301
        bf_x = x.to(torch.bfloat16)

302
        qweight = layer.weight
303
304
305
306
307
308
309
        offsets = qweight.bnb_shard_offsets
        quant_states = qweight.bnb_quant_state
        matmul_states = qweight.matmul_state
        generation = qweight.generation

        out_dim_0 = x.shape[0]
        out_dim_1 = sum(
310
311
312
            [quant_state[1].shape[0] for quant_state in quant_states.items()]
        )
        out = torch.empty(out_dim_0, out_dim_1, dtype=torch.float16, device=x.device)
313
314
315
316
317
318
319
320
321

        current_index = 0
        for i in range(len(quant_states)):
            output_size = quant_states[i].shape[0]

            # in profile_run or the first generation of inference,
            # create new matmul_states
            if generation == 0 or generation == 1:
                matmul_states[i] = MatmulLtState()
322
                matmul_states[i].CB = qweight[offsets[i] : offsets[i + 1]]
323
                matmul_states[i].SCB = quant_states[i].to(x.device)
324
325
326
327
                matmul_states[i].threshold = self.quant_config.llm_int8_threshold
                matmul_states[
                    i
                ].has_fp16_weights = self.quant_config.llm_int8_has_fp16_weight
328
                matmul_states[i].is_training = False
329
330
331
332
                if (
                    matmul_states[i].threshold > 0.0
                    and not matmul_states[i].has_fp16_weights
                ):
333
334
335
336
                    matmul_states[i].use_pool = True

            new_x = bf_x.unsqueeze(0)

337
338
339
            out[:, current_index : current_index + output_size] = matmul(
                new_x, qweight[offsets[i] : offsets[i + 1]], state=matmul_states[i]
            )
340
341
342
343
344

            current_index += output_size

        out = out.to(original_type)

345
346
347
        if reshape_after_matmul:
            out = out.view(*original_shape[:-1], out.size(-1))

348
349
350
351
352
353
354
355
        if bias is not None:
            out += bias

        qweight.generation += 1

        return out

    def _apply_4bit_weight(
356
357
358
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
359
        bias: torch.Tensor | None = None,
360
    ) -> torch.Tensor:
361
        original_type = x.dtype
362
363
364
365
366
        original_shape = x.shape
        reshape_after_matmul = False
        if x.ndim > 2:
            x = x.reshape(-1, x.size(-1))
            reshape_after_matmul = True
367
368
        bf_x = x.to(torch.bfloat16)

369
        qweight = layer.weight
370
371
372
373
374
        quant_states = qweight.bnb_quant_state
        offsets = qweight.bnb_shard_offsets

        out_dim_0 = x.shape[0]
        out_dim_1 = sum(
375
376
377
            [quant_state[1].shape[0] for quant_state in quant_states.items()]
        )
        out = torch.empty(out_dim_0, out_dim_1, dtype=torch.bfloat16, device=x.device)
378
        apply_bnb_4bit(bf_x, qweight, offsets, out)
379
380
        out = out.to(original_type)

381
382
383
        if reshape_after_matmul:
            out = out.view(*original_shape[:-1], out.size(-1))

384
385
386
387
        if bias is not None:
            out += bias

        return out
388
389
390
391
392
393
394
395
396
397


def _apply_bnb_4bit(
    x: torch.Tensor,
    weight: torch.Tensor,
    offsets: torch.Tensor,
    out: torch.Tensor,
) -> None:
    # only load the bitsandbytes module when needed
    from bitsandbytes import matmul_4bit
398

399
400
401
402
403
404
405
406
    quant_states = weight.bnb_quant_state
    current_index = 0
    for i in range(len(quant_states)):
        output_size = quant_states[i].shape[0]
        # It is more efficient to use out kwarg like
        # matmul_4bit(..., out = ...).  Infeasible now due to the bug
        # https://github.com/TimDettmers/bitsandbytes/issues/1235.
        # Need to change  after the bug is fixed.
407
408
409
        out[:, current_index : current_index + output_size] = matmul_4bit(
            x, weight[offsets[i] : offsets[i + 1]].t(), quant_states[i]
        )
410
411
412
413
414
415
416
417
418
419
420
421
422
        current_index += output_size


def _apply_bnb_4bit_fake(
    x: torch.Tensor,
    weight: torch.Tensor,
    offsets: torch.Tensor,
    out: torch.Tensor,
) -> None:
    return


try:
423
424
425
426
427
428
429
    direct_register_custom_op(
        op_name="apply_bnb_4bit",
        op_func=_apply_bnb_4bit,
        mutates_args=["out"],
        fake_impl=_apply_bnb_4bit_fake,
        dispatch_key=current_platform.dispatch_key,
    )
430
431
432
433
    apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit

except AttributeError as error:
    raise error
434
435
436
437
438
439
440
441
442


class BitsAndBytesMoEMethod(FusedMoEMethodBase):
    """MoE method for BitsAndBytes.

    Args:
       quant_config: The BitsAndBytes quantization config.
    """

443
444
445
446
447
448
    def __init__(
        self,
        quant_config: BitsAndBytesConfig,
        moe: FusedMoEConfig,
    ):
        super().__init__(moe)
449
        _check_bitsandbytes_version()
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        self.quant_config = quant_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        if self.quant_config.load_in_8bit:
            call_fun = self._create_weights_8bit
        else:
            call_fun = self._create_weights_4bit
        call_fun(
            layer,
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            params_dtype,
            **extra_weight_attrs,
        )

474
    def get_fused_moe_quant_config(
475
        self, layer: torch.nn.Module
476
    ) -> FusedMoEQuantConfig | None:
477
478
        return None

479
480
    def apply(
        self,
481
        layer: FusedMoE,
482
        x: torch.Tensor,
483
484
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
485
        shared_experts_input: torch.Tensor | None,
486
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
487
        from vllm.model_executor.layers.fused_moe import fused_experts
488

489
        # TODO(bnell): Do these need to be called on the hot path?
490
491
492
493
494
495
496
497
498
499
        if self.quant_config.load_in_8bit:
            w13, w2 = self._apply_8bit_dequant(layer)
        else:
            w13, w2 = self._apply_4bit_dequnt(layer)
        return fused_experts(
            hidden_states=x,
            w1=w13,
            w2=w2,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
500
            inplace=not self.moe.disable_inplace,
501
502
503
504
            activation=layer.activation,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
505
            quant_config=self.moe_quant_config,
506
507
508
509
510
511
512
513
514
515
516
517
518
        )

    def _create_weights_4bit(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        quant_ratio = calculate_quant_ratio(params_dtype)
        # Fused gate_up_proj (column parallel)
519
520
521
        w13_total_size = (
            hidden_size * 2 * intermediate_size_per_partition
        ) // quant_ratio
522
523
524
525
526
527
528
529
530
531
532
533
534
535
        w13_qweight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                w13_total_size,
                1,
                dtype=torch.uint8,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_qweight)
        set_weight_attrs(w13_qweight, extra_weight_attrs)
        set_weight_attrs(
            w13_qweight,
            {
536
537
538
                "num_experts": num_experts,
                "input_dim": hidden_size,
                "output_dim": 2 * intermediate_size_per_partition,
539
540
541
542
543
                "experts_shape": (
                    num_experts,
                    intermediate_size_per_partition * 2,
                    hidden_size,
                ),
544
545
                "pack_factor": quant_ratio,
                "use_bitsandbytes_4bit": True,
546
547
548
            },
        )
        # down_proj (row parallel)
549
        w2_total_size = (hidden_size * intermediate_size_per_partition) // quant_ratio
550
551
552
553
554
555
556
557
558
559
560
561
        w2_qweight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                w2_total_size,
                1,
                dtype=torch.uint8,
            ),
            requires_grad=False,
        )
        set_weight_attrs(
            w2_qweight,
            {
562
563
564
                "num_experts": num_experts,
                "input_dim": intermediate_size_per_partition,
                "output_dim": hidden_size,
565
566
567
568
569
                "experts_shape": (
                    num_experts,
                    hidden_size,
                    intermediate_size_per_partition,
                ),
570
571
                "pack_factor": quant_ratio,
                "use_bitsandbytes_4bit": True,
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
            },
        )
        layer.register_parameter("w2_weight", w2_qweight)
        set_weight_attrs(w2_qweight, extra_weight_attrs)

    def _create_weights_8bit(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        raise NotImplementedError

    def _apply_4bit_dequnt(
589
590
        self, layer: torch.nn.Module
    ) -> tuple[torch.Tensor, torch.Tensor]:
591
        from bitsandbytes.functional import dequantize_4bit
592

593
594
595
596
597
598
599
600
601
602
603
604
605
        w13 = dequantize_4bit(
            layer.w13_weight.reshape(-1, 1),
            layer.w13_weight.bnb_quant_state,
        )
        w2 = dequantize_4bit(
            layer.w2_weight.reshape(-1, 1),
            layer.w2_weight.bnb_quant_state,
        )
        w13 = w13.reshape(layer.w13_weight.experts_shape)
        w2 = w2.reshape(layer.w2_weight.experts_shape)
        return w13, w2

    def _apply_8bit_dequant(
606
607
        self, layer: torch.nn.Module
    ) -> tuple[torch.Tensor, torch.Tensor]:
608
        raise NotImplementedError