deepspeedfp.py 7.11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Any, Optional
5
6
7
8

import torch
import torch.nn as nn
import torch.nn.functional as F
9
from packaging import version
10
11

from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
12
13
14
15
from vllm.model_executor.layers.quantization import (
    QuantizationConfig,
    QuantizationMethods,
)
16
17
18
19
20
from vllm.model_executor.utils import set_weight_attrs


class DeepSpeedFPConfig(QuantizationConfig):
    """Config for DeepSpeed FP quantizer. It supports fp6 and fp8.
21
22

    Args:
23
24
25
26
27
28
29
30
31
        weight_bits: the target quantization bits, 6 or 8.
        group_size: group size for quantizaiton, default to 128.
    """

    def __init__(
        self,
        weight_bits: int = 8,
        group_size: int = 512,
    ) -> None:
32
        super().__init__()
33
34
35
36
37
38
39
40
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.valid_types = [torch.bfloat16, torch.float16]

        if self.weight_bits not in (6, 8):
            raise ValueError(
                "Currently, only 6-bit or 8-bit weight quantization are "
                f"supported for DeepSpeed FP quantizaiton, but got "
41
42
                f"{self.weight_bits} bits."
            )
43
44

    def __repr__(self) -> str:
45
46
47
48
        return (
            f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), "
            f"group_size={self.group_size}"
        )
49
50

    @classmethod
51
52
    def get_name(cls) -> QuantizationMethods:
        return "deepspeedfp"
53
54

    @classmethod
55
    def from_config(cls, config: dict[str, Any]) -> "DeepSpeedFPConfig":
56
57
58
59
60
61
62
63
        weight_bits = cls.get_from_keys(config, ["bits"])
        group_size = cls.get_from_keys(config, ["group_size"])
        return cls(weight_bits=weight_bits, group_size=group_size)

    def get_linear_method(self) -> "DeepSpeedFPLinearMethod":
        return DeepSpeedFPLinearMethod(self)

    @classmethod
64
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
65
66
67
68
69
70
71
72
        return [torch.half, torch.bfloat16]

    @classmethod
    # Need to figure it out
    def get_min_capability(cls) -> int:
        return 60

    @staticmethod
73
    def get_config_filenames() -> list[str]:
74
75
76
77
78
        return [
            "quant_config.json",
            "quantize_config.json",
        ]

79
80
81
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["DeepSpeedFPLinearMethod"]:
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        if isinstance(layer, LinearBase):
            return DeepSpeedFPLinearMethod(self)
        return None


class DeepSpeedFPLinearMethod(LinearMethodBase):
    """Linear method for DeepSpeedFP quantizer.

    Args:
        quant_config: the DeepSpeedFP quantization config.
    """

    def __init__(self, quant_config: DeepSpeedFPConfig):
        self.quant_config = quant_config
        self.weight = None

98
99
100
101
102
103
104
105
106
107
108
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        weight_loader=None,
        **extra_weight_attrs,
    ):
109
110
111
112
113
114
115
116
        del output_size
        del input_size
        output_size_per_partition = sum(output_partition_sizes)
        weight = DeepSpeedFPParameter(
            torch.Size((output_size_per_partition, input_size_per_partition)),
            params_dtype=params_dtype,
            quant_config=self.quant_config,
        )
117
118
119
120
121
122
123
        set_weight_attrs(
            weight,
            {
                "input_dim": 1,
                "output_dim": 0,
            },
        )
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        layer.register_parameter("weight", weight)

        def quant_weight_loader(param, loaded_weight, *args, **kwargs):
            # Calls the original weight loader (if any), quantizes the result,
            # and then loads the quantized parameter.
            if weight_loader is not None:
                orig_param_data = param.data
                param.data = param.ds_dequantize()
                weight_loader(param, loaded_weight, *args, **kwargs)
                param.data, loaded_weight = orig_param_data, param.data
            param.ds_quantize_(loaded_weight.cuda())

        extra_weight_attrs["weight_loader"] = quant_weight_loader
        set_weight_attrs(weight, extra_weight_attrs)

139
140
141
142
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
143
        bias: torch.Tensor | None = None,
144
    ) -> torch.Tensor:
145
146
147
148
149
150
151
152
153
154
155
156
        weight = layer.weight
        y = weight.ds_dequantize()
        return F.linear(x, y, bias)


class DeepSpeedFPParameter(nn.Parameter):
    """
    DeepSpeedFP quantized parameter class that implements fp8/fp6
    quantization deepspeed. Weights are stored in quantized form on
    GPUs, and can be dequantized on-the-fly when needed by the model.
    """

157
158
159
160
161
162
    def __new__(
        cls,
        orig_shape: torch.Size,
        params_dtype: torch.dtype,
        quant_config: DeepSpeedFPConfig,
    ):
163
164
        try:
            import deepspeed
165

166
            if version.parse(deepspeed.__version__) < version.parse("0.14.2"):
167
168
169
                raise ImportError(
                    "deepspeed version is wrong. Please install deepspeed>=0.14.2."
                )
170
171
            from deepspeed.ops.fp_quantizer import FP_Quantize
        except ImportError as err:
172
173
174
175
176
177
178
179
180
181
182
183
            raise ImportError(
                "Please install deepspeed>=0.14.2 via "
                "`pip install deepspeed>=0.14.2` to use "
                "deepspeedfp quantizer."
            ) from err
        data = torch.empty(
            (
                orig_shape.numel() // quant_config.group_size,
                quant_config.group_size * quant_config.weight_bits // 8 + 4,
            ),
            dtype=torch.int8,
        )
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
        self.orig_shape = orig_shape
        self.quant_config = quant_config
        self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size)
        self.fp_quantizer.orig_shape = orig_shape
        self.fp_quantizer.orig_dtype = params_dtype
        return self

    def ds_quantize_(self, tensor: torch.Tensor):
        assert tensor.device.type == "cuda" and tensor.dtype != torch.int8
        return self.data.copy_(
            self.fp_quantizer.quantize(
                tensor.data,
                q_bits=self.quant_config.weight_bits,
198
199
            )
        )
200
201
202
203
204
205
206

    def ds_dequantize(self, fp_out=None) -> torch.Tensor:
        """
        Return a tensor containing the dequantized weights of this parameter.
        """
        assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
        return self.fp_quantizer.dequantize(
207
208
            self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits
        )
209
210
211
212
213
214
215
216

    def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor:
        """
        Return a tensor where only the weights at `indices` are dequantized
        (to save HBM -> SRAM bandwidth).
        """
        assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
        return self.fp_quantizer.selective_dequantize(
217
218
            self.data, indices, fp_out=fp_out, q_bits=self.quant_config.weight_bits
        )