base.py 12.8 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
4

Casper Hansen's avatar
Casper Hansen committed
5
6
import torch
import torch.nn as nn
7

Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper's avatar
Casper committed
9
from typing import List, Union
10
from safetensors.torch import save_file
11

12
from huggingface_hub import snapshot_download
13
import transformers
14
from transformers.modeling_utils import shard_checkpoint
15

16
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
17
18
19
20
21
from awq.utils.module import (
    get_named_linears,
    set_op_by_name,
    exclude_layers_to_not_quantize,
)
Casper's avatar
Casper committed
22
23
24
25
from transformers import (
    AutoConfig,
    PreTrainedModel,
    PretrainedConfig,
26
27
    AutoProcessor,
    CLIPImageProcessor,
Casper's avatar
Casper committed
28
)
29
30
31
32
from accelerate.big_modeling import (
    init_empty_weights,
    load_checkpoint_and_dispatch,
)
Casper's avatar
Casper committed
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name

# Since we support different `AutoModelForxxx` from transformers 
# we need to define a custom mapping dict as below:
TRANSFORMERS_AUTO_MAPPING_DICT = {
    "mpt": "AutoModelForCausalLM",
    "llama": "AutoModelForCausalLM",
    "opt": "AutoModelForCausalLM",
    "RefinedWeb": "AutoModelForCausalLM",
    "RefinedWebModel": "AutoModelForCausalLM",
    "falcon": "AutoModelForCausalLM",
    "bloom": "AutoModelForCausalLM",
    "gptj": "AutoModelForCausalLM",
    "gpt_bigcode": "AutoModelForCausalLM",
    "mistral": "AutoModelForCausalLM",
    "mixtral": "AutoModelForCausalLM",
    "gpt_neox": "AutoModelForCausalLM",
    "aquila": "AutoModelForCausalLM",
    "Yi": "AutoModelForCausalLM",
    "qwen": "AutoModelForCausalLM",
Aoyu's avatar
Aoyu committed
58
    "baichuan": "AutoModelForCausalLM",
59
60
61
    "llava": "AutoModelForVision2Seq",
}

62
class BaseAWQForCausalLM(nn.Module):
63
    def __init__(self, model, model_type, is_quantized, config, quant_config, processor):
64
        super().__init__()
65
66
67
68
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
Casper's avatar
Casper committed
69
        self.config: PretrainedConfig = config
Casper's avatar
Casper committed
70
        self.quant_config: AwqConfig = quant_config
71
        self.processor: CLIPImageProcessor = processor
72
73
74
75
76
77
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
78
79
80
81
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
82

Casper Hansen's avatar
Casper Hansen committed
83
    @torch.no_grad()
Casper Hansen's avatar
Casper Hansen committed
84
85
    def quantize(self, tokenizer=None, quant_config={},
                       calib_data: Union[str, List[str]]="pileval", 
86
87
                       split="train", text_column="text", duo_scaling=True, 
                       modules_to_not_convert=None, export_compatible=False):
Casper's avatar
Casper committed
88
        self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
89

90
        self.quantizer = AwqQuantizer(
Casper's avatar
Casper committed
91
            self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size,
92
93
            self.quant_config.version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=modules_to_not_convert,
            export_compatible=export_compatible
Casper Hansen's avatar
Casper Hansen committed
94
        )
95
        self.quantizer.quantize()
Aoyu's avatar
Aoyu committed
96
        
Casper Hansen's avatar
Casper Hansen committed
97
        self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
98
    
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    @torch.no_grad()
    def pack(self):
        """
        A utility function for the following scenario. Note that save_quantized will
        overwrite existing weights if you use the same quant_path.
        
        model.quantize(
            tokenizer,
            quant_config=quant_config,
            export_compatible=True
        ) 
        model.save_quantized(...)  # produces GGUF/other compat weights
        model.pack(...) # makes the model CUDA compat
        model.save_quantized(...)  # produces CUDA compat weights
        """
        self.quantizer.pack()
    
qwopqwop200's avatar
qwopqwop200 committed
116
    @staticmethod
Casper's avatar
Casper committed
117
    def fuse_layers(model):
qwopqwop200's avatar
qwopqwop200 committed
118
        pass
Casper's avatar
Casper committed
119

120
    def save_quantized(self, save_dir, safetensors=True, shard_size="10GB"):
Casper Hansen's avatar
Casper Hansen committed
121
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
122

Casper Hansen's avatar
Casper Hansen committed
123
124
125
126
        # Save model
        class EmptyModule(nn.Module):
            def __init__(self): super(EmptyModule, self).__init__()
            def forward(self, x): return x
127

Casper's avatar
Casper committed
128
129
        # Save model and config files with empty state dict
        self.model.config.quantization_config = self.quant_config.to_transformers_dict()
Casper Hansen's avatar
Casper Hansen committed
130
        self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
Casper's avatar
Casper committed
131
        self.quant_config.save_pretrained(save_dir)
132

133
134
135
136
        # Vision transformers have a processor
        if self.processor is not None:
            self.processor.save_pretrained(save_dir)

Casper Hansen's avatar
Casper Hansen committed
137
        # Remove empty state dict
138
139
140
141
        default_paths = [f'{save_dir}/model.safetensors', f'{save_dir}/pytorch_model.bin']
        for path in default_paths:
            if os.path.exists(path):
                os.remove(path)
142

Casper Hansen's avatar
Casper Hansen committed
143
144
        # model_name has no extension, add it when saving state_dict
        model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
145

Casper Hansen's avatar
Casper Hansen committed
146
147
148
149
150
151
        # shard checkpoint into chunks (10GB default)
        shards, index = shard_checkpoint(
            self.model.state_dict(), 
            max_shard_size=shard_size, 
            weights_name=model_name
        )
152

Casper Hansen's avatar
Casper Hansen committed
153
154
155
156
157
158
159
        for shard_file, shard in shards.items():
            if safetensors:
                # safetensors must be in the same memory, so we duplicate and use contiguous memory
                shard = {k: v.clone().contiguous() for k, v in shard.items()}
                save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
            else:
                torch.save(shard, os.path.join(save_dir, shard_file))
160

Casper Hansen's avatar
Casper Hansen committed
161
162
163
164
165
        # save shard index
        if index is not None:
            with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
                file.write(json.dumps(index, indent=4))
        
166
        
167
168
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
169
170
171
172
173
                        trust_remote_code=True, safetensors=False, device_map=None,
                        **model_init_kwargs):
        # Get weights path and quant config
        model_weights_path, config, quant_config = self._load_config(
            self, model_path, '', safetensors, trust_remote_code=trust_remote_code
174
        )
Casper's avatar
Casper committed
175

176
177
178
179
180
181
182
183
        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)

        processor = None
        if target_cls_name == "AutoModelForVision2Seq":
            processor = AutoProcessor.from_pretrained(model_weights_path)
            processor: CLIPImageProcessor = processor.image_processor

Casper Hansen's avatar
Casper Hansen committed
184
        # If not quantized, must load with AutoModelForCausalLM
185
        model = target_cls.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
186
187
188
189
            model_weights_path,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
            use_safetensors=safetensors,
190
            device_map=device_map,
Casper Hansen's avatar
Casper Hansen committed
191
192
193
194
195
            **model_init_kwargs
        )

        model.eval()

196
197
        return self(model, model_type, is_quantized=False, config=config, 
                    quant_config=quant_config, processor=processor)
Casper Hansen's avatar
Casper Hansen committed
198

199
    @classmethod
200
    def from_quantized(self, model_path, model_type, model_filename='', 
Casper Hansen's avatar
Casper Hansen committed
201
                             max_new_tokens=None, torch_dtype=torch.float16, 
202
                             trust_remote_code=True, safetensors=True, is_quantized=True, 
s4rduk4r's avatar
s4rduk4r committed
203
                             fuse_layers=False, version='GEMM',
204
                             device_map="balanced", offload_folder=None,
205
                             **config_kwargs):
Casper Hansen's avatar
Casper Hansen committed
206
207
208
        # [STEP 1-2] Load weights path and configs
        model_weights_path, config, quant_config = self._load_config(
            self, model_path, model_filename, safetensors, version, 
209
210
            trust_remote_code, max_new_tokens=max_new_tokens,
            **config_kwargs
Casper Hansen's avatar
Casper Hansen committed
211
        )
212
213
214

        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)
Casper Hansen's avatar
Casper Hansen committed
215
216
217
        
        # [STEP 3] Load model
        with init_empty_weights():
218
            model = target_cls.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
Casper Hansen's avatar
Casper Hansen committed
219
220
        
        # Prepare WQLinear layers, replace nn.Linear
Casper's avatar
Casper committed
221
        self._load_quantized_modules(self, model, quant_config, quant_config.version)
Casper Hansen's avatar
Casper Hansen committed
222
223
224
        
        model.tie_weights()

225
226
227
        # loads the weights into modules and distributes
        # across available devices automatically
        load_checkpoint_and_dispatch(
Casper Hansen's avatar
Casper Hansen committed
228
229
            model,
            checkpoint=model_weights_path,
s4rduk4r's avatar
s4rduk4r committed
230
            device_map=device_map,
231
            no_split_module_classes=[self.layer_type],
s4rduk4r's avatar
s4rduk4r committed
232
            offload_folder=offload_folder,
233
            dtype=torch_dtype,
Casper Hansen's avatar
Casper Hansen committed
234
235
236
        )
        
        # Dispath to devices
237
        if fuse_layers:
Casper's avatar
Casper committed
238
            self.fuse_layers(model)
s4rduk4r's avatar
s4rduk4r committed
239

240
241
        return self(model, model_type, is_quantized=is_quantized, config=config,
                    quant_config=quant_config, processor=None)
Casper Hansen's avatar
Casper Hansen committed
242

243
    def _load_config(self, model_path, model_filename, safetensors=True, 
244
245
                           version="GEMM", trust_remote_code=True, max_new_tokens=4096,
                           **config_kwargs):
246
        # [STEP 1] Download model if path is not a directory
247
        if not os.path.isdir(model_path):
248
            ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
249
            if safetensors:
250
                ignore_patterns.extend(["*.pt*", "*.bin*", "consolidated*"])
251
            else:
Casper Hansen's avatar
Casper Hansen committed
252
253
                ignore_patterns.append("*.safetensors*")
            
254
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
255
        
256
257
258
259
        if model_filename != '':
            model_weights_path = model_path + f'/{model_filename}'
        else:
            model_weights_path = model_path
260

261
        # [STEP 2] Load config and set sequence length
262
        # TODO: Create BaseAWQConfig class
Casper's avatar
Casper committed
263
        quant_config = AwqConfig.from_pretrained(model_path)
264
        
265
266
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
267
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
268
269
270
271
            config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048)
            # To add the generate support for Multi-modal models as well
            if hasattr(config, "text_config"):
                config.text_config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048)
272
273
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
274
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
275
276
            config.max_new_tokens = max_new_tokens
        
Casper Hansen's avatar
Casper Hansen committed
277
        return model_weights_path, config, quant_config
Casper's avatar
Casper committed
278

Casper Hansen's avatar
Casper Hansen committed
279
    def _load_quantized_modules(self, model, quant_config, version):
280
        # Real quantization of weights
Casper's avatar
Casper committed
281
        assert quant_config.zero_point, "We only support zero_point quantization now."
282
283
        
        # Get blocks of model
284
        layers = self.get_model_layers(model)
285

286
287
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
288
289

            # Get every linear layer in a block
290
            named_linears = get_named_linears(layer)
291

292
293
294
            # Filter out the linear layers we don't want to exclude
            named_linears = exclude_layers_to_not_quantize(named_linears, quant_config.modules_to_not_convert)

295
            # Replace activation functions
296
            self._scale_activations(self, layer)
297

298
            # Replace nn.Linear with WQLinear
299
            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
300
301
302
303
304
305
                if version == 'GEMM':
                    q_linear_module = WQLinear_GEMM
                elif version == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
306
                    module,
Casper's avatar
Casper committed
307
308
                    quant_config.w_bit,
                    quant_config.q_group_size,
Casper Hansen's avatar
Casper Hansen committed
309
310
                    True
                )
311
312
313
314
315
316
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
317
    @staticmethod
318
    def _scale_activations(self, layer):
319
        scale_dict = self.get_act_for_scaling(layer)
320

321
322
323
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
324

325
326
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
327

328
329
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
330
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)