awq.py 6.29 KB
Newer Older
1
2
3
4
from typing import Any, Dict, List, Optional

import torch

5
from vllm import _custom_ops as ops
6
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
7
8
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
9
10
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
                                           PackedvLLMParameter)
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45


class AWQConfig(QuantizationConfig):
    """Config class for AWQ.

    Reference: https://arxiv.org/abs/2306.00978
    """

    def __init__(
        self,
        weight_bits: int,
        group_size: int,
        zero_point: bool,
    ) -> None:
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.zero_point = zero_point

        if self.weight_bits != 4:
            raise ValueError(
                "Currently, only 4-bit weight quantization is supported for "
                f"AWQ, but got {self.weight_bits} bits.")
        self.pack_factor = 32 // self.weight_bits

    def __repr__(self) -> str:
        return (f"AWQConfig(weight_bits={self.weight_bits}, "
                f"group_size={self.group_size}, "
                f"zero_point={self.zero_point})")

    def get_name(self) -> str:
        return "awq"

    def get_supported_act_dtypes(self) -> List[torch.dtype]:
        return [torch.half]

46
47
    @classmethod
    def get_min_capability(cls) -> int:
48
49
50
51
52
53
54
        # The AWQ kernel only supports Turing or newer GPUs.
        return 75

    @staticmethod
    def get_config_filenames() -> List[str]:
        return [
            "quant_config.json",  # E.g., casperhansen/vicuna-7b-v1.5-awq
55
56
            # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
            "quantize_config.json",
57
58
59
60
61
62
63
64
65
        ]

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
        zero_point = cls.get_from_keys(config, ["zero_point"])
        return cls(weight_bits, group_size, zero_point)

66
67
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["AWQLinearMethod"]:
68
69
70
        if isinstance(layer, LinearBase):
            return AWQLinearMethod(self)
        return None
71

72
73
74
    def get_scaled_act_names(self) -> List[str]:
        return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]

75
76
77
78
79
80
81
82
83
84
85

class AWQLinearMethod(LinearMethodBase):
    """Linear method for AWQ.

    Args:
        quant_config: The AWQ quantization config.
    """

    def __init__(self, quant_config: AWQConfig):
        self.quant_config = quant_config

86
87
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
88
                       output_partition_sizes: List[int], input_size: int,
89
90
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
CHU Tianxiang's avatar
CHU Tianxiang committed
91
        if input_size_per_partition % self.quant_config.group_size != 0:
92
93
94
95
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")
James Fleming's avatar
James Fleming committed
96
97

        output_size_per_partition = sum(output_partition_sizes)
CHU Tianxiang's avatar
CHU Tianxiang committed
98
        if output_size_per_partition % self.quant_config.pack_factor != 0:
99
100
101
102
103
            raise ValueError(
                "The output size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")

104
105
106
        weight_loader = extra_weight_attrs.get("weight_loader")
        qweight = PackedvLLMParameter(
            data=torch.empty(
CHU Tianxiang's avatar
CHU Tianxiang committed
107
108
                input_size_per_partition,
                output_size_per_partition // self.quant_config.pack_factor,
109
110
                dtype=torch.int32,
            ),
111
112
113
114
115
116
117
118
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            weight_loader=weight_loader)

        qzeros = PackedvLLMParameter(
            data=torch.empty(
CHU Tianxiang's avatar
CHU Tianxiang committed
119
120
                input_size_per_partition // self.quant_config.group_size,
                output_size_per_partition // self.quant_config.pack_factor,
121
122
                dtype=torch.int32,
            ),
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            weight_loader=weight_loader)

        scales = GroupQuantScaleParameter(data=torch.empty(
            input_size_per_partition // self.quant_config.group_size,
            output_size_per_partition,
            dtype=params_dtype,
        ),
                                          input_dim=0,
                                          output_dim=1,
                                          weight_loader=weight_loader)
137
138
139
140

        layer.register_parameter("qweight", qweight)
        layer.register_parameter("qzeros", qzeros)
        layer.register_parameter("scales", scales)
141
142
143
144
145
146
147
148

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        layer.qweight = torch.nn.Parameter(layer.qweight.data,
                                           requires_grad=False)
        layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
                                          requires_grad=False)
        layer.scales = torch.nn.Parameter(layer.scales.data,
                                          requires_grad=False)
149

150
151
152
153
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
154
155
156
        qweight = layer.qweight
        scales = layer.scales
        qzeros = layer.qzeros
157
158
159
        pack_factor = self.quant_config.pack_factor
        out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
        reshaped_x = x.reshape(-1, x.shape[-1])
160
161
162
163
164
165
166
167
168
169

        # num_tokens >= threshold
        FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256

        if FP16_MATMUL_HEURISTIC_CONDITION:
            out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
            out = torch.matmul(reshaped_x, out)
        else:
            out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
                               pack_factor)
170
        if bias is not None:
171
            out.add_(bias)
172
        return out.reshape(out_shape)