awq.py 11 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
from typing import Any, Dict, List, Optional

import torch
zhuwenwen's avatar
zhuwenwen committed
6
import os
7
import torch.nn.functional as F
8
9
10
11
import vllm.envs as envs
import json
import math
from vllm.platforms import current_platform
12
from vllm import _custom_ops as ops
13
14
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
15
16
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
17
18
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
                                           PackedvLLMParameter)
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
from vllm.logger import init_logger
logger = init_logger(__name__)
triton_configs_dict={}

def get_triton_cache(file_path):
        #会将所报错的json文件以字典的形式return出来
     
    if os.path.exists(file_path):
        with open(file_path, 'r') as file:
            cachedata = json.load(file)
                
    #把所有的cache解析成key:config的形式:[M_N_K]:[config]
    for key, value in cachedata.items():
        for sub_key, sub_value in value.items():
            configs_key= f"{sub_key}_{key}"
            configs_value={
                'SPLIT_K': int(sub_value["SPLIT_K"]),
                'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
                'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
                'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
                'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
                'num_stages':int(sub_value['num_stages']),
                'num_warps':int(sub_value['num_warps'])
            }
            if 'num_ldmatrixes' in sub_value:
                configs_value["num_ldmatrixes"] = int(sub_value['num_ldmatrixes'])
            triton_configs_dict[configs_key]=configs_value
    logger.info("%s have loaded!", file_path)

def default_execution(k,n):
    configs_key= f"1_{n}_{k}"
    if configs_key in triton_configs_dict:
        return
    script_dir = os.path.dirname(os.path.abspath(__file__))
    cache_json_file=f"{script_dir}/configs/awq/"
    device_name = current_platform.get_device_name().replace(" ", "_")
    filename = f"AWQ_{n}_{k}_{device_name}.json"
    file_full_path = os.path.join(cache_json_file, filename)

    if os.path.isfile(file_full_path) and file_full_path.endswith(".json"):
        # 如果是文件,则添加到列表
        get_triton_cache(file_full_path)
    return


def getspec_config(M,N,K):
    if f"{M}_{N}_{K}" in triton_configs_dict:
        return triton_configs_dict[f"{M}_{N}_{K}"]
    else:
        return None  
70
71


72
73
74
75
76
77
78
79
80
81
class AWQShareWorkSpace:
    _instance = None
    
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super(AWQShareWorkSpace, cls).__new__(cls, *args, **kwargs)
            cls._instance._initialize()
        return cls._instance

    def _initialize(self):
zhuwenwen's avatar
zhuwenwen committed
82
83
        self.awqworkshapcesize = ops.GetAWQShareWorkspaceSize()
        self.awqworkshapce = ops.GetAWQShareWorkspace()
84
85
86
87
88
89
90
91
92
93
94
95
96


class AWQConfig(QuantizationConfig):
    """Config class for AWQ.

    Reference: https://arxiv.org/abs/2306.00978
    """

    def __init__(
        self,
        weight_bits: int,
        group_size: int,
        zero_point: bool,
97
        modules_to_not_convert: Optional[List[str]] = None,
98
    ) -> None:
99
        super().__init__()
100
101
102
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.zero_point = zero_point
103
        self.modules_to_not_convert = modules_to_not_convert or []
104
105
106
107
108
109
110
111
112
113

        if self.weight_bits != 4:
            raise ValueError(
                "Currently, only 4-bit weight quantization is supported for "
                f"AWQ, but got {self.weight_bits} bits.")
        self.pack_factor = 32 // self.weight_bits

    def __repr__(self) -> str:
        return (f"AWQConfig(weight_bits={self.weight_bits}, "
                f"group_size={self.group_size}, "
114
115
                f"zero_point={self.zero_point}, "
                f"modules_to_not_convert={self.modules_to_not_convert})")
116
117
118
119
120
121
122

    def get_name(self) -> str:
        return "awq"

    def get_supported_act_dtypes(self) -> List[torch.dtype]:
        return [torch.half]

123
124
    @classmethod
    def get_min_capability(cls) -> int:
125
126
127
128
129
130
131
        # The AWQ kernel only supports Turing or newer GPUs.
        return 75

    @staticmethod
    def get_config_filenames() -> List[str]:
        return [
            "quant_config.json",  # E.g., casperhansen/vicuna-7b-v1.5-awq
132
133
            # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
            "quantize_config.json",
134
135
136
137
138
139
140
        ]

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
        zero_point = cls.get_from_keys(config, ["zero_point"])
141
142
143
        modules_to_not_convert = cls.get_from_keys_or(
            config, ["modules_to_not_convert"], None)
        return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
144

145
    def get_quant_method(self, layer: torch.nn.Module,
146
                         prefix: str) -> Optional["LinearMethodBase"]:
147
        if isinstance(layer, LinearBase):
148
149
            if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
                return UnquantizedLinearMethod()
150
151
            return AWQLinearMethod(self)
        return None
152
153


154
155
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
    return any(module_name in prefix for module_name in modules_to_not_convert)
156

157
158
159
160
161
162
163
164
165
166

class AWQLinearMethod(LinearMethodBase):
    """Linear method for AWQ.

    Args:
        quant_config: The AWQ quantization config.
    """

    def __init__(self, quant_config: AWQConfig):
        self.quant_config = quant_config
167
        self.awqsingleton= AWQShareWorkSpace()
zhuwenwen's avatar
zhuwenwen committed
168
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
169

170
171
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
172
                       output_partition_sizes: List[int], input_size: int,
173
174
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
CHU Tianxiang's avatar
CHU Tianxiang committed
175
        if input_size_per_partition % self.quant_config.group_size != 0:
176
177
178
179
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")
James Fleming's avatar
James Fleming committed
180
181

        output_size_per_partition = sum(output_partition_sizes)
CHU Tianxiang's avatar
CHU Tianxiang committed
182
        if output_size_per_partition % self.quant_config.pack_factor != 0:
183
184
185
186
187
            raise ValueError(
                "The output size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")

188
189
190
        weight_loader = extra_weight_attrs.get("weight_loader")
        qweight = PackedvLLMParameter(
            data=torch.empty(
CHU Tianxiang's avatar
CHU Tianxiang committed
191
192
                input_size_per_partition,
                output_size_per_partition // self.quant_config.pack_factor,
193
194
                dtype=torch.int32,
            ),
195
196
197
198
199
200
201
202
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            weight_loader=weight_loader)

        qzeros = PackedvLLMParameter(
            data=torch.empty(
CHU Tianxiang's avatar
CHU Tianxiang committed
203
204
                input_size_per_partition // self.quant_config.group_size,
                output_size_per_partition // self.quant_config.pack_factor,
205
206
                dtype=torch.int32,
            ),
207
208
209
210
211
212
213
214
215
216
217
218
219
220
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            weight_loader=weight_loader)

        scales = GroupQuantScaleParameter(data=torch.empty(
            input_size_per_partition // self.quant_config.group_size,
            output_size_per_partition,
            dtype=params_dtype,
        ),
                                          input_dim=0,
                                          output_dim=1,
                                          weight_loader=weight_loader)
gaoqiong's avatar
gaoqiong committed
221
        
222
223
224
        zeros_and_scales = GroupQuantScaleParameter(data=torch.empty(
            input_size_per_partition // self.quant_config.group_size,
            output_size_per_partition,
zhuwenwen's avatar
zhuwenwen committed
225
            dtype=torch.int32,
226
227
228
229
        ),
                                          input_dim=0,
                                          output_dim=1,
                                          weight_loader=weight_loader)
230
231
232
233

        layer.register_parameter("qweight", qweight)
        layer.register_parameter("qzeros", qzeros)
        layer.register_parameter("scales", scales)
gaoqiong's avatar
gaoqiong committed
234
        layer.register_parameter("zeros_and_scales", zeros_and_scales)
235
236
237
        # 加载triton_config
        if envs.VLLM_USE_TRITON_AWQ:
            default_execution(input_size_per_partition,output_size_per_partition)
238
239
240
241
242
243
244
245

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        layer.qweight = torch.nn.Parameter(layer.qweight.data,
                                           requires_grad=False)
        layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
                                          requires_grad=False)
        layer.scales = torch.nn.Parameter(layer.scales.data,
                                          requires_grad=False)
246
247
        layer.zeros_and_scales = torch.nn.Parameter(layer.zeros_and_scales.data,
                                          requires_grad=False)
248

249
250
251
252
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
253
        qweight = layer.qweight
gaoqiong's avatar
gaoqiong committed
254
        zeros_and_scales = layer.zeros_and_scales
255
256
257
        qzeros = layer.qzeros
        scales = layer.scales
        pack_factor = self.quant_config.pack_factor   
gaoqiong's avatar
gaoqiong committed
258
        out_shape = (x.shape[:-1] + (qweight.shape[0] * 1, ))
259
        reshaped_x = x.reshape(-1, x.shape[-1])
gaoqiong's avatar
gaoqiong committed
260
261
262
263
264
        
        m = reshaped_x.shape[0]
        k = reshaped_x.shape[-1]
        n = qweight.shape[0]
        
zhuwenwen's avatar
zhuwenwen committed
265
266
267
268
269
        if self.use_awq_pad:
            if k % 4096 == 0:
                padding_group=2
            else:
                padding_group=0
270
        else:
gaoqiong's avatar
gaoqiong committed
271
            padding_group=0
272
273
274
275
276
277
278
279
280

        if envs.VLLM_USE_TRITON_AWQ:

            if m>16:
                m = 2 ** math.ceil(math.log2(m))
            best_config=getspec_config(m,n,k)
            out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config)    
            out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, ))
        else:
zhuwenwen's avatar
zhuwenwen committed
281
            out = ops.awq_gemm(reshaped_x,
gaoqiong's avatar
gaoqiong committed
282
283
284
285
286
287
288
                            qweight,
                            zeros_and_scales,
                            m,
                            n,
                            k,
                            self.quant_config.group_size,
                            padding_group,
289
290
                            self.awqsingleton.awqworkshapce,
                            self.awqsingleton.awqworkshapcesize)
gaoqiong's avatar
gaoqiong committed
291
        
292
        if bias is not None:
293
            out.add_(bias)
zhuwenwen's avatar
zhuwenwen committed
294
        return out.reshape(out_shape)