awq.py 15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any, Union
5

zhuwenwen's avatar
zhuwenwen committed
6
import os
7
import json
8
import torch
9
10
from vllm import envs

11
from vllm.platforms import current_platform
12
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
13

14
from vllm import _custom_ops as ops
15
16
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
17
18
19
20
21
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
22
from vllm.model_executor.layers.quantization.base_config import (
23
        QuantizationConfig,
24
25
    QuantizeMethodBase,
)
26
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
27
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
28
29
30
31
32
from vllm.transformers_utils.config import get_safetensors_params_metadata

if TYPE_CHECKING:
    from vllm.model_executor.layers.quantization import QuantizationMethods
    from vllm.model_executor.models.utils import WeightsMapper
33
 
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
triton_configs_dict={}

def get_triton_cache(file_path):
        #会将所报错的json文件以字典的形式return出来
     
    if os.path.exists(file_path):
        with open(file_path, 'r') as file:
            cachedata = json.load(file)
                
    #把所有的cache解析成key:config的形式:[M_N_K]:[config]
    for key, value in cachedata.items():
        for sub_key, sub_value in value.items():
            configs_key= f"{sub_key}_{key}"
            configs_value={
                'SPLIT_K': int(sub_value["SPLIT_K"]),
                'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
                'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
                'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
                'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
                'num_stages':int(sub_value['num_stages']),
                'num_warps':int(sub_value['num_warps'])
            }
            if 'num_ldmatrixes' in sub_value:
                configs_value["num_ldmatrixes"] = int(sub_value['num_ldmatrixes'])
            triton_configs_dict[configs_key]=configs_value
    logger.info("%s have loaded!", file_path)

def default_execution(k,n):
    configs_key= f"1_{n}_{k}"
    if configs_key in triton_configs_dict:
        return
    script_dir = os.path.dirname(os.path.abspath(__file__))
    cache_json_file=f"{script_dir}/configs/awq/"
    device_name = current_platform.get_device_name().replace(" ", "_")
    filename = f"AWQ_{n}_{k}_{device_name}.json"
    file_full_path = os.path.join(cache_json_file, filename)

    if os.path.isfile(file_full_path) and file_full_path.endswith(".json"):
        # 如果是文件,则添加到列表
        get_triton_cache(file_full_path)
    return


def getspec_config(M,N,K):
79
80
81
82
83
84
85
86
    m_config = M
    if M > 16:
        # 直接计算 2 的幂
        m_config = 1
        while m_config < M:
            m_config *= 2
    if f"{m_config}_{N}_{K}" in triton_configs_dict:
        return triton_configs_dict[f"{m_config}_{N}_{K}"]
87
88
    else:
        return None  
89
90


91
92
93
94
95
96
97
98
99
100
class AWQShareWorkSpace:
    _instance = None
    
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super(AWQShareWorkSpace, cls).__new__(cls, *args, **kwargs)
            cls._instance._initialize()
        return cls._instance

    def _initialize(self):
zhuwenwen's avatar
zhuwenwen committed
101
102
        self.awqworkshapcesize = ops.GetAWQShareWorkspaceSize()
        self.awqworkshapce = ops.GetAWQShareWorkspace()
103
104


105
106
logger = init_logger(__name__)

107
108
109
110
111
112
113
114
115
116
117
118

class AWQConfig(QuantizationConfig):
    """Config class for AWQ.

    Reference: https://arxiv.org/abs/2306.00978
    """

    def __init__(
        self,
        weight_bits: int,
        group_size: int,
        zero_point: bool,
119
        modules_to_not_convert: list[str] | None = None,
120
    ) -> None:
121
        super().__init__()
122
123
124
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.zero_point = zero_point
125
        self.modules_to_not_convert = modules_to_not_convert or []
126
127
128
129

        if self.weight_bits != 4:
            raise ValueError(
                "Currently, only 4-bit weight quantization is supported for "
130
131
                f"AWQ, but got {self.weight_bits} bits."
            )
132
133
134
        self.pack_factor = 32 // self.weight_bits

    def __repr__(self) -> str:
135
136
137
138
139
140
        return (
            f"AWQConfig(weight_bits={self.weight_bits}, "
            f"group_size={self.group_size}, "
            f"zero_point={self.zero_point}, "
            f"modules_to_not_convert={self.modules_to_not_convert})"
        )
141

142
    def get_name(self) -> "QuantizationMethods":
143
144
        return "awq"

145
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
zhuwenwen's avatar
zhuwenwen committed
146
        return [torch.half, torch.bfloat16]
147

148
149
    @classmethod
    def get_min_capability(cls) -> int:
150
151
152
153
        # The AWQ kernel only supports Turing or newer GPUs.
        return 75

    @staticmethod
154
    def get_config_filenames() -> list[str]:
155
156
        return [
            "quant_config.json",  # E.g., casperhansen/vicuna-7b-v1.5-awq
157
158
            # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
            "quantize_config.json",
159
160
161
        ]

    @classmethod
162
    def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
163
164
165
        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
        zero_point = cls.get_from_keys(config, ["zero_point"])
166
        modules_to_not_convert = cls.get_from_keys_or(
167
168
            config, ["modules_to_not_convert"], None
        )
169
        return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
170

171
172
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
173
    ) -> Union["LinearMethodBase", "QuantizeMethodBase"] | None:
174
        if isinstance(layer, LinearBase):
175
176
177
178
179
180
            if is_layer_skipped(
                prefix,
                self.modules_to_not_convert,
                self.packed_modules_mapping,
                skip_with_substr=True,
            ):
181
                return UnquantizedLinearMethod()
182
            return AWQLinearMethod(self)
183
184
        elif isinstance(layer, FusedMoE):
            # Lazy import to avoid circular import.
185
            from .awq_marlin import AWQMarlinConfig
186
187
            from .moe_wna16 import MoeWNA16Config
            from .utils.marlin_utils import check_moe_marlin_supports_layer
188

189
190
191
            if not check_moe_marlin_supports_layer(layer, self.group_size):
                logger.warning_once(
                    f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
192
193
                    "Falling back to Moe WNA16 kernels."
                )
194
195
196
197
198
199
                config = {
                    "quant_method": "awq",
                    "bits": self.weight_bits,
                    "group_size": self.group_size,
                    "zero_point": self.zero_point,
                    "lm_head": False,
200
                    "modules_to_not_convert": self.modules_to_not_convert,
201
202
                }
                return MoeWNA16Config.from_config(config).get_quant_method(
203
204
                    layer, prefix
                )
205
206
207
208
209
210
211
212
213
            marlin_compatible_config_dict = {
                "quant_method": "awq",
                "bits": self.weight_bits,
                "group_size": self.group_size,
                "zero_point": self.zero_point,
                "lm_head": False,
                "modules_to_not_convert": self.modules_to_not_convert,
            }
            awq_marlin_config = AWQMarlinConfig.from_config(
214
215
                marlin_compatible_config_dict
            )
216
            return awq_marlin_config.get_quant_method(layer, prefix)
217
        return None
218

219
220
221
222
223
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.modules_to_not_convert:
            self.modules_to_not_convert = hf_to_vllm_mapper.apply_list(
                self.modules_to_not_convert
            )
224

225
226
227
228
229
230
231
232
233
234
235
236
237
238
    def maybe_update_config(self, model_name: str, revision: str | None = None):
        if self.modules_to_not_convert:
            return

        unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
        metadata = get_safetensors_params_metadata(model_name, revision=revision)
        layers = {param_name.rsplit(".", 1)[0] for param_name in metadata}
        quant_layers: set[str] = {
            param_name.rsplit(".", 1)[0]
            for param_name, info in metadata.items()
            if (dtype := info.get("dtype", None))
            and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
        }
        self.modules_to_not_convert = list(layers - quant_layers)
239

240
241
242
243
244
245
246
247
248
249

class AWQLinearMethod(LinearMethodBase):
    """Linear method for AWQ.

    Args:
        quant_config: The AWQ quantization config.
    """

    def __init__(self, quant_config: AWQConfig):
        self.quant_config = quant_config
250
        self.awqsingleton= AWQShareWorkSpace()
zhuwenwen's avatar
zhuwenwen committed
251
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
252

253
254
255
256
257
258
259
260
261
262
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
263
264
265
266
267
268
269
        # Normalize group_size
        if self.quant_config.group_size != -1:
            group_size = self.quant_config.group_size
        else:
            group_size = input_size

        if input_size_per_partition % group_size != 0:
270
271
272
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
273
274
                "tensor parallel size."
            )
James Fleming's avatar
James Fleming committed
275
276

        output_size_per_partition = sum(output_partition_sizes)
CHU Tianxiang's avatar
CHU Tianxiang committed
277
        if output_size_per_partition % self.quant_config.pack_factor != 0:
278
279
280
            raise ValueError(
                "The output size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
281
282
                "tensor parallel size."
            )
283

284
285
286
        weight_loader = extra_weight_attrs.get("weight_loader")
        qweight = PackedvLLMParameter(
            data=torch.empty(
CHU Tianxiang's avatar
CHU Tianxiang committed
287
288
                input_size_per_partition,
                output_size_per_partition // self.quant_config.pack_factor,
289
290
                dtype=torch.int32,
            ),
291
292
293
294
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
295
296
            weight_loader=weight_loader,
        )
297

298
299
        num_groups = input_size_per_partition // group_size

300
301
        qzeros = PackedvLLMParameter(
            data=torch.empty(
302
                num_groups,
CHU Tianxiang's avatar
CHU Tianxiang committed
303
                output_size_per_partition // self.quant_config.pack_factor,
304
305
                dtype=torch.int32,
            ),
306
307
308
309
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
310
311
            weight_loader=weight_loader,
        )
312

313
314
315
316
317
318
319
320
321
322
        scales = GroupQuantScaleParameter(
            data=torch.empty(
                num_groups,
                output_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=0,
            output_dim=1,
            weight_loader=weight_loader,
        )
323
324
325
326
327
328
329
330
331
        zeros_and_scales = GroupQuantScaleParameter(
            data=torch.empty(
                input_size_per_partition // self.quant_config.group_size,
                output_size_per_partition,
                dtype=torch.int32,
            ),
            input_dim=0,
            output_dim=1,
            weight_loader=weight_loader)
332
333
334
335

        layer.register_parameter("qweight", qweight)
        layer.register_parameter("qzeros", qzeros)
        layer.register_parameter("scales", scales)
gaoqiong's avatar
gaoqiong committed
336
        layer.register_parameter("zeros_and_scales", zeros_and_scales)
337
338
339
        # 加载triton_config
        if envs.VLLM_USE_TRITON_AWQ:
            default_execution(input_size_per_partition,output_size_per_partition)
340
341

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
342
343
344
345
346
347
        if not envs.VLLM_USE_TRITON_AWQ:
            
            group_size= self.quant_config.group_size 
            pad_group=2 
            dim_n = layer.scales.data.shape[1]
            dim_k = layer.qweight.data.shape[0]
zhuwenwen's avatar
zhuwenwen committed
348
            _qw, _sz=ops.convert_s4(layer.qweight,layer.qzeros,layer.scales.to(torch.float16),int(group_size)) 
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
            sz = ops.sz_permute(_sz).reshape(-1,dim_n)  
            sz = sz.reshape(dim_n,-1)
            _qw = _qw.reshape(dim_n,-1)
            
            if dim_k % 4096==0 and self.use_awq_pad:
                zeros_and_scalse_pad = torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                sz = torch.cat((sz,zeros_and_scalse_pad),dim=1).contiguous()
                qweight_pad = torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                _qw=torch.cat((_qw,qweight_pad),dim=1).contiguous()
                        
            layer.qweight = torch.nn.Parameter(_qw, requires_grad=False)
            layer.zeros_and_scales = torch.nn.Parameter(sz, requires_grad=False)
            layer.qzeros = None
            layer.scales = None
        else:

365
366
367
            layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
            layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
            layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
368
369
370
371
372

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
373
        bias: torch.Tensor | None = None,
374
    ) -> torch.Tensor:
375
        qweight = layer.qweight
gaoqiong's avatar
gaoqiong committed
376
        zeros_and_scales = layer.zeros_and_scales
377
378
        qzeros = layer.qzeros
        scales = layer.scales
379
        pack_factor = self.quant_config.pack_factor
380
        out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
381
        reshaped_x = x.reshape(-1, x.shape[-1])
gaoqiong's avatar
gaoqiong committed
382
383
384
385
386
        
        m = reshaped_x.shape[0]
        k = reshaped_x.shape[-1]
        n = qweight.shape[0]
        
zhuwenwen's avatar
zhuwenwen committed
387
388
389
390
391
        if self.use_awq_pad:
            if k % 4096 == 0:
                padding_group=2
            else:
                padding_group=0
392
        else:
gaoqiong's avatar
gaoqiong committed
393
            padding_group=0
394
395
396
397
398
399

        if envs.VLLM_USE_TRITON_AWQ:
            best_config=getspec_config(m,n,k)
            out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config)    
            out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, ))
        else:
400
            # out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
401
            out = torch.ops.vllm.awq_gemm(reshaped_x,
gaoqiong's avatar
gaoqiong committed
402
403
404
405
406
407
408
                            qweight,
                            zeros_and_scales,
                            m,
                            n,
                            k,
                            self.quant_config.group_size,
                            padding_group,
409
410
                            self.awqsingleton.awqworkshapce,
                            self.awqsingleton.awqworkshapcesize)
411
        if bias is not None:
412
            out.add_(bias)
zhuwenwen's avatar
zhuwenwen committed
413
        return out.reshape(out_shape)