awq.py 11 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
from typing import Any, Dict, List, Optional

import torch
zhuwenwen's avatar
zhuwenwen committed
6
import os
7
import torch.nn.functional as F
8
9
10
11
import vllm.envs as envs
import json
import math
from vllm.platforms import current_platform
12
from vllm import _custom_ops as ops
13
14
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
15
16
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
17
18
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
                                           PackedvLLMParameter)
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
from vllm.logger import init_logger
logger = init_logger(__name__)
triton_configs_dict={}

def get_triton_cache(file_path):
        #会将所报错的json文件以字典的形式return出来
     
    if os.path.exists(file_path):
        with open(file_path, 'r') as file:
            cachedata = json.load(file)
                
    #把所有的cache解析成key:config的形式:[M_N_K]:[config]
    for key, value in cachedata.items():
        for sub_key, sub_value in value.items():
            configs_key= f"{sub_key}_{key}"
            configs_value={
                'SPLIT_K': int(sub_value["SPLIT_K"]),
                'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
                'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
                'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
                'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
                'num_stages':int(sub_value['num_stages']),
                'num_warps':int(sub_value['num_warps'])
            }
            if 'num_ldmatrixes' in sub_value:
                configs_value["num_ldmatrixes"] = int(sub_value['num_ldmatrixes'])
            triton_configs_dict[configs_key]=configs_value
    logger.info("%s have loaded!", file_path)

def default_execution(k,n):
    configs_key= f"1_{n}_{k}"
    if configs_key in triton_configs_dict:
        return
    script_dir = os.path.dirname(os.path.abspath(__file__))
    cache_json_file=f"{script_dir}/configs/awq/"
    device_name = current_platform.get_device_name().replace(" ", "_")
    filename = f"AWQ_{n}_{k}_{device_name}.json"
    file_full_path = os.path.join(cache_json_file, filename)

    if os.path.isfile(file_full_path) and file_full_path.endswith(".json"):
        # 如果是文件,则添加到列表
        get_triton_cache(file_full_path)
    return


def getspec_config(M,N,K):
    if f"{M}_{N}_{K}" in triton_configs_dict:
        return triton_configs_dict[f"{M}_{N}_{K}"]
    else:
        return None  
70
71


72
73
74
75
76
77
78
79
80
81
class AWQShareWorkSpace:
    _instance = None
    
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super(AWQShareWorkSpace, cls).__new__(cls, *args, **kwargs)
            cls._instance._initialize()
        return cls._instance

    def _initialize(self):
zhuwenwen's avatar
zhuwenwen committed
82
83
        self.awqworkshapcesize = ops.GetAWQShareWorkspaceSize()
        self.awqworkshapce = ops.GetAWQShareWorkspace()
84
85
86
87
88
89
90
91
92
93
94
95
96


class AWQConfig(QuantizationConfig):
    """Config class for AWQ.

    Reference: https://arxiv.org/abs/2306.00978
    """

    def __init__(
        self,
        weight_bits: int,
        group_size: int,
        zero_point: bool,
97
        modules_to_not_convert: Optional[List[str]] = None,
98
99
100
101
    ) -> None:
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.zero_point = zero_point
102
        self.modules_to_not_convert = modules_to_not_convert or []
103
104
105
106
107
108
109
110
111
112

        if self.weight_bits != 4:
            raise ValueError(
                "Currently, only 4-bit weight quantization is supported for "
                f"AWQ, but got {self.weight_bits} bits.")
        self.pack_factor = 32 // self.weight_bits

    def __repr__(self) -> str:
        return (f"AWQConfig(weight_bits={self.weight_bits}, "
                f"group_size={self.group_size}, "
113
114
                f"zero_point={self.zero_point}, "
                f"modules_to_not_convert={self.modules_to_not_convert})")
115
116
117
118
119
120
121

    def get_name(self) -> str:
        return "awq"

    def get_supported_act_dtypes(self) -> List[torch.dtype]:
        return [torch.half]

122
123
    @classmethod
    def get_min_capability(cls) -> int:
124
125
126
127
128
129
130
        # The AWQ kernel only supports Turing or newer GPUs.
        return 75

    @staticmethod
    def get_config_filenames() -> List[str]:
        return [
            "quant_config.json",  # E.g., casperhansen/vicuna-7b-v1.5-awq
131
132
            # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
            "quantize_config.json",
133
134
135
136
137
138
139
        ]

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
        zero_point = cls.get_from_keys(config, ["zero_point"])
140
141
142
        modules_to_not_convert = cls.get_from_keys_or(
            config, ["modules_to_not_convert"], None)
        return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
143

144
    def get_quant_method(self, layer: torch.nn.Module,
145
                         prefix: str) -> Optional["LinearMethodBase"]:
146
        if isinstance(layer, LinearBase):
147
148
            if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
                return UnquantizedLinearMethod()
149
150
            return AWQLinearMethod(self)
        return None
151
152


153
154
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
    return any(module_name in prefix for module_name in modules_to_not_convert)
155

156
157
158
159
160
161
162
163
164
165

class AWQLinearMethod(LinearMethodBase):
    """Linear method for AWQ.

    Args:
        quant_config: The AWQ quantization config.
    """

    def __init__(self, quant_config: AWQConfig):
        self.quant_config = quant_config
166
        self.awqsingleton= AWQShareWorkSpace()
zhuwenwen's avatar
zhuwenwen committed
167
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
168

169
170
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
171
                       output_partition_sizes: List[int], input_size: int,
172
173
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
CHU Tianxiang's avatar
CHU Tianxiang committed
174
        if input_size_per_partition % self.quant_config.group_size != 0:
175
176
177
178
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")
James Fleming's avatar
James Fleming committed
179
180

        output_size_per_partition = sum(output_partition_sizes)
CHU Tianxiang's avatar
CHU Tianxiang committed
181
        if output_size_per_partition % self.quant_config.pack_factor != 0:
182
183
184
185
186
            raise ValueError(
                "The output size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")

187
188
189
        weight_loader = extra_weight_attrs.get("weight_loader")
        qweight = PackedvLLMParameter(
            data=torch.empty(
CHU Tianxiang's avatar
CHU Tianxiang committed
190
191
                input_size_per_partition,
                output_size_per_partition // self.quant_config.pack_factor,
192
193
                dtype=torch.int32,
            ),
194
195
196
197
198
199
200
201
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            weight_loader=weight_loader)

        qzeros = PackedvLLMParameter(
            data=torch.empty(
CHU Tianxiang's avatar
CHU Tianxiang committed
202
203
                input_size_per_partition // self.quant_config.group_size,
                output_size_per_partition // self.quant_config.pack_factor,
204
205
                dtype=torch.int32,
            ),
206
207
208
209
210
211
212
213
214
215
216
217
218
219
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            weight_loader=weight_loader)

        scales = GroupQuantScaleParameter(data=torch.empty(
            input_size_per_partition // self.quant_config.group_size,
            output_size_per_partition,
            dtype=params_dtype,
        ),
                                          input_dim=0,
                                          output_dim=1,
                                          weight_loader=weight_loader)
gaoqiong's avatar
gaoqiong committed
220
        
221
222
223
        zeros_and_scales = GroupQuantScaleParameter(data=torch.empty(
            input_size_per_partition // self.quant_config.group_size,
            output_size_per_partition,
zhuwenwen's avatar
zhuwenwen committed
224
            dtype=torch.int32,
225
226
227
228
        ),
                                          input_dim=0,
                                          output_dim=1,
                                          weight_loader=weight_loader)
229
230
231
232

        layer.register_parameter("qweight", qweight)
        layer.register_parameter("qzeros", qzeros)
        layer.register_parameter("scales", scales)
gaoqiong's avatar
gaoqiong committed
233
        layer.register_parameter("zeros_and_scales", zeros_and_scales)
234
235
236
        # 加载triton_config
        if envs.VLLM_USE_TRITON_AWQ:
            default_execution(input_size_per_partition,output_size_per_partition)
237
238
239
240
241
242
243
244

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        layer.qweight = torch.nn.Parameter(layer.qweight.data,
                                           requires_grad=False)
        layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
                                          requires_grad=False)
        layer.scales = torch.nn.Parameter(layer.scales.data,
                                          requires_grad=False)
245
246
        layer.zeros_and_scales = torch.nn.Parameter(layer.zeros_and_scales.data,
                                          requires_grad=False)
247

248
249
250
251
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
252
        qweight = layer.qweight
gaoqiong's avatar
gaoqiong committed
253
        zeros_and_scales = layer.zeros_and_scales
254
255
256
        qzeros = layer.qzeros
        scales = layer.scales
        pack_factor = self.quant_config.pack_factor   
gaoqiong's avatar
gaoqiong committed
257
        out_shape = (x.shape[:-1] + (qweight.shape[0] * 1, ))
258
        reshaped_x = x.reshape(-1, x.shape[-1])
gaoqiong's avatar
gaoqiong committed
259
260
261
262
263
        
        m = reshaped_x.shape[0]
        k = reshaped_x.shape[-1]
        n = qweight.shape[0]
        
zhuwenwen's avatar
zhuwenwen committed
264
265
266
267
268
        if self.use_awq_pad:
            if k % 4096 == 0:
                padding_group=2
            else:
                padding_group=0
269
        else:
gaoqiong's avatar
gaoqiong committed
270
            padding_group=0
271
272
273
274
275
276
277
278
279

        if envs.VLLM_USE_TRITON_AWQ:

            if m>16:
                m = 2 ** math.ceil(math.log2(m))
            best_config=getspec_config(m,n,k)
            out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config)    
            out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, ))
        else:
zhuwenwen's avatar
zhuwenwen committed
280
            out = ops.awq_gemm(reshaped_x,
gaoqiong's avatar
gaoqiong committed
281
282
283
284
285
286
287
                            qweight,
                            zeros_and_scales,
                            m,
                            n,
                            k,
                            self.quant_config.group_size,
                            padding_group,
288
289
                            self.awqsingleton.awqworkshapce,
                            self.awqsingleton.awqworkshapcesize)
gaoqiong's avatar
gaoqiong committed
290
        
291
        if bias is not None:
292
            out.add_(bias)
zhuwenwen's avatar
zhuwenwen committed
293
        return out.reshape(out_shape)