awq.py 15.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any, Union
5

zhuwenwen's avatar
zhuwenwen committed
6
import os
7
import json
8
import torch
9
10
from vllm import envs

11
from vllm.platforms import current_platform
12
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
13

14
from vllm import _custom_ops as ops
15
16
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
17
18
19
20
21
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
22
from vllm.model_executor.layers.quantization.base_config import (
23
        QuantizationConfig,
24
25
    QuantizeMethodBase,
)
26
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
27
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
28
29
30
31
32
from vllm.transformers_utils.config import get_safetensors_params_metadata

if TYPE_CHECKING:
    from vllm.model_executor.layers.quantization import QuantizationMethods
    from vllm.model_executor.models.utils import WeightsMapper
33
 
34
35
36
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
triton_configs_dict={}

37
38
39
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
    return any(module_name in prefix for module_name in modules_to_not_convert)

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def get_triton_cache(file_path):
        #会将所报错的json文件以字典的形式return出来
     
    if os.path.exists(file_path):
        with open(file_path, 'r') as file:
            cachedata = json.load(file)
                
    #把所有的cache解析成key:config的形式:[M_N_K]:[config]
    for key, value in cachedata.items():
        for sub_key, sub_value in value.items():
            configs_key= f"{sub_key}_{key}"
            configs_value={
                'SPLIT_K': int(sub_value["SPLIT_K"]),
                'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
                'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
                'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
                'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
                'num_stages':int(sub_value['num_stages']),
                'num_warps':int(sub_value['num_warps'])
            }
            if 'num_ldmatrixes' in sub_value:
                configs_value["num_ldmatrixes"] = int(sub_value['num_ldmatrixes'])
            triton_configs_dict[configs_key]=configs_value
    logger.info("%s have loaded!", file_path)

def default_execution(k,n):
    configs_key= f"1_{n}_{k}"
    if configs_key in triton_configs_dict:
        return
    script_dir = os.path.dirname(os.path.abspath(__file__))
    cache_json_file=f"{script_dir}/configs/awq/"
    device_name = current_platform.get_device_name().replace(" ", "_")
    filename = f"AWQ_{n}_{k}_{device_name}.json"
    file_full_path = os.path.join(cache_json_file, filename)

    if os.path.isfile(file_full_path) and file_full_path.endswith(".json"):
        # 如果是文件,则添加到列表
        get_triton_cache(file_full_path)
    return


def getspec_config(M,N,K):
82
83
84
85
86
87
88
89
    m_config = M
    if M > 16:
        # 直接计算 2 的幂
        m_config = 1
        while m_config < M:
            m_config *= 2
    if f"{m_config}_{N}_{K}" in triton_configs_dict:
        return triton_configs_dict[f"{m_config}_{N}_{K}"]
90
91
    else:
        return None  
92
93


94
95
96
97
98
99
100
101
102
103
class AWQShareWorkSpace:
    _instance = None
    
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super(AWQShareWorkSpace, cls).__new__(cls, *args, **kwargs)
            cls._instance._initialize()
        return cls._instance

    def _initialize(self):
zhuwenwen's avatar
zhuwenwen committed
104
105
        self.awqworkshapcesize = ops.GetAWQShareWorkspaceSize()
        self.awqworkshapce = ops.GetAWQShareWorkspace()
106
107


108
109
logger = init_logger(__name__)

110
111
112
113
114
115
116
117
118
119
120
121

class AWQConfig(QuantizationConfig):
    """Config class for AWQ.

    Reference: https://arxiv.org/abs/2306.00978
    """

    def __init__(
        self,
        weight_bits: int,
        group_size: int,
        zero_point: bool,
122
        modules_to_not_convert: list[str] | None = None,
123
    ) -> None:
124
        super().__init__()
125
126
127
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.zero_point = zero_point
128
        self.modules_to_not_convert = modules_to_not_convert or []
129
130
131
132

        if self.weight_bits != 4:
            raise ValueError(
                "Currently, only 4-bit weight quantization is supported for "
133
134
                f"AWQ, but got {self.weight_bits} bits."
            )
135
136
137
        self.pack_factor = 32 // self.weight_bits

    def __repr__(self) -> str:
138
139
140
141
142
143
        return (
            f"AWQConfig(weight_bits={self.weight_bits}, "
            f"group_size={self.group_size}, "
            f"zero_point={self.zero_point}, "
            f"modules_to_not_convert={self.modules_to_not_convert})"
        )
144

145
    def get_name(self) -> "QuantizationMethods":
146
147
        return "awq"

148
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
zhuwenwen's avatar
zhuwenwen committed
149
        return [torch.half, torch.bfloat16]
150

151
152
    @classmethod
    def get_min_capability(cls) -> int:
153
154
155
156
        # The AWQ kernel only supports Turing or newer GPUs.
        return 75

    @staticmethod
157
    def get_config_filenames() -> list[str]:
158
159
        return [
            "quant_config.json",  # E.g., casperhansen/vicuna-7b-v1.5-awq
160
161
            # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
            "quantize_config.json",
162
163
164
        ]

    @classmethod
165
    def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
166
167
168
        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
        zero_point = cls.get_from_keys(config, ["zero_point"])
169
        modules_to_not_convert = cls.get_from_keys_or(
170
171
            config, ["modules_to_not_convert"], None
        )
172
        return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
173

174
175
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
176
    ) -> Union["LinearMethodBase", "QuantizeMethodBase"] | None:
177
        if isinstance(layer, LinearBase):
178
179
180
181
182
183
            if is_layer_skipped(
                prefix,
                self.modules_to_not_convert,
                self.packed_modules_mapping,
                skip_with_substr=True,
            ):
184
                return UnquantizedLinearMethod()
185
            return AWQLinearMethod(self)
186
187
        elif isinstance(layer, FusedMoE):
            # Lazy import to avoid circular import.
188
            from .awq_marlin import AWQMarlinConfig
189
190
            from .moe_wna16 import MoeWNA16Config
            from .utils.marlin_utils import check_moe_marlin_supports_layer
191

192
193
194
            if not check_moe_marlin_supports_layer(layer, self.group_size):
                logger.warning_once(
                    f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
195
196
                    "Falling back to Moe WNA16 kernels."
                )
197
198
199
200
201
202
                config = {
                    "quant_method": "awq",
                    "bits": self.weight_bits,
                    "group_size": self.group_size,
                    "zero_point": self.zero_point,
                    "lm_head": False,
203
                    "modules_to_not_convert": self.modules_to_not_convert,
204
205
                }
                return MoeWNA16Config.from_config(config).get_quant_method(
206
207
                    layer, prefix
                )
208
209
210
211
212
213
214
215
216
            marlin_compatible_config_dict = {
                "quant_method": "awq",
                "bits": self.weight_bits,
                "group_size": self.group_size,
                "zero_point": self.zero_point,
                "lm_head": False,
                "modules_to_not_convert": self.modules_to_not_convert,
            }
            awq_marlin_config = AWQMarlinConfig.from_config(
217
218
                marlin_compatible_config_dict
            )
219
            return awq_marlin_config.get_quant_method(layer, prefix)
220
        return None
221

222
223
224
225
226
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.modules_to_not_convert:
            self.modules_to_not_convert = hf_to_vllm_mapper.apply_list(
                self.modules_to_not_convert
            )
227

228
229
230
231
232
233
234
235
236
237
238
239
240
241
    def maybe_update_config(self, model_name: str, revision: str | None = None):
        if self.modules_to_not_convert:
            return

        unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
        metadata = get_safetensors_params_metadata(model_name, revision=revision)
        layers = {param_name.rsplit(".", 1)[0] for param_name in metadata}
        quant_layers: set[str] = {
            param_name.rsplit(".", 1)[0]
            for param_name, info in metadata.items()
            if (dtype := info.get("dtype", None))
            and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
        }
        self.modules_to_not_convert = list(layers - quant_layers)
242

243
244
245
246
247
248
249
250
251
252

class AWQLinearMethod(LinearMethodBase):
    """Linear method for AWQ.

    Args:
        quant_config: The AWQ quantization config.
    """

    def __init__(self, quant_config: AWQConfig):
        self.quant_config = quant_config
253
        self.awqsingleton= AWQShareWorkSpace()
zhuwenwen's avatar
zhuwenwen committed
254
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
255

256
257
258
259
260
261
262
263
264
265
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
266
267
268
269
270
271
272
        # Normalize group_size
        if self.quant_config.group_size != -1:
            group_size = self.quant_config.group_size
        else:
            group_size = input_size

        if input_size_per_partition % group_size != 0:
273
274
275
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
276
277
                "tensor parallel size."
            )
James Fleming's avatar
James Fleming committed
278
279

        output_size_per_partition = sum(output_partition_sizes)
CHU Tianxiang's avatar
CHU Tianxiang committed
280
        if output_size_per_partition % self.quant_config.pack_factor != 0:
281
282
283
            raise ValueError(
                "The output size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
284
285
                "tensor parallel size."
            )
286

287
288
289
        weight_loader = extra_weight_attrs.get("weight_loader")
        qweight = PackedvLLMParameter(
            data=torch.empty(
CHU Tianxiang's avatar
CHU Tianxiang committed
290
291
                input_size_per_partition,
                output_size_per_partition // self.quant_config.pack_factor,
292
293
                dtype=torch.int32,
            ),
294
295
296
297
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
298
299
            weight_loader=weight_loader,
        )
300

301
302
        num_groups = input_size_per_partition // group_size

303
304
        qzeros = PackedvLLMParameter(
            data=torch.empty(
305
                num_groups,
CHU Tianxiang's avatar
CHU Tianxiang committed
306
                output_size_per_partition // self.quant_config.pack_factor,
307
308
                dtype=torch.int32,
            ),
309
310
311
312
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
313
314
            weight_loader=weight_loader,
        )
315

316
317
318
319
320
321
322
323
324
325
        scales = GroupQuantScaleParameter(
            data=torch.empty(
                num_groups,
                output_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=0,
            output_dim=1,
            weight_loader=weight_loader,
        )
326
327
328
329
330
331
332
333
334
        zeros_and_scales = GroupQuantScaleParameter(
            data=torch.empty(
                input_size_per_partition // self.quant_config.group_size,
                output_size_per_partition,
                dtype=torch.int32,
            ),
            input_dim=0,
            output_dim=1,
            weight_loader=weight_loader)
335
336
337
338

        layer.register_parameter("qweight", qweight)
        layer.register_parameter("qzeros", qzeros)
        layer.register_parameter("scales", scales)
gaoqiong's avatar
gaoqiong committed
339
        layer.register_parameter("zeros_and_scales", zeros_and_scales)
340
341
342
        # 加载triton_config
        if envs.VLLM_USE_TRITON_AWQ:
            default_execution(input_size_per_partition,output_size_per_partition)
343
344

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
345
346
347
348
349
350
        if not envs.VLLM_USE_TRITON_AWQ:
            
            group_size= self.quant_config.group_size 
            pad_group=2 
            dim_n = layer.scales.data.shape[1]
            dim_k = layer.qweight.data.shape[0]
zhuwenwen's avatar
zhuwenwen committed
351
            _qw, _sz=ops.convert_s4(layer.qweight,layer.qzeros,layer.scales.to(torch.float16),int(group_size)) 
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
            sz = ops.sz_permute(_sz).reshape(-1,dim_n)  
            sz = sz.reshape(dim_n,-1)
            _qw = _qw.reshape(dim_n,-1)
            
            if dim_k % 4096==0 and self.use_awq_pad:
                zeros_and_scalse_pad = torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                sz = torch.cat((sz,zeros_and_scalse_pad),dim=1).contiguous()
                qweight_pad = torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                _qw=torch.cat((_qw,qweight_pad),dim=1).contiguous()
                        
            layer.qweight = torch.nn.Parameter(_qw, requires_grad=False)
            layer.zeros_and_scales = torch.nn.Parameter(sz, requires_grad=False)
            layer.qzeros = None
            layer.scales = None
        else:

368
369
370
            layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
            layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
            layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
371
372
373
374
375

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
376
        bias: torch.Tensor | None = None,
377
    ) -> torch.Tensor:
378
        qweight = layer.qweight
gaoqiong's avatar
gaoqiong committed
379
        zeros_and_scales = layer.zeros_and_scales
380
381
        qzeros = layer.qzeros
        scales = layer.scales
382
        pack_factor = self.quant_config.pack_factor
383
        out_shape = (x.shape[:-1] + (qweight.shape[0] * 1, ))
384
        reshaped_x = x.reshape(-1, x.shape[-1])
gaoqiong's avatar
gaoqiong committed
385
386
387
388
389
        
        m = reshaped_x.shape[0]
        k = reshaped_x.shape[-1]
        n = qweight.shape[0]
        
zhuwenwen's avatar
zhuwenwen committed
390
391
392
393
394
        if self.use_awq_pad:
            if k % 4096 == 0:
                padding_group=2
            else:
                padding_group=0
395
        else:
gaoqiong's avatar
gaoqiong committed
396
            padding_group=0
397
398
399
400
401
402

        if envs.VLLM_USE_TRITON_AWQ:
            best_config=getspec_config(m,n,k)
            out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config)    
            out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, ))
        else:
403
            # out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
404
            out = torch.ops.vllm.awq_gemm(reshaped_x,
gaoqiong's avatar
gaoqiong committed
405
406
407
408
409
410
411
                            qweight,
                            zeros_and_scales,
                            m,
                            n,
                            k,
                            self.quant_config.group_size,
                            padding_group,
412
413
                            self.awqsingleton.awqworkshapce,
                            self.awqsingleton.awqworkshapcesize)
414
        if bias is not None:
415
            out.add_(bias)
zhuwenwen's avatar
zhuwenwen committed
416
        return out.reshape(out_shape)