awq.py 14.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any, Union
5

zhuwenwen's avatar
zhuwenwen committed
6
import os
7
import json
8
import torch
9
10
from vllm import envs

11
from vllm.platforms import current_platform
12
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
13

14
from vllm import _custom_ops as ops
15
16
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
17
18
19
20
21
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
22
from vllm.model_executor.layers.quantization.base_config import (
23
        QuantizationConfig,
24
25
    QuantizeMethodBase,
)
26
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
27
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
28
29
30
31
32
from vllm.transformers_utils.config import get_safetensors_params_metadata

if TYPE_CHECKING:
    from vllm.model_executor.layers.quantization import QuantizationMethods
    from vllm.model_executor.models.utils import WeightsMapper
33
 
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
triton_configs_dict={}

def get_triton_cache(file_path):
        #会将所报错的json文件以字典的形式return出来
     
    if os.path.exists(file_path):
        with open(file_path, 'r') as file:
            cachedata = json.load(file)
                
    #把所有的cache解析成key:config的形式:[M_N_K]:[config]
    for key, value in cachedata.items():
        for sub_key, sub_value in value.items():
            configs_key= f"{sub_key}_{key}"
            configs_value={
                'SPLIT_K': int(sub_value["SPLIT_K"]),
                'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
                'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
                'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
                'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
                'num_stages':int(sub_value['num_stages']),
                'num_warps':int(sub_value['num_warps'])
            }
            if 'num_ldmatrixes' in sub_value:
                configs_value["num_ldmatrixes"] = int(sub_value['num_ldmatrixes'])
            triton_configs_dict[configs_key]=configs_value
    logger.info("%s have loaded!", file_path)

def default_execution(k,n):
    configs_key= f"1_{n}_{k}"
    if configs_key in triton_configs_dict:
        return
    script_dir = os.path.dirname(os.path.abspath(__file__))
    cache_json_file=f"{script_dir}/configs/awq/"
    device_name = current_platform.get_device_name().replace(" ", "_")
    filename = f"AWQ_{n}_{k}_{device_name}.json"
    file_full_path = os.path.join(cache_json_file, filename)

    if os.path.isfile(file_full_path) and file_full_path.endswith(".json"):
        # 如果是文件,则添加到列表
        get_triton_cache(file_full_path)
    return


def getspec_config(M,N,K):
79
80
81
82
83
84
85
86
    m_config = M
    if M > 16:
        # 直接计算 2 的幂
        m_config = 1
        while m_config < M:
            m_config *= 2
    if f"{m_config}_{N}_{K}" in triton_configs_dict:
        return triton_configs_dict[f"{m_config}_{N}_{K}"]
87
88
    else:
        return None  
89
90


91
92
93
94
95
96
97
98
99
100
class AWQShareWorkSpace:
    _instance = None
    
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super(AWQShareWorkSpace, cls).__new__(cls, *args, **kwargs)
            cls._instance._initialize()
        return cls._instance

    def _initialize(self):
zhuwenwen's avatar
zhuwenwen committed
101
102
        self.awqworkshapcesize = ops.GetAWQShareWorkspaceSize()
        self.awqworkshapce = ops.GetAWQShareWorkspace()
103
104


105
106
logger = init_logger(__name__)

107
108
109
110
111
112
113
114
115
116
117
118

class AWQConfig(QuantizationConfig):
    """Config class for AWQ.

    Reference: https://arxiv.org/abs/2306.00978
    """

    def __init__(
        self,
        weight_bits: int,
        group_size: int,
        zero_point: bool,
119
        modules_to_not_convert: list[str] | None = None,
120
    ) -> None:
121
        super().__init__()
122
123
124
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.zero_point = zero_point
125
        self.modules_to_not_convert = modules_to_not_convert or []
126
127
128
129

        if self.weight_bits != 4:
            raise ValueError(
                "Currently, only 4-bit weight quantization is supported for "
130
131
                f"AWQ, but got {self.weight_bits} bits."
            )
132
133
134
        self.pack_factor = 32 // self.weight_bits

    def __repr__(self) -> str:
135
136
137
138
139
140
        return (
            f"AWQConfig(weight_bits={self.weight_bits}, "
            f"group_size={self.group_size}, "
            f"zero_point={self.zero_point}, "
            f"modules_to_not_convert={self.modules_to_not_convert})"
        )
141

142
    def get_name(self) -> "QuantizationMethods":
143
144
        return "awq"

145
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
zhuwenwen's avatar
zhuwenwen committed
146
        return [torch.half, torch.bfloat16]
147

148
149
    @classmethod
    def get_min_capability(cls) -> int:
150
151
152
153
        # The AWQ kernel only supports Turing or newer GPUs.
        return 75

    @staticmethod
154
    def get_config_filenames() -> list[str]:
155
156
        return [
            "quant_config.json",  # E.g., casperhansen/vicuna-7b-v1.5-awq
157
158
            # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
            "quantize_config.json",
159
160
161
        ]

    @classmethod
162
    def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
163
164
165
        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
        zero_point = cls.get_from_keys(config, ["zero_point"])
166
        modules_to_not_convert = cls.get_from_keys_or(
167
168
            config, ["modules_to_not_convert"], None
        )
169
        return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
170

171
172
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
173
    ) -> Union["LinearMethodBase", "QuantizeMethodBase"] | None:
174
        if isinstance(layer, LinearBase):
175
176
177
178
179
180
            if is_layer_skipped(
                prefix,
                self.modules_to_not_convert,
                self.packed_modules_mapping,
                skip_with_substr=True,
            ):
181
                return UnquantizedLinearMethod()
182
            return AWQLinearMethod(self)
183
184
        elif isinstance(layer, FusedMoE):
            # Lazy import to avoid circular import.
185
            from .awq_marlin import AWQMarlinConfig, AWQMarlinMoEMethod
186
187
            from .moe_wna16 import MoeWNA16Config
            from .utils.marlin_utils import check_moe_marlin_supports_layer
188

189
190
191
            if not check_moe_marlin_supports_layer(layer, self.group_size):
                logger.warning_once(
                    f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
192
193
                    "Falling back to Moe WNA16 kernels."
                )
194
195
196
197
198
199
200
201
                config = {
                    "quant_method": "awq",
                    "bits": self.weight_bits,
                    "group_size": self.group_size,
                    "zero_point": self.zero_point,
                    "lm_head": False,
                }
                return MoeWNA16Config.from_config(config).get_quant_method(
202
203
                    layer, prefix
                )
204
205
206
207
208
209
210
211
212
            marlin_compatible_config_dict = {
                "quant_method": "awq",
                "bits": self.weight_bits,
                "group_size": self.group_size,
                "zero_point": self.zero_point,
                "lm_head": False,
                "modules_to_not_convert": self.modules_to_not_convert,
            }
            awq_marlin_config = AWQMarlinConfig.from_config(
213
214
                marlin_compatible_config_dict
            )
215
            return AWQMarlinMoEMethod(awq_marlin_config, layer.moe_config)
216
        return None
217

218
219
220
221
222
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.modules_to_not_convert:
            self.modules_to_not_convert = hf_to_vllm_mapper.apply_list(
                self.modules_to_not_convert
            )
223

224
225
226
227
228
229
230
231
232
233
234
235
236
237
    def maybe_update_config(self, model_name: str, revision: str | None = None):
        if self.modules_to_not_convert:
            return

        unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
        metadata = get_safetensors_params_metadata(model_name, revision=revision)
        layers = {param_name.rsplit(".", 1)[0] for param_name in metadata}
        quant_layers: set[str] = {
            param_name.rsplit(".", 1)[0]
            for param_name, info in metadata.items()
            if (dtype := info.get("dtype", None))
            and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
        }
        self.modules_to_not_convert = list(layers - quant_layers)
238

239
240
241
242
243
244
245
246
247
248

class AWQLinearMethod(LinearMethodBase):
    """Linear method for AWQ.

    Args:
        quant_config: The AWQ quantization config.
    """

    def __init__(self, quant_config: AWQConfig):
        self.quant_config = quant_config
249
        self.awqsingleton= AWQShareWorkSpace()
zhuwenwen's avatar
zhuwenwen committed
250
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
251

252
253
254
255
256
257
258
259
260
261
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
262
263
264
265
266
267
268
        # Normalize group_size
        if self.quant_config.group_size != -1:
            group_size = self.quant_config.group_size
        else:
            group_size = input_size

        if input_size_per_partition % group_size != 0:
269
270
271
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
272
273
                "tensor parallel size."
            )
James Fleming's avatar
James Fleming committed
274
275

        output_size_per_partition = sum(output_partition_sizes)
CHU Tianxiang's avatar
CHU Tianxiang committed
276
        if output_size_per_partition % self.quant_config.pack_factor != 0:
277
278
279
            raise ValueError(
                "The output size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
280
281
                "tensor parallel size."
            )
282

283
284
285
        weight_loader = extra_weight_attrs.get("weight_loader")
        qweight = PackedvLLMParameter(
            data=torch.empty(
CHU Tianxiang's avatar
CHU Tianxiang committed
286
287
                input_size_per_partition,
                output_size_per_partition // self.quant_config.pack_factor,
288
289
                dtype=torch.int32,
            ),
290
291
292
293
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
294
295
            weight_loader=weight_loader,
        )
296

297
298
        num_groups = input_size_per_partition // group_size

299
300
        qzeros = PackedvLLMParameter(
            data=torch.empty(
301
                num_groups,
CHU Tianxiang's avatar
CHU Tianxiang committed
302
                output_size_per_partition // self.quant_config.pack_factor,
303
304
                dtype=torch.int32,
            ),
305
306
307
308
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
309
310
            weight_loader=weight_loader,
        )
311

312
313
314
315
316
317
318
319
320
321
        scales = GroupQuantScaleParameter(
            data=torch.empty(
                num_groups,
                output_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=0,
            output_dim=1,
            weight_loader=weight_loader,
        )
322
323
324
325
326
327
328
329
330
        zeros_and_scales = GroupQuantScaleParameter(
            data=torch.empty(
                input_size_per_partition // self.quant_config.group_size,
                output_size_per_partition,
                dtype=torch.int32,
            ),
            input_dim=0,
            output_dim=1,
            weight_loader=weight_loader)
331
332
333
334

        layer.register_parameter("qweight", qweight)
        layer.register_parameter("qzeros", qzeros)
        layer.register_parameter("scales", scales)
gaoqiong's avatar
gaoqiong committed
335
        layer.register_parameter("zeros_and_scales", zeros_and_scales)
336
337
338
        # 加载triton_config
        if envs.VLLM_USE_TRITON_AWQ:
            default_execution(input_size_per_partition,output_size_per_partition)
339
340

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
341
342
343
344
345
346
        if not envs.VLLM_USE_TRITON_AWQ:
            
            group_size= self.quant_config.group_size 
            pad_group=2 
            dim_n = layer.scales.data.shape[1]
            dim_k = layer.qweight.data.shape[0]
zhuwenwen's avatar
zhuwenwen committed
347
            _qw, _sz=ops.convert_s4(layer.qweight,layer.qzeros,layer.scales.to(torch.float16),int(group_size)) 
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
            sz = ops.sz_permute(_sz).reshape(-1,dim_n)  
            sz = sz.reshape(dim_n,-1)
            _qw = _qw.reshape(dim_n,-1)
            
            if dim_k % 4096==0 and self.use_awq_pad:
                zeros_and_scalse_pad = torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                sz = torch.cat((sz,zeros_and_scalse_pad),dim=1).contiguous()
                qweight_pad = torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                _qw=torch.cat((_qw,qweight_pad),dim=1).contiguous()
                        
            layer.qweight = torch.nn.Parameter(_qw, requires_grad=False)
            layer.zeros_and_scales = torch.nn.Parameter(sz, requires_grad=False)
            layer.qzeros = None
            layer.scales = None
        else:

364
365
366
            layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
            layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
            layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
367
368
369
370
371

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
372
        bias: torch.Tensor | None = None,
373
    ) -> torch.Tensor:
374
        qweight = layer.qweight
gaoqiong's avatar
gaoqiong committed
375
        zeros_and_scales = layer.zeros_and_scales
376
377
        qzeros = layer.qzeros
        scales = layer.scales
378
        pack_factor = self.quant_config.pack_factor
379
        out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
380
        reshaped_x = x.reshape(-1, x.shape[-1])
gaoqiong's avatar
gaoqiong committed
381
382
383
384
385
        
        m = reshaped_x.shape[0]
        k = reshaped_x.shape[-1]
        n = qweight.shape[0]
        
zhuwenwen's avatar
zhuwenwen committed
386
387
388
389
390
        if self.use_awq_pad:
            if k % 4096 == 0:
                padding_group=2
            else:
                padding_group=0
391
        else:
gaoqiong's avatar
gaoqiong committed
392
            padding_group=0
393
394
395
396
397
398

        if envs.VLLM_USE_TRITON_AWQ:
            best_config=getspec_config(m,n,k)
            out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config)    
            out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, ))
        else:
399
            # out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
400
            out = torch.ops.vllm.awq_gemm(reshaped_x,
gaoqiong's avatar
gaoqiong committed
401
402
403
404
405
406
407
                            qweight,
                            zeros_and_scales,
                            m,
                            n,
                            k,
                            self.quant_config.group_size,
                            padding_group,
408
409
                            self.awqsingleton.awqworkshapce,
                            self.awqsingleton.awqworkshapcesize)
410
        if bias is not None:
411
            out.add_(bias)
zhuwenwen's avatar
zhuwenwen committed
412
        return out.reshape(out_shape)