awq.py 8.74 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Any, Optional, Union
5
6
7

import torch

8
from vllm import _custom_ops as ops
9
10
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
11
12
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
13
from vllm.model_executor.layers.quantization import QuantizationMethods
14
from vllm.model_executor.layers.quantization.base_config import (
15
    QuantizationConfig, QuantizeMethodBase)
16
17
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
                                           PackedvLLMParameter)
18

19
20
logger = init_logger(__name__)

21
22
23
24
25
26
27
28
29
30
31
32

class AWQConfig(QuantizationConfig):
    """Config class for AWQ.

    Reference: https://arxiv.org/abs/2306.00978
    """

    def __init__(
        self,
        weight_bits: int,
        group_size: int,
        zero_point: bool,
33
        modules_to_not_convert: Optional[list[str]] = None,
34
    ) -> None:
35
        super().__init__()
36
37
38
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.zero_point = zero_point
39
        self.modules_to_not_convert = modules_to_not_convert or []
40
41
42
43
44
45
46
47
48
49

        if self.weight_bits != 4:
            raise ValueError(
                "Currently, only 4-bit weight quantization is supported for "
                f"AWQ, but got {self.weight_bits} bits.")
        self.pack_factor = 32 // self.weight_bits

    def __repr__(self) -> str:
        return (f"AWQConfig(weight_bits={self.weight_bits}, "
                f"group_size={self.group_size}, "
50
51
                f"zero_point={self.zero_point}, "
                f"modules_to_not_convert={self.modules_to_not_convert})")
52

53
    def get_name(self) -> QuantizationMethods:
54
55
        return "awq"

56
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
57
58
        return [torch.half]

59
60
    @classmethod
    def get_min_capability(cls) -> int:
61
62
63
64
        # The AWQ kernel only supports Turing or newer GPUs.
        return 75

    @staticmethod
65
    def get_config_filenames() -> list[str]:
66
67
        return [
            "quant_config.json",  # E.g., casperhansen/vicuna-7b-v1.5-awq
68
69
            # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
            "quantize_config.json",
70
71
72
        ]

    @classmethod
73
    def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
74
75
76
        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
        zero_point = cls.get_from_keys(config, ["zero_point"])
77
78
79
        modules_to_not_convert = cls.get_from_keys_or(
            config, ["modules_to_not_convert"], None)
        return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
80

81
82
83
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
84
        if isinstance(layer, LinearBase):
85
86
            if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
                return UnquantizedLinearMethod()
87
            return AWQLinearMethod(self)
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        elif isinstance(layer, FusedMoE):
            # Lazy import to avoid circular import.
            from .awq_marlin import AWQMarlinConfig, AWQMoEMethod
            from .moe_wna16 import MoeWNA16Config
            from .utils.marlin_utils import check_moe_marlin_supports_layer
            if not check_moe_marlin_supports_layer(layer, self.group_size):
                logger.warning_once(
                    f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
                    "Falling back to Moe WNA16 kernels.")
                config = {
                    "quant_method": "awq",
                    "bits": self.weight_bits,
                    "group_size": self.group_size,
                    "zero_point": self.zero_point,
                    "lm_head": False,
                }
                return MoeWNA16Config.from_config(config).get_quant_method(
                    layer, prefix)
            marlin_compatible_config_dict = {
                "quant_method": "awq",
                "bits": self.weight_bits,
                "group_size": self.group_size,
                "zero_point": self.zero_point,
                "lm_head": False,
                "modules_to_not_convert": self.modules_to_not_convert,
            }
            awq_marlin_config = AWQMarlinConfig.from_config(
                marlin_compatible_config_dict)
            return AWQMoEMethod(awq_marlin_config)
117
        return None
118
119


120
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
121
122
123
    return any(module_name in prefix for module_name in modules_to_not_convert)


124
125
126
127
128
129
130
131
132
133
class AWQLinearMethod(LinearMethodBase):
    """Linear method for AWQ.

    Args:
        quant_config: The AWQ quantization config.
    """

    def __init__(self, quant_config: AWQConfig):
        self.quant_config = quant_config

134
135
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
136
                       output_partition_sizes: list[int], input_size: int,
137
138
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
139
140
141
142
143
144
145
        # Normalize group_size
        if self.quant_config.group_size != -1:
            group_size = self.quant_config.group_size
        else:
            group_size = input_size

        if input_size_per_partition % group_size != 0:
146
147
148
149
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")
James Fleming's avatar
James Fleming committed
150
151

        output_size_per_partition = sum(output_partition_sizes)
CHU Tianxiang's avatar
CHU Tianxiang committed
152
        if output_size_per_partition % self.quant_config.pack_factor != 0:
153
154
155
156
157
            raise ValueError(
                "The output size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")

158
159
160
        weight_loader = extra_weight_attrs.get("weight_loader")
        qweight = PackedvLLMParameter(
            data=torch.empty(
CHU Tianxiang's avatar
CHU Tianxiang committed
161
162
                input_size_per_partition,
                output_size_per_partition // self.quant_config.pack_factor,
163
164
                dtype=torch.int32,
            ),
165
166
167
168
169
170
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            weight_loader=weight_loader)

171
172
        num_groups = input_size_per_partition // group_size

173
174
        qzeros = PackedvLLMParameter(
            data=torch.empty(
175
                num_groups,
CHU Tianxiang's avatar
CHU Tianxiang committed
176
                output_size_per_partition // self.quant_config.pack_factor,
177
178
                dtype=torch.int32,
            ),
179
180
181
182
183
184
185
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            weight_loader=weight_loader)

        scales = GroupQuantScaleParameter(data=torch.empty(
186
            num_groups,
187
188
189
190
191
192
            output_size_per_partition,
            dtype=params_dtype,
        ),
                                          input_dim=0,
                                          output_dim=1,
                                          weight_loader=weight_loader)
193
194
195
196

        layer.register_parameter("qweight", qweight)
        layer.register_parameter("qzeros", qzeros)
        layer.register_parameter("scales", scales)
197
198
199
200
201
202
203
204

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        layer.qweight = torch.nn.Parameter(layer.qweight.data,
                                           requires_grad=False)
        layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
                                          requires_grad=False)
        layer.scales = torch.nn.Parameter(layer.scales.data,
                                          requires_grad=False)
205

206
207
208
209
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
210
211
212
        qweight = layer.qweight
        scales = layer.scales
        qzeros = layer.qzeros
213
214
215
        pack_factor = self.quant_config.pack_factor
        out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
        reshaped_x = x.reshape(-1, x.shape[-1])
216
217
218
219
220
221
222
223
224
225

        # num_tokens >= threshold
        FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256

        if FP16_MATMUL_HEURISTIC_CONDITION:
            out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
            out = torch.matmul(reshaped_x, out)
        else:
            out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
                               pack_factor)
226
        if bias is not None:
227
            out.add_(bias)
228
        return out.reshape(out_shape)