awq.py 14.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Any, Optional, Union
5
6

import torch
zhuwenwen's avatar
zhuwenwen committed
7
import os
8
import torch.nn.functional as F
9
10
11
12
import vllm.envs as envs
import json
import math
from vllm.platforms import current_platform
13
from vllm import _custom_ops as ops
14
15
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
16
17
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
18
from vllm.model_executor.layers.quantization import QuantizationMethods
19
from vllm.model_executor.layers.quantization.base_config import (
20
    QuantizationConfig, QuantizeMethodBase)
21
22
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
                                           PackedvLLMParameter)
23
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
24
import lightop
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from vllm.logger import init_logger
logger = init_logger(__name__)
triton_configs_dict={}

def get_triton_cache(file_path):
        #会将所报错的json文件以字典的形式return出来
     
    if os.path.exists(file_path):
        with open(file_path, 'r') as file:
            cachedata = json.load(file)
                
    #把所有的cache解析成key:config的形式:[M_N_K]:[config]
    for key, value in cachedata.items():
        for sub_key, sub_value in value.items():
            configs_key= f"{sub_key}_{key}"
            configs_value={
                'SPLIT_K': int(sub_value["SPLIT_K"]),
                'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
                'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
                'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
                'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
                'num_stages':int(sub_value['num_stages']),
                'num_warps':int(sub_value['num_warps'])
            }
            if 'num_ldmatrixes' in sub_value:
                configs_value["num_ldmatrixes"] = int(sub_value['num_ldmatrixes'])
            triton_configs_dict[configs_key]=configs_value
    logger.info("%s have loaded!", file_path)

def default_execution(k,n):
    configs_key= f"1_{n}_{k}"
    if configs_key in triton_configs_dict:
        return
    script_dir = os.path.dirname(os.path.abspath(__file__))
    cache_json_file=f"{script_dir}/configs/awq/"
    device_name = current_platform.get_device_name().replace(" ", "_")
    filename = f"AWQ_{n}_{k}_{device_name}.json"
    file_full_path = os.path.join(cache_json_file, filename)

    if os.path.isfile(file_full_path) and file_full_path.endswith(".json"):
        # 如果是文件,则添加到列表
        get_triton_cache(file_full_path)
    return


def getspec_config(M,N,K):
71
72
73
74
75
76
77
78
    m_config = M
    if M > 16:
        # 直接计算 2 的幂
        m_config = 1
        while m_config < M:
            m_config *= 2
    if f"{m_config}_{N}_{K}" in triton_configs_dict:
        return triton_configs_dict[f"{m_config}_{N}_{K}"]
79
80
    else:
        return None  
81
82


83
84
85
86
87
88
89
90
91
92
class AWQShareWorkSpace:
    _instance = None
    
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super(AWQShareWorkSpace, cls).__new__(cls, *args, **kwargs)
            cls._instance._initialize()
        return cls._instance

    def _initialize(self):
zhuwenwen's avatar
zhuwenwen committed
93
94
        self.awqworkshapcesize = ops.GetAWQShareWorkspaceSize()
        self.awqworkshapce = ops.GetAWQShareWorkspace()
95

96
97
logger = init_logger(__name__)

98
99
100
101
102
103
104
105
106
107
108
109

class AWQConfig(QuantizationConfig):
    """Config class for AWQ.

    Reference: https://arxiv.org/abs/2306.00978
    """

    def __init__(
        self,
        weight_bits: int,
        group_size: int,
        zero_point: bool,
110
        modules_to_not_convert: Optional[list[str]] = None,
111
    ) -> None:
112
        super().__init__()
113
114
115
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.zero_point = zero_point
116
        self.modules_to_not_convert = modules_to_not_convert or []
117
118
119
120
121
122
123
124
125
126

        if self.weight_bits != 4:
            raise ValueError(
                "Currently, only 4-bit weight quantization is supported for "
                f"AWQ, but got {self.weight_bits} bits.")
        self.pack_factor = 32 // self.weight_bits

    def __repr__(self) -> str:
        return (f"AWQConfig(weight_bits={self.weight_bits}, "
                f"group_size={self.group_size}, "
127
128
                f"zero_point={self.zero_point}, "
                f"modules_to_not_convert={self.modules_to_not_convert})")
129

130
    def get_name(self) -> QuantizationMethods:
131
132
        return "awq"

133
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
zhuwenwen's avatar
zhuwenwen committed
134
        return [torch.half, torch.bfloat16]
135

136
137
    @classmethod
    def get_min_capability(cls) -> int:
138
139
140
141
        # The AWQ kernel only supports Turing or newer GPUs.
        return 75

    @staticmethod
142
    def get_config_filenames() -> list[str]:
143
144
        return [
            "quant_config.json",  # E.g., casperhansen/vicuna-7b-v1.5-awq
145
146
            # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
            "quantize_config.json",
147
148
149
        ]

    @classmethod
150
    def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
151
152
153
        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
        zero_point = cls.get_from_keys(config, ["zero_point"])
154
155
156
        modules_to_not_convert = cls.get_from_keys_or(
            config, ["modules_to_not_convert"], None)
        return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
157

158
159
160
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
161
        if isinstance(layer, LinearBase):
162
163
            if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
                return UnquantizedLinearMethod()
164
            return AWQLinearMethod(self)
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        elif isinstance(layer, FusedMoE):
            # Lazy import to avoid circular import.
            from .awq_marlin import AWQMarlinConfig, AWQMoEMethod
            from .moe_wna16 import MoeWNA16Config
            from .utils.marlin_utils import check_moe_marlin_supports_layer
            if not check_moe_marlin_supports_layer(layer, self.group_size):
                logger.warning_once(
                    f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
                    "Falling back to Moe WNA16 kernels.")
                config = {
                    "quant_method": "awq",
                    "bits": self.weight_bits,
                    "group_size": self.group_size,
                    "zero_point": self.zero_point,
                    "lm_head": False,
                }
                return MoeWNA16Config.from_config(config).get_quant_method(
                    layer, prefix)
            marlin_compatible_config_dict = {
                "quant_method": "awq",
                "bits": self.weight_bits,
                "group_size": self.group_size,
                "zero_point": self.zero_point,
                "lm_head": False,
                "modules_to_not_convert": self.modules_to_not_convert,
            }
            awq_marlin_config = AWQMarlinConfig.from_config(
                marlin_compatible_config_dict)
193
            return AWQMoEMethod(awq_marlin_config, layer.moe_config)
194
        return None
195
196


197
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
198
    return any(module_name in prefix for module_name in modules_to_not_convert)
199

200
201
202
203
204
205
206
207
208

class AWQLinearMethod(LinearMethodBase):
    """Linear method for AWQ.

    Args:
        quant_config: The AWQ quantization config.
    """

    def __init__(self, quant_config: AWQConfig):
209
210
        if not envs.AWQ_GEMM_MARLIN and not envs.VLLM_USE_TRITON_AWQ:
            self.awqsingleton= AWQShareWorkSpace()
211
        self.quant_config = quant_config
zhuwenwen's avatar
zhuwenwen committed
212
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
213

214
215
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
216
                       output_partition_sizes: list[int], input_size: int,
217
218
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
219
220
221
222
223
224
225
        # Normalize group_size
        if self.quant_config.group_size != -1:
            group_size = self.quant_config.group_size
        else:
            group_size = input_size

        if input_size_per_partition % group_size != 0:
226
227
228
229
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")
James Fleming's avatar
James Fleming committed
230
231

        output_size_per_partition = sum(output_partition_sizes)
CHU Tianxiang's avatar
CHU Tianxiang committed
232
        if output_size_per_partition % self.quant_config.pack_factor != 0:
233
234
235
236
237
            raise ValueError(
                "The output size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")

238
239
240
        weight_loader = extra_weight_attrs.get("weight_loader")
        qweight = PackedvLLMParameter(
            data=torch.empty(
CHU Tianxiang's avatar
CHU Tianxiang committed
241
242
                input_size_per_partition,
                output_size_per_partition // self.quant_config.pack_factor,
243
244
                dtype=torch.int32,
            ),
245
246
247
248
249
250
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            weight_loader=weight_loader)

251
252
        num_groups = input_size_per_partition // group_size

253
254
        qzeros = PackedvLLMParameter(
            data=torch.empty(
255
                num_groups,
CHU Tianxiang's avatar
CHU Tianxiang committed
256
                output_size_per_partition // self.quant_config.pack_factor,
257
258
                dtype=torch.int32,
            ),
259
260
261
262
263
264
265
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            weight_loader=weight_loader)

        scales = GroupQuantScaleParameter(data=torch.empty(
266
            num_groups,
267
268
269
270
271
272
            output_size_per_partition,
            dtype=params_dtype,
        ),
                                          input_dim=0,
                                          output_dim=1,
                                          weight_loader=weight_loader)
gaoqiong's avatar
gaoqiong committed
273
        
274
275
276
        zeros_and_scales = GroupQuantScaleParameter(data=torch.empty(
            input_size_per_partition // self.quant_config.group_size,
            output_size_per_partition,
zhuwenwen's avatar
zhuwenwen committed
277
            dtype=torch.int32,
278
279
280
281
        ),
                                          input_dim=0,
                                          output_dim=1,
                                          weight_loader=weight_loader)
282
283
284
285

        layer.register_parameter("qweight", qweight)
        layer.register_parameter("qzeros", qzeros)
        layer.register_parameter("scales", scales)
gaoqiong's avatar
gaoqiong committed
286
        layer.register_parameter("zeros_and_scales", zeros_and_scales)
287
288
289
        # 加载triton_config
        if envs.VLLM_USE_TRITON_AWQ:
            default_execution(input_size_per_partition,output_size_per_partition)
290
291

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
292
293
294
295
296
297
        if not envs.VLLM_USE_TRITON_AWQ:
            
            group_size= self.quant_config.group_size 
            pad_group=2 
            dim_n = layer.scales.data.shape[1]
            dim_k = layer.qweight.data.shape[0]
zhuwenwen's avatar
zhuwenwen committed
298
            _qw, _sz=ops.convert_s4(layer.qweight,layer.qzeros,layer.scales.to(torch.float16),int(group_size)) 
299
300
301
302
303
304
305
306
307
            sz = ops.sz_permute(_sz).reshape(-1,dim_n)  
            sz = sz.reshape(dim_n,-1)
            _qw = _qw.reshape(dim_n,-1)
            
            if dim_k % 4096==0 and self.use_awq_pad:
                zeros_and_scalse_pad = torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                sz = torch.cat((sz,zeros_and_scalse_pad),dim=1).contiguous()
                qweight_pad = torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                _qw=torch.cat((_qw,qweight_pad),dim=1).contiguous()
308
309
310
311

            if envs.AWQ_GEMM_MARLIN: 
                _qw =torch.ops.vllm.awq_gemm_marlin_weight_repack(_qw, dim_n, dim_k)

312
313
314
315
316
317
318
319
320
321
322
323
            layer.qweight = torch.nn.Parameter(_qw, requires_grad=False)
            layer.zeros_and_scales = torch.nn.Parameter(sz, requires_grad=False)
            layer.qzeros = None
            layer.scales = None
        else:

            layer.qweight = torch.nn.Parameter(layer.qweight.data,
                                            requires_grad=False)
            layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
                                            requires_grad=False)
            layer.scales = torch.nn.Parameter(layer.scales.data,
                                            requires_grad=False)
324

325
326
327
328
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
329
        qweight = layer.qweight
gaoqiong's avatar
gaoqiong committed
330
        zeros_and_scales = layer.zeros_and_scales
331
332
333
        qzeros = layer.qzeros
        scales = layer.scales
        pack_factor = self.quant_config.pack_factor   
gaoqiong's avatar
gaoqiong committed
334
        
335
        reshaped_x = x.reshape(-1, x.shape[-1])
gaoqiong's avatar
gaoqiong committed
336
337
        m = reshaped_x.shape[0]
        k = reshaped_x.shape[-1]
338
339
340
        n = layer.output_size        
        out_shape = (x.shape[:-1] + (n, ))

gaoqiong's avatar
gaoqiong committed
341
        
zhuwenwen's avatar
zhuwenwen committed
342
343
344
345
346
        if self.use_awq_pad:
            if k % 4096 == 0:
                padding_group=2
            else:
                padding_group=0
347
        else:
gaoqiong's avatar
gaoqiong committed
348
            padding_group=0
349
350
351
352
353
354

        if envs.VLLM_USE_TRITON_AWQ:
            best_config=getspec_config(m,n,k)
            out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config)    
            out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, ))
        else:
355
356
357
358
359
360
361
362
363
364
365
366
367
            if envs.AWQ_GEMM_MARLIN: 
                out = torch.ops.vllm.gemm_awq_w4a16_marlin(reshaped_x, qweight, zeros_and_scales, m, n, k)
            else:
                out = torch.ops.vllm.awq_gemm(reshaped_x,
                                qweight,
                                zeros_and_scales,
                                m,
                                n,
                                k,
                                self.quant_config.group_size,
                                padding_group,
                                self.awqsingleton.awqworkshapce,
                                self.awqsingleton.awqworkshapcesize)
gaoqiong's avatar
gaoqiong committed
368
        
369
        if bias is not None:
370
            out.add_(bias)
zhuwenwen's avatar
zhuwenwen committed
371
        return out.reshape(out_shape)