"vllm/vscode:/vscode.git/clone" did not exist on "a2e9ebe9e242295a58e400835ef98a14b29c4fb0"
awq.py 13.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Any, Optional, Union
5
6

import torch
zhuwenwen's avatar
zhuwenwen committed
7
import os
8
import torch.nn.functional as F
9
10
11
12
import vllm.envs as envs
import json
import math
from vllm.platforms import current_platform
13
from vllm import _custom_ops as ops
14
15
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
16
17
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
18
from vllm.model_executor.layers.quantization import QuantizationMethods
19
from vllm.model_executor.layers.quantization.base_config import (
20
    QuantizationConfig, QuantizeMethodBase)
21
22
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
                                           PackedvLLMParameter)
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
from vllm.logger import init_logger
logger = init_logger(__name__)
triton_configs_dict={}

def get_triton_cache(file_path):
        #会将所报错的json文件以字典的形式return出来
     
    if os.path.exists(file_path):
        with open(file_path, 'r') as file:
            cachedata = json.load(file)
                
    #把所有的cache解析成key:config的形式:[M_N_K]:[config]
    for key, value in cachedata.items():
        for sub_key, sub_value in value.items():
            configs_key= f"{sub_key}_{key}"
            configs_value={
                'SPLIT_K': int(sub_value["SPLIT_K"]),
                'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
                'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
                'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
                'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
                'num_stages':int(sub_value['num_stages']),
                'num_warps':int(sub_value['num_warps'])
            }
            if 'num_ldmatrixes' in sub_value:
                configs_value["num_ldmatrixes"] = int(sub_value['num_ldmatrixes'])
            triton_configs_dict[configs_key]=configs_value
    logger.info("%s have loaded!", file_path)

def default_execution(k,n):
    configs_key= f"1_{n}_{k}"
    if configs_key in triton_configs_dict:
        return
    script_dir = os.path.dirname(os.path.abspath(__file__))
    cache_json_file=f"{script_dir}/configs/awq/"
    device_name = current_platform.get_device_name().replace(" ", "_")
    filename = f"AWQ_{n}_{k}_{device_name}.json"
    file_full_path = os.path.join(cache_json_file, filename)

    if os.path.isfile(file_full_path) and file_full_path.endswith(".json"):
        # 如果是文件,则添加到列表
        get_triton_cache(file_full_path)
    return


def getspec_config(M,N,K):
    if f"{M}_{N}_{K}" in triton_configs_dict:
        return triton_configs_dict[f"{M}_{N}_{K}"]
    else:
        return None  
74
75


76
77
78
79
80
81
82
83
84
85
class AWQShareWorkSpace:
    _instance = None
    
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super(AWQShareWorkSpace, cls).__new__(cls, *args, **kwargs)
            cls._instance._initialize()
        return cls._instance

    def _initialize(self):
zhuwenwen's avatar
zhuwenwen committed
86
87
        self.awqworkshapcesize = ops.GetAWQShareWorkspaceSize()
        self.awqworkshapce = ops.GetAWQShareWorkspace()
88

89
90
logger = init_logger(__name__)

91
92
93
94
95
96
97
98
99
100
101
102

class AWQConfig(QuantizationConfig):
    """Config class for AWQ.

    Reference: https://arxiv.org/abs/2306.00978
    """

    def __init__(
        self,
        weight_bits: int,
        group_size: int,
        zero_point: bool,
103
        modules_to_not_convert: Optional[list[str]] = None,
104
    ) -> None:
105
        super().__init__()
106
107
108
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.zero_point = zero_point
109
        self.modules_to_not_convert = modules_to_not_convert or []
110
111
112
113
114
115
116
117
118
119

        if self.weight_bits != 4:
            raise ValueError(
                "Currently, only 4-bit weight quantization is supported for "
                f"AWQ, but got {self.weight_bits} bits.")
        self.pack_factor = 32 // self.weight_bits

    def __repr__(self) -> str:
        return (f"AWQConfig(weight_bits={self.weight_bits}, "
                f"group_size={self.group_size}, "
120
121
                f"zero_point={self.zero_point}, "
                f"modules_to_not_convert={self.modules_to_not_convert})")
122

123
    def get_name(self) -> QuantizationMethods:
124
125
        return "awq"

126
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
127
128
        return [torch.half]

129
130
    @classmethod
    def get_min_capability(cls) -> int:
131
132
133
134
        # The AWQ kernel only supports Turing or newer GPUs.
        return 75

    @staticmethod
135
    def get_config_filenames() -> list[str]:
136
137
        return [
            "quant_config.json",  # E.g., casperhansen/vicuna-7b-v1.5-awq
138
139
            # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
            "quantize_config.json",
140
141
142
        ]

    @classmethod
143
    def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
144
145
146
        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
        zero_point = cls.get_from_keys(config, ["zero_point"])
147
148
149
        modules_to_not_convert = cls.get_from_keys_or(
            config, ["modules_to_not_convert"], None)
        return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
150

151
152
153
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
154
        if isinstance(layer, LinearBase):
155
156
            if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
                return UnquantizedLinearMethod()
157
            return AWQLinearMethod(self)
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        elif isinstance(layer, FusedMoE):
            # Lazy import to avoid circular import.
            from .awq_marlin import AWQMarlinConfig, AWQMoEMethod
            from .moe_wna16 import MoeWNA16Config
            from .utils.marlin_utils import check_moe_marlin_supports_layer
            if not check_moe_marlin_supports_layer(layer, self.group_size):
                logger.warning_once(
                    f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
                    "Falling back to Moe WNA16 kernels.")
                config = {
                    "quant_method": "awq",
                    "bits": self.weight_bits,
                    "group_size": self.group_size,
                    "zero_point": self.zero_point,
                    "lm_head": False,
                }
                return MoeWNA16Config.from_config(config).get_quant_method(
                    layer, prefix)
            marlin_compatible_config_dict = {
                "quant_method": "awq",
                "bits": self.weight_bits,
                "group_size": self.group_size,
                "zero_point": self.zero_point,
                "lm_head": False,
                "modules_to_not_convert": self.modules_to_not_convert,
            }
            awq_marlin_config = AWQMarlinConfig.from_config(
                marlin_compatible_config_dict)
            return AWQMoEMethod(awq_marlin_config)
187
        return None
188
189


190
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
191
    return any(module_name in prefix for module_name in modules_to_not_convert)
192

193
194
195
196
197
198
199
200
201
202

class AWQLinearMethod(LinearMethodBase):
    """Linear method for AWQ.

    Args:
        quant_config: The AWQ quantization config.
    """

    def __init__(self, quant_config: AWQConfig):
        self.quant_config = quant_config
203
        self.awqsingleton= AWQShareWorkSpace()
zhuwenwen's avatar
zhuwenwen committed
204
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
205

206
207
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
208
                       output_partition_sizes: list[int], input_size: int,
209
210
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
211
212
213
214
215
216
217
        # Normalize group_size
        if self.quant_config.group_size != -1:
            group_size = self.quant_config.group_size
        else:
            group_size = input_size

        if input_size_per_partition % group_size != 0:
218
219
220
221
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")
James Fleming's avatar
James Fleming committed
222
223

        output_size_per_partition = sum(output_partition_sizes)
CHU Tianxiang's avatar
CHU Tianxiang committed
224
        if output_size_per_partition % self.quant_config.pack_factor != 0:
225
226
227
228
229
            raise ValueError(
                "The output size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")

230
231
232
        weight_loader = extra_weight_attrs.get("weight_loader")
        qweight = PackedvLLMParameter(
            data=torch.empty(
CHU Tianxiang's avatar
CHU Tianxiang committed
233
234
                input_size_per_partition,
                output_size_per_partition // self.quant_config.pack_factor,
235
236
                dtype=torch.int32,
            ),
237
238
239
240
241
242
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            weight_loader=weight_loader)

243
244
        num_groups = input_size_per_partition // group_size

245
246
        qzeros = PackedvLLMParameter(
            data=torch.empty(
247
                num_groups,
CHU Tianxiang's avatar
CHU Tianxiang committed
248
                output_size_per_partition // self.quant_config.pack_factor,
249
250
                dtype=torch.int32,
            ),
251
252
253
254
255
256
257
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            weight_loader=weight_loader)

        scales = GroupQuantScaleParameter(data=torch.empty(
258
            num_groups,
259
260
261
262
263
264
            output_size_per_partition,
            dtype=params_dtype,
        ),
                                          input_dim=0,
                                          output_dim=1,
                                          weight_loader=weight_loader)
gaoqiong's avatar
gaoqiong committed
265
        
266
267
268
        zeros_and_scales = GroupQuantScaleParameter(data=torch.empty(
            input_size_per_partition // self.quant_config.group_size,
            output_size_per_partition,
zhuwenwen's avatar
zhuwenwen committed
269
            dtype=torch.int32,
270
271
272
273
        ),
                                          input_dim=0,
                                          output_dim=1,
                                          weight_loader=weight_loader)
274
275
276
277

        layer.register_parameter("qweight", qweight)
        layer.register_parameter("qzeros", qzeros)
        layer.register_parameter("scales", scales)
gaoqiong's avatar
gaoqiong committed
278
        layer.register_parameter("zeros_and_scales", zeros_and_scales)
279
280
281
        # 加载triton_config
        if envs.VLLM_USE_TRITON_AWQ:
            default_execution(input_size_per_partition,output_size_per_partition)
282
283

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        if not envs.VLLM_USE_TRITON_AWQ:
            
            group_size= self.quant_config.group_size 
            pad_group=2 
            dim_n = layer.scales.data.shape[1]
            dim_k = layer.qweight.data.shape[0]
            _qw, _sz=ops.convert_s4(layer.qweight,layer.qzeros,layer.scales,int(group_size)) 
            sz = ops.sz_permute(_sz).reshape(-1,dim_n)  
            sz = sz.reshape(dim_n,-1)
            _qw = _qw.reshape(dim_n,-1)
            
            if dim_k % 4096==0 and self.use_awq_pad:
                zeros_and_scalse_pad = torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                sz = torch.cat((sz,zeros_and_scalse_pad),dim=1).contiguous()
                qweight_pad = torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                _qw=torch.cat((_qw,qweight_pad),dim=1).contiguous()
                        
            layer.qweight = torch.nn.Parameter(_qw, requires_grad=False)
            layer.zeros_and_scales = torch.nn.Parameter(sz, requires_grad=False)
            layer.qzeros = None
            layer.scales = None
        else:

            layer.qweight = torch.nn.Parameter(layer.qweight.data,
                                            requires_grad=False)
            layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
                                            requires_grad=False)
            layer.scales = torch.nn.Parameter(layer.scales.data,
                                            requires_grad=False)
313

314
315
316
317
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
318
        qweight = layer.qweight
gaoqiong's avatar
gaoqiong committed
319
        zeros_and_scales = layer.zeros_and_scales
320
321
322
        qzeros = layer.qzeros
        scales = layer.scales
        pack_factor = self.quant_config.pack_factor   
gaoqiong's avatar
gaoqiong committed
323
        out_shape = (x.shape[:-1] + (qweight.shape[0] * 1, ))
324
        reshaped_x = x.reshape(-1, x.shape[-1])
gaoqiong's avatar
gaoqiong committed
325
326
327
328
329
        
        m = reshaped_x.shape[0]
        k = reshaped_x.shape[-1]
        n = qweight.shape[0]
        
zhuwenwen's avatar
zhuwenwen committed
330
331
332
333
334
        if self.use_awq_pad:
            if k % 4096 == 0:
                padding_group=2
            else:
                padding_group=0
335
        else:
gaoqiong's avatar
gaoqiong committed
336
            padding_group=0
337
338
339
340

        if envs.VLLM_USE_TRITON_AWQ:

            if m>16:
yangql's avatar
yangql committed
341
                m = 1 << (m - 1).bit_length()
342
343
344
345
            best_config=getspec_config(m,n,k)
            out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config)    
            out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, ))
        else:
zhuwenwen's avatar
zhuwenwen committed
346
            out = ops.awq_gemm(reshaped_x,
gaoqiong's avatar
gaoqiong committed
347
348
349
350
351
352
353
                            qweight,
                            zeros_and_scales,
                            m,
                            n,
                            k,
                            self.quant_config.group_size,
                            padding_group,
354
355
                            self.awqsingleton.awqworkshapce,
                            self.awqsingleton.awqworkshapcesize)
gaoqiong's avatar
gaoqiong committed
356
        
357
        if bias is not None:
358
            out.add_(bias)
zhuwenwen's avatar
zhuwenwen committed
359
        return out.reshape(out_shape)