"vllm/vscode:/vscode.git/clone" did not exist on "e52e4da9714962b8db623359992ac3a5853879f7"
awq.py 8.58 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Any, Optional, Union
5
6
7

import torch

8
from vllm import _custom_ops as ops
9
10
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
11
12
13
14
15
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
16
from vllm.model_executor.layers.quantization import QuantizationMethods
17
from vllm.model_executor.layers.quantization.base_config import (
18
19
20
21
    QuantizationConfig,
    QuantizeMethodBase,
)
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
22

23
24
logger = init_logger(__name__)

25
26
27
28
29
30
31
32
33
34
35
36

class AWQConfig(QuantizationConfig):
    """Config class for AWQ.

    Reference: https://arxiv.org/abs/2306.00978
    """

    def __init__(
        self,
        weight_bits: int,
        group_size: int,
        zero_point: bool,
37
        modules_to_not_convert: Optional[list[str]] = None,
38
    ) -> None:
39
        super().__init__()
40
41
42
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.zero_point = zero_point
43
        self.modules_to_not_convert = modules_to_not_convert or []
44
45
46
47

        if self.weight_bits != 4:
            raise ValueError(
                "Currently, only 4-bit weight quantization is supported for "
48
49
                f"AWQ, but got {self.weight_bits} bits."
            )
50
51
52
        self.pack_factor = 32 // self.weight_bits

    def __repr__(self) -> str:
53
54
55
56
57
58
        return (
            f"AWQConfig(weight_bits={self.weight_bits}, "
            f"group_size={self.group_size}, "
            f"zero_point={self.zero_point}, "
            f"modules_to_not_convert={self.modules_to_not_convert})"
        )
59

60
    def get_name(self) -> QuantizationMethods:
61
62
        return "awq"

63
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
64
65
        return [torch.half]

66
67
    @classmethod
    def get_min_capability(cls) -> int:
68
69
70
71
        # The AWQ kernel only supports Turing or newer GPUs.
        return 75

    @staticmethod
72
    def get_config_filenames() -> list[str]:
73
74
        return [
            "quant_config.json",  # E.g., casperhansen/vicuna-7b-v1.5-awq
75
76
            # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
            "quantize_config.json",
77
78
79
        ]

    @classmethod
80
    def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
81
82
83
        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
        zero_point = cls.get_from_keys(config, ["zero_point"])
84
        modules_to_not_convert = cls.get_from_keys_or(
85
86
            config, ["modules_to_not_convert"], None
        )
87
        return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
88

89
90
91
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
92
        if isinstance(layer, LinearBase):
93
94
            if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
                return UnquantizedLinearMethod()
95
            return AWQLinearMethod(self)
96
97
98
99
100
        elif isinstance(layer, FusedMoE):
            # Lazy import to avoid circular import.
            from .awq_marlin import AWQMarlinConfig, AWQMoEMethod
            from .moe_wna16 import MoeWNA16Config
            from .utils.marlin_utils import check_moe_marlin_supports_layer
101

102
103
104
            if not check_moe_marlin_supports_layer(layer, self.group_size):
                logger.warning_once(
                    f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
105
106
                    "Falling back to Moe WNA16 kernels."
                )
107
108
109
110
111
112
113
114
                config = {
                    "quant_method": "awq",
                    "bits": self.weight_bits,
                    "group_size": self.group_size,
                    "zero_point": self.zero_point,
                    "lm_head": False,
                }
                return MoeWNA16Config.from_config(config).get_quant_method(
115
116
                    layer, prefix
                )
117
118
119
120
121
122
123
124
125
            marlin_compatible_config_dict = {
                "quant_method": "awq",
                "bits": self.weight_bits,
                "group_size": self.group_size,
                "zero_point": self.zero_point,
                "lm_head": False,
                "modules_to_not_convert": self.modules_to_not_convert,
            }
            awq_marlin_config = AWQMarlinConfig.from_config(
126
127
                marlin_compatible_config_dict
            )
128
            return AWQMoEMethod(awq_marlin_config, layer.moe_config)
129
        return None
130
131


132
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
133
134
135
    return any(module_name in prefix for module_name in modules_to_not_convert)


136
137
138
139
140
141
142
143
144
145
class AWQLinearMethod(LinearMethodBase):
    """Linear method for AWQ.

    Args:
        quant_config: The AWQ quantization config.
    """

    def __init__(self, quant_config: AWQConfig):
        self.quant_config = quant_config

146
147
148
149
150
151
152
153
154
155
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
156
157
158
159
160
161
162
        # Normalize group_size
        if self.quant_config.group_size != -1:
            group_size = self.quant_config.group_size
        else:
            group_size = input_size

        if input_size_per_partition % group_size != 0:
163
164
165
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
166
167
                "tensor parallel size."
            )
James Fleming's avatar
James Fleming committed
168
169

        output_size_per_partition = sum(output_partition_sizes)
CHU Tianxiang's avatar
CHU Tianxiang committed
170
        if output_size_per_partition % self.quant_config.pack_factor != 0:
171
172
173
            raise ValueError(
                "The output size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
174
175
                "tensor parallel size."
            )
176

177
178
179
        weight_loader = extra_weight_attrs.get("weight_loader")
        qweight = PackedvLLMParameter(
            data=torch.empty(
CHU Tianxiang's avatar
CHU Tianxiang committed
180
181
                input_size_per_partition,
                output_size_per_partition // self.quant_config.pack_factor,
182
183
                dtype=torch.int32,
            ),
184
185
186
187
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
188
189
            weight_loader=weight_loader,
        )
190

191
192
        num_groups = input_size_per_partition // group_size

193
194
        qzeros = PackedvLLMParameter(
            data=torch.empty(
195
                num_groups,
CHU Tianxiang's avatar
CHU Tianxiang committed
196
                output_size_per_partition // self.quant_config.pack_factor,
197
198
                dtype=torch.int32,
            ),
199
200
201
202
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
203
204
            weight_loader=weight_loader,
        )
205

206
207
208
209
210
211
212
213
214
215
        scales = GroupQuantScaleParameter(
            data=torch.empty(
                num_groups,
                output_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=0,
            output_dim=1,
            weight_loader=weight_loader,
        )
216
217
218
219

        layer.register_parameter("qweight", qweight)
        layer.register_parameter("qzeros", qzeros)
        layer.register_parameter("scales", scales)
220
221

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
222
223
224
225
226
227
228
229
230
231
        layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
        layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
        layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
232
233
234
        qweight = layer.qweight
        scales = layer.scales
        qzeros = layer.qzeros
235
        pack_factor = self.quant_config.pack_factor
236
        out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
237
        reshaped_x = x.reshape(-1, x.shape[-1])
238
239
240
241
242
243
244
245

        # num_tokens >= threshold
        FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256

        if FP16_MATMUL_HEURISTIC_CONDITION:
            out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
            out = torch.matmul(reshaped_x, out)
        else:
246
            out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
247
        if bias is not None:
248
            out.add_(bias)
249
        return out.reshape(out_shape)