awq.py 14 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Any, Optional, Union
5
6

import torch
zhuwenwen's avatar
zhuwenwen committed
7
import os
8
import torch.nn.functional as F
9
10
11
12
import vllm.envs as envs
import json
import math
from vllm.platforms import current_platform
13
from vllm import _custom_ops as ops
14
15
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
16
17
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
18
from vllm.model_executor.layers.quantization import QuantizationMethods
19
from vllm.model_executor.layers.quantization.base_config import (
20
    QuantizationConfig, QuantizeMethodBase)
21
22
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
                                           PackedvLLMParameter)
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
from vllm.logger import init_logger
logger = init_logger(__name__)
triton_configs_dict={}

def get_triton_cache(file_path):
        #会将所报错的json文件以字典的形式return出来
     
    if os.path.exists(file_path):
        with open(file_path, 'r') as file:
            cachedata = json.load(file)
                
    #把所有的cache解析成key:config的形式:[M_N_K]:[config]
    for key, value in cachedata.items():
        for sub_key, sub_value in value.items():
            configs_key= f"{sub_key}_{key}"
            configs_value={
                'SPLIT_K': int(sub_value["SPLIT_K"]),
                'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
                'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
                'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
                'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
                'num_stages':int(sub_value['num_stages']),
                'num_warps':int(sub_value['num_warps'])
            }
            if 'num_ldmatrixes' in sub_value:
                configs_value["num_ldmatrixes"] = int(sub_value['num_ldmatrixes'])
            triton_configs_dict[configs_key]=configs_value
    logger.info("%s have loaded!", file_path)

def default_execution(k,n):
    configs_key= f"1_{n}_{k}"
    if configs_key in triton_configs_dict:
        return
    script_dir = os.path.dirname(os.path.abspath(__file__))
    cache_json_file=f"{script_dir}/configs/awq/"
    device_name = current_platform.get_device_name().replace(" ", "_")
    filename = f"AWQ_{n}_{k}_{device_name}.json"
    file_full_path = os.path.join(cache_json_file, filename)

    if os.path.isfile(file_full_path) and file_full_path.endswith(".json"):
        # 如果是文件,则添加到列表
        get_triton_cache(file_full_path)
    return


def getspec_config(M,N,K):
70
71
72
73
74
75
76
77
    m_config = M
    if M > 16:
        # 直接计算 2 的幂
        m_config = 1
        while m_config < M:
            m_config *= 2
    if f"{m_config}_{N}_{K}" in triton_configs_dict:
        return triton_configs_dict[f"{m_config}_{N}_{K}"]
78
79
    else:
        return None  
80
81


82
83
84
85
86
87
88
89
90
91
class AWQShareWorkSpace:
    _instance = None
    
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super(AWQShareWorkSpace, cls).__new__(cls, *args, **kwargs)
            cls._instance._initialize()
        return cls._instance

    def _initialize(self):
zhuwenwen's avatar
zhuwenwen committed
92
93
        self.awqworkshapcesize = ops.GetAWQShareWorkspaceSize()
        self.awqworkshapce = ops.GetAWQShareWorkspace()
94

95
96
logger = init_logger(__name__)

97
98
99
100
101
102
103
104
105
106
107
108

class AWQConfig(QuantizationConfig):
    """Config class for AWQ.

    Reference: https://arxiv.org/abs/2306.00978
    """

    def __init__(
        self,
        weight_bits: int,
        group_size: int,
        zero_point: bool,
109
        modules_to_not_convert: Optional[list[str]] = None,
110
    ) -> None:
111
        super().__init__()
112
113
114
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.zero_point = zero_point
115
        self.modules_to_not_convert = modules_to_not_convert or []
116
117
118
119
120
121
122
123
124
125

        if self.weight_bits != 4:
            raise ValueError(
                "Currently, only 4-bit weight quantization is supported for "
                f"AWQ, but got {self.weight_bits} bits.")
        self.pack_factor = 32 // self.weight_bits

    def __repr__(self) -> str:
        return (f"AWQConfig(weight_bits={self.weight_bits}, "
                f"group_size={self.group_size}, "
126
127
                f"zero_point={self.zero_point}, "
                f"modules_to_not_convert={self.modules_to_not_convert})")
128

129
    def get_name(self) -> QuantizationMethods:
130
131
        return "awq"

132
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
zhuwenwen's avatar
zhuwenwen committed
133
        return [torch.half, torch.bfloat16]
134

135
136
    @classmethod
    def get_min_capability(cls) -> int:
137
138
139
140
        # The AWQ kernel only supports Turing or newer GPUs.
        return 75

    @staticmethod
141
    def get_config_filenames() -> list[str]:
142
143
        return [
            "quant_config.json",  # E.g., casperhansen/vicuna-7b-v1.5-awq
144
145
            # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
            "quantize_config.json",
146
147
148
        ]

    @classmethod
149
    def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
150
151
152
        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
        zero_point = cls.get_from_keys(config, ["zero_point"])
153
154
155
        modules_to_not_convert = cls.get_from_keys_or(
            config, ["modules_to_not_convert"], None)
        return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
156

157
158
159
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
160
        if isinstance(layer, LinearBase):
161
162
            if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
                return UnquantizedLinearMethod()
163
            return AWQLinearMethod(self)
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        elif isinstance(layer, FusedMoE):
            # Lazy import to avoid circular import.
            from .awq_marlin import AWQMarlinConfig, AWQMoEMethod
            from .moe_wna16 import MoeWNA16Config
            from .utils.marlin_utils import check_moe_marlin_supports_layer
            if not check_moe_marlin_supports_layer(layer, self.group_size):
                logger.warning_once(
                    f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
                    "Falling back to Moe WNA16 kernels.")
                config = {
                    "quant_method": "awq",
                    "bits": self.weight_bits,
                    "group_size": self.group_size,
                    "zero_point": self.zero_point,
                    "lm_head": False,
                }
                return MoeWNA16Config.from_config(config).get_quant_method(
                    layer, prefix)
            marlin_compatible_config_dict = {
                "quant_method": "awq",
                "bits": self.weight_bits,
                "group_size": self.group_size,
                "zero_point": self.zero_point,
                "lm_head": False,
                "modules_to_not_convert": self.modules_to_not_convert,
            }
            awq_marlin_config = AWQMarlinConfig.from_config(
                marlin_compatible_config_dict)
192
            return AWQMoEMethod(awq_marlin_config, layer.moe_config)
193
        return None
194
195


196
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
197
    return any(module_name in prefix for module_name in modules_to_not_convert)
198

199
200
201
202
203
204
205
206
207
208

class AWQLinearMethod(LinearMethodBase):
    """Linear method for AWQ.

    Args:
        quant_config: The AWQ quantization config.
    """

    def __init__(self, quant_config: AWQConfig):
        self.quant_config = quant_config
209
        self.awqsingleton= AWQShareWorkSpace()
zhuwenwen's avatar
zhuwenwen committed
210
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
211

212
213
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
214
                       output_partition_sizes: list[int], input_size: int,
215
216
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
217
218
219
220
221
222
223
        # Normalize group_size
        if self.quant_config.group_size != -1:
            group_size = self.quant_config.group_size
        else:
            group_size = input_size

        if input_size_per_partition % group_size != 0:
224
225
226
227
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")
James Fleming's avatar
James Fleming committed
228
229

        output_size_per_partition = sum(output_partition_sizes)
CHU Tianxiang's avatar
CHU Tianxiang committed
230
        if output_size_per_partition % self.quant_config.pack_factor != 0:
231
232
233
234
235
            raise ValueError(
                "The output size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")

236
237
238
        weight_loader = extra_weight_attrs.get("weight_loader")
        qweight = PackedvLLMParameter(
            data=torch.empty(
CHU Tianxiang's avatar
CHU Tianxiang committed
239
240
                input_size_per_partition,
                output_size_per_partition // self.quant_config.pack_factor,
241
242
                dtype=torch.int32,
            ),
243
244
245
246
247
248
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            weight_loader=weight_loader)

249
250
        num_groups = input_size_per_partition // group_size

251
252
        qzeros = PackedvLLMParameter(
            data=torch.empty(
253
                num_groups,
CHU Tianxiang's avatar
CHU Tianxiang committed
254
                output_size_per_partition // self.quant_config.pack_factor,
255
256
                dtype=torch.int32,
            ),
257
258
259
260
261
262
263
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            weight_loader=weight_loader)

        scales = GroupQuantScaleParameter(data=torch.empty(
264
            num_groups,
265
266
267
268
269
270
            output_size_per_partition,
            dtype=params_dtype,
        ),
                                          input_dim=0,
                                          output_dim=1,
                                          weight_loader=weight_loader)
gaoqiong's avatar
gaoqiong committed
271
        
272
273
274
        zeros_and_scales = GroupQuantScaleParameter(data=torch.empty(
            input_size_per_partition // self.quant_config.group_size,
            output_size_per_partition,
zhuwenwen's avatar
zhuwenwen committed
275
            dtype=torch.int32,
276
277
278
279
        ),
                                          input_dim=0,
                                          output_dim=1,
                                          weight_loader=weight_loader)
280
281
282
283

        layer.register_parameter("qweight", qweight)
        layer.register_parameter("qzeros", qzeros)
        layer.register_parameter("scales", scales)
gaoqiong's avatar
gaoqiong committed
284
        layer.register_parameter("zeros_and_scales", zeros_and_scales)
285
286
287
        # 加载triton_config
        if envs.VLLM_USE_TRITON_AWQ:
            default_execution(input_size_per_partition,output_size_per_partition)
288
289

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
290
291
292
293
294
295
        if not envs.VLLM_USE_TRITON_AWQ:
            
            group_size= self.quant_config.group_size 
            pad_group=2 
            dim_n = layer.scales.data.shape[1]
            dim_k = layer.qweight.data.shape[0]
zhuwenwen's avatar
zhuwenwen committed
296
            _qw, _sz=ops.convert_s4(layer.qweight,layer.qzeros,layer.scales.to(torch.float16),int(group_size)) 
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
            sz = ops.sz_permute(_sz).reshape(-1,dim_n)  
            sz = sz.reshape(dim_n,-1)
            _qw = _qw.reshape(dim_n,-1)
            
            if dim_k % 4096==0 and self.use_awq_pad:
                zeros_and_scalse_pad = torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                sz = torch.cat((sz,zeros_and_scalse_pad),dim=1).contiguous()
                qweight_pad = torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                _qw=torch.cat((_qw,qweight_pad),dim=1).contiguous()
                        
            layer.qweight = torch.nn.Parameter(_qw, requires_grad=False)
            layer.zeros_and_scales = torch.nn.Parameter(sz, requires_grad=False)
            layer.qzeros = None
            layer.scales = None
        else:

            layer.qweight = torch.nn.Parameter(layer.qweight.data,
                                            requires_grad=False)
            layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
                                            requires_grad=False)
            layer.scales = torch.nn.Parameter(layer.scales.data,
                                            requires_grad=False)
319

320
321
322
323
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
324
        qweight = layer.qweight
gaoqiong's avatar
gaoqiong committed
325
        zeros_and_scales = layer.zeros_and_scales
326
327
328
        qzeros = layer.qzeros
        scales = layer.scales
        pack_factor = self.quant_config.pack_factor   
gaoqiong's avatar
gaoqiong committed
329
        out_shape = (x.shape[:-1] + (qweight.shape[0] * 1, ))
330
        reshaped_x = x.reshape(-1, x.shape[-1])
gaoqiong's avatar
gaoqiong committed
331
332
333
334
335
        
        m = reshaped_x.shape[0]
        k = reshaped_x.shape[-1]
        n = qweight.shape[0]
        
zhuwenwen's avatar
zhuwenwen committed
336
337
338
339
340
        if self.use_awq_pad:
            if k % 4096 == 0:
                padding_group=2
            else:
                padding_group=0
341
        else:
gaoqiong's avatar
gaoqiong committed
342
            padding_group=0
343
344
345
346
347
348

        if envs.VLLM_USE_TRITON_AWQ:
            best_config=getspec_config(m,n,k)
            out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config)    
            out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, ))
        else:
349
            out = torch.ops.vllm.awq_gemm(reshaped_x,
gaoqiong's avatar
gaoqiong committed
350
351
352
353
354
355
356
                            qweight,
                            zeros_and_scales,
                            m,
                            n,
                            k,
                            self.quant_config.group_size,
                            padding_group,
357
358
                            self.awqsingleton.awqworkshapce,
                            self.awqsingleton.awqworkshapcesize)
gaoqiong's avatar
gaoqiong committed
359
        
360
        if bias is not None:
361
            out.add_(bias)
zhuwenwen's avatar
zhuwenwen committed
362
        return out.reshape(out_shape)