deepspeedfp.py 6.98 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import Any, Dict, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs


class DeepSpeedFPConfig(QuantizationConfig):
    """Config for DeepSpeed FP quantizer. It supports fp6 and fp8.
    
    Args: 
        weight_bits: the target quantization bits, 6 or 8.
        group_size: group size for quantizaiton, default to 128.
    """

    def __init__(
        self,
        weight_bits: int = 8,
        group_size: int = 512,
    ) -> None:
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.valid_types = [torch.bfloat16, torch.float16]

        if self.weight_bits not in (6, 8):
            raise ValueError(
                "Currently, only 6-bit or 8-bit weight quantization are "
                f"supported for DeepSpeed FP quantizaiton, but got "
                f"{self.weight_bits} bits.")

    def __repr__(self) -> str:
        return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), "
                f"group_size={self.group_size}")

    @classmethod
    def get_name(cls) -> str:
        return "DeepSpeedFP"

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
        weight_bits = cls.get_from_keys(config, ["bits"])
        group_size = cls.get_from_keys(config, ["group_size"])
        return cls(weight_bits=weight_bits, group_size=group_size)

    def get_linear_method(self) -> "DeepSpeedFPLinearMethod":
        return DeepSpeedFPLinearMethod(self)

    def get_scaled_act_names(self) -> List[str]:
        return []

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.half, torch.bfloat16]

    @classmethod
    # Need to figure it out
    def get_min_capability(cls) -> int:
        return 60

    @staticmethod
    def get_config_filenames() -> List[str]:
        return [
            "quant_config.json",
            "quantize_config.json",
        ]

72
73
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["DeepSpeedFPLinearMethod"]:
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        if isinstance(layer, LinearBase):
            return DeepSpeedFPLinearMethod(self)
        return None


class DeepSpeedFPLinearMethod(LinearMethodBase):
    """Linear method for DeepSpeedFP quantizer.

    Args:
        quant_config: the DeepSpeedFP quantization config.
    """

    def __init__(self, quant_config: DeepSpeedFPConfig):
        self.quant_config = quant_config
        self.weight = None

    def create_weights(self,
                       layer: torch.nn.Module,
                       input_size_per_partition: int,
                       output_partition_sizes: List[int],
                       input_size: int,
                       output_size: int,
                       params_dtype: torch.dtype,
                       weight_loader=None,
                       **extra_weight_attrs):
        del output_size
        del input_size
        output_size_per_partition = sum(output_partition_sizes)
        weight = DeepSpeedFPParameter(
            torch.Size((output_size_per_partition, input_size_per_partition)),
            params_dtype=params_dtype,
            quant_config=self.quant_config,
        )
        set_weight_attrs(weight, {
            "input_dim": 1,
            "output_dim": 0,
        })
        layer.register_parameter("weight", weight)

        def quant_weight_loader(param, loaded_weight, *args, **kwargs):
            # Calls the original weight loader (if any), quantizes the result,
            # and then loads the quantized parameter.
            if weight_loader is not None:
                orig_param_data = param.data
                param.data = param.ds_dequantize()
                weight_loader(param, loaded_weight, *args, **kwargs)
                param.data, loaded_weight = orig_param_data, param.data
            param.ds_quantize_(loaded_weight.cuda())

        extra_weight_attrs["weight_loader"] = quant_weight_loader
        set_weight_attrs(weight, extra_weight_attrs)

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        weight = layer.weight
        y = weight.ds_dequantize()
        return F.linear(x, y, bias)


class DeepSpeedFPParameter(nn.Parameter):
    """
    DeepSpeedFP quantized parameter class that implements fp8/fp6
    quantization deepspeed. Weights are stored in quantized form on
    GPUs, and can be dequantized on-the-fly when needed by the model.
    """

    def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
                quant_config: DeepSpeedFPConfig):
        try:
            import deepspeed
            if deepspeed.__version__ < "0.14.2":
                raise ImportError("deepspeed version is wrong. Please "
                                  "install deepspeed>=0.14.2.")
            from deepspeed.ops.fp_quantizer import FP_Quantize
        except ImportError as err:
            raise ImportError("Please install deepspeed>=0.14.2 via "
                              "`pip install deepspeed>=0.14.2` to use "
                              "deepspeedfp quantizer.") from err
        data = torch.empty((
            orig_shape.numel() // quant_config.group_size,
            quant_config.group_size * quant_config.weight_bits // 8 + 4,
        ),
                           dtype=torch.int8)
        self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
        self.orig_shape = orig_shape
        self.quant_config = quant_config
        self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size)
        self.fp_quantizer.orig_shape = orig_shape
        self.fp_quantizer.orig_dtype = params_dtype
        return self

    def ds_quantize_(self, tensor: torch.Tensor):
        assert tensor.device.type == "cuda" and tensor.dtype != torch.int8
        return self.data.copy_(
            self.fp_quantizer.quantize(
                tensor.data,
                q_bits=self.quant_config.weight_bits,
            ))

    def ds_dequantize(self, fp_out=None) -> torch.Tensor:
        """
        Return a tensor containing the dequantized weights of this parameter.
        """
        assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
        return self.fp_quantizer.dequantize(
            self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits)

    def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor:
        """
        Return a tensor where only the weights at `indices` are dequantized
        (to save HBM -> SRAM bandwidth).
        """
        assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
        return self.fp_quantizer.selective_dequantize(
            self.data,
            indices,
            fp_out=fp_out,
            q_bits=self.quant_config.weight_bits)