base.py 12.1 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
4

Casper Hansen's avatar
Casper Hansen committed
5
6
import torch
import torch.nn as nn
7

Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper's avatar
Casper committed
9
from typing import List, Union
10
from safetensors.torch import save_file
11

12
from huggingface_hub import snapshot_download
13
import transformers
14
from transformers.modeling_utils import shard_checkpoint
15

16
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
17
18
19
20
21
from awq.utils.module import (
    get_named_linears,
    set_op_by_name,
    exclude_layers_to_not_quantize,
)
Casper's avatar
Casper committed
22
23
24
25
from transformers import (
    AutoConfig,
    PreTrainedModel,
    PretrainedConfig,
26
27
    AutoProcessor,
    CLIPImageProcessor,
Casper's avatar
Casper committed
28
)
29
30
31
32
from accelerate.big_modeling import (
    init_empty_weights,
    load_checkpoint_and_dispatch,
)
Casper's avatar
Casper committed
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name

# Since we support different `AutoModelForxxx` from transformers 
# we need to define a custom mapping dict as below:
TRANSFORMERS_AUTO_MAPPING_DICT = {
    "mpt": "AutoModelForCausalLM",
    "llama": "AutoModelForCausalLM",
    "opt": "AutoModelForCausalLM",
    "RefinedWeb": "AutoModelForCausalLM",
    "RefinedWebModel": "AutoModelForCausalLM",
    "falcon": "AutoModelForCausalLM",
    "bloom": "AutoModelForCausalLM",
    "gptj": "AutoModelForCausalLM",
    "gpt_bigcode": "AutoModelForCausalLM",
    "mistral": "AutoModelForCausalLM",
    "mixtral": "AutoModelForCausalLM",
    "gpt_neox": "AutoModelForCausalLM",
    "aquila": "AutoModelForCausalLM",
    "Yi": "AutoModelForCausalLM",
    "qwen": "AutoModelForCausalLM",
    "llava": "AutoModelForVision2Seq",
}

61
class BaseAWQForCausalLM(nn.Module):
62
    def __init__(self, model, model_type, is_quantized, config, quant_config, processor):
63
        super().__init__()
64
65
66
67
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
Casper's avatar
Casper committed
68
        self.config: PretrainedConfig = config
Casper's avatar
Casper committed
69
        self.quant_config: AwqConfig = quant_config
70
        self.processor: CLIPImageProcessor = processor
71
72
73
74
75
76
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
77
78
79
80
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
81

Casper Hansen's avatar
Casper Hansen committed
82
    @torch.no_grad()
Casper Hansen's avatar
Casper Hansen committed
83
84
    def quantize(self, tokenizer=None, quant_config={},
                       calib_data: Union[str, List[str]]="pileval", 
85
                       split="train", text_column="text", duo_scaling=True, modules_to_not_convert=None):
Casper's avatar
Casper committed
86
        self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
87

Casper Hansen's avatar
Casper Hansen committed
88
        quantizer = AwqQuantizer(
Casper's avatar
Casper committed
89
            self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size,
90
            self.quant_config.version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=modules_to_not_convert
Casper Hansen's avatar
Casper Hansen committed
91
92
93
        )
        quantizer.quantize()
        self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
94
    
qwopqwop200's avatar
qwopqwop200 committed
95
    @staticmethod
Casper's avatar
Casper committed
96
    def fuse_layers(model):
qwopqwop200's avatar
qwopqwop200 committed
97
        pass
Casper's avatar
Casper committed
98

99
    def save_quantized(self, save_dir, safetensors=True, shard_size="10GB"):
Casper Hansen's avatar
Casper Hansen committed
100
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
101

Casper Hansen's avatar
Casper Hansen committed
102
103
104
105
        # Save model
        class EmptyModule(nn.Module):
            def __init__(self): super(EmptyModule, self).__init__()
            def forward(self, x): return x
106

Casper's avatar
Casper committed
107
108
        # Save model and config files with empty state dict
        self.model.config.quantization_config = self.quant_config.to_transformers_dict()
Casper Hansen's avatar
Casper Hansen committed
109
        self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
Casper's avatar
Casper committed
110
        self.quant_config.save_pretrained(save_dir)
111

112
113
114
115
        # Vision transformers have a processor
        if self.processor is not None:
            self.processor.save_pretrained(save_dir)

Casper Hansen's avatar
Casper Hansen committed
116
        # Remove empty state dict
117
118
119
120
        default_paths = [f'{save_dir}/model.safetensors', f'{save_dir}/pytorch_model.bin']
        for path in default_paths:
            if os.path.exists(path):
                os.remove(path)
121

Casper Hansen's avatar
Casper Hansen committed
122
123
        # model_name has no extension, add it when saving state_dict
        model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
124

Casper Hansen's avatar
Casper Hansen committed
125
126
127
128
129
130
        # shard checkpoint into chunks (10GB default)
        shards, index = shard_checkpoint(
            self.model.state_dict(), 
            max_shard_size=shard_size, 
            weights_name=model_name
        )
131

Casper Hansen's avatar
Casper Hansen committed
132
133
134
135
136
137
138
        for shard_file, shard in shards.items():
            if safetensors:
                # safetensors must be in the same memory, so we duplicate and use contiguous memory
                shard = {k: v.clone().contiguous() for k, v in shard.items()}
                save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
            else:
                torch.save(shard, os.path.join(save_dir, shard_file))
139

Casper Hansen's avatar
Casper Hansen committed
140
141
142
143
144
        # save shard index
        if index is not None:
            with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
                file.write(json.dumps(index, indent=4))
        
145
        
146
147
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
148
149
150
151
152
                        trust_remote_code=True, safetensors=False, device_map=None,
                        **model_init_kwargs):
        # Get weights path and quant config
        model_weights_path, config, quant_config = self._load_config(
            self, model_path, '', safetensors, trust_remote_code=trust_remote_code
153
        )
Casper's avatar
Casper committed
154

155
156
157
158
159
160
161
162
        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)

        processor = None
        if target_cls_name == "AutoModelForVision2Seq":
            processor = AutoProcessor.from_pretrained(model_weights_path)
            processor: CLIPImageProcessor = processor.image_processor

Casper Hansen's avatar
Casper Hansen committed
163
        # If not quantized, must load with AutoModelForCausalLM
164
        model = target_cls.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
165
166
167
168
            model_weights_path,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
            use_safetensors=safetensors,
169
            device_map=device_map,
Casper Hansen's avatar
Casper Hansen committed
170
171
172
173
174
            **model_init_kwargs
        )

        model.eval()

175
176
        return self(model, model_type, is_quantized=False, config=config, 
                    quant_config=quant_config, processor=processor)
Casper Hansen's avatar
Casper Hansen committed
177

178
    @classmethod
179
    def from_quantized(self, model_path, model_type, model_filename='', 
Casper Hansen's avatar
Casper Hansen committed
180
                             max_new_tokens=None, torch_dtype=torch.float16, 
181
                             trust_remote_code=True, safetensors=True, is_quantized=True, 
s4rduk4r's avatar
s4rduk4r committed
182
                             fuse_layers=False, version='GEMM',
183
                             device_map="balanced", offload_folder=None,
184
                             **config_kwargs):
Casper Hansen's avatar
Casper Hansen committed
185
186
187
        # [STEP 1-2] Load weights path and configs
        model_weights_path, config, quant_config = self._load_config(
            self, model_path, model_filename, safetensors, version, 
188
189
            trust_remote_code, max_new_tokens=max_new_tokens,
            **config_kwargs
Casper Hansen's avatar
Casper Hansen committed
190
        )
191
192
193

        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)
Casper Hansen's avatar
Casper Hansen committed
194
195
196
        
        # [STEP 3] Load model
        with init_empty_weights():
197
            model = target_cls.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
Casper Hansen's avatar
Casper Hansen committed
198
199
        
        # Prepare WQLinear layers, replace nn.Linear
Casper's avatar
Casper committed
200
        self._load_quantized_modules(self, model, quant_config, quant_config.version)
Casper Hansen's avatar
Casper Hansen committed
201
202
203
        
        model.tie_weights()

204
205
206
        # loads the weights into modules and distributes
        # across available devices automatically
        load_checkpoint_and_dispatch(
Casper Hansen's avatar
Casper Hansen committed
207
208
            model,
            checkpoint=model_weights_path,
s4rduk4r's avatar
s4rduk4r committed
209
            device_map=device_map,
210
            no_split_module_classes=[self.layer_type],
s4rduk4r's avatar
s4rduk4r committed
211
            offload_folder=offload_folder,
212
            dtype=torch_dtype,
Casper Hansen's avatar
Casper Hansen committed
213
214
215
        )
        
        # Dispath to devices
216
        if fuse_layers:
Casper's avatar
Casper committed
217
            self.fuse_layers(model)
s4rduk4r's avatar
s4rduk4r committed
218

219
220
        return self(model, model_type, is_quantized=is_quantized, config=config,
                    quant_config=quant_config, processor=None)
Casper Hansen's avatar
Casper Hansen committed
221

222
    def _load_config(self, model_path, model_filename, safetensors=True, 
223
224
                           version="GEMM", trust_remote_code=True, max_new_tokens=4096,
                           **config_kwargs):
225
        # [STEP 1] Download model if path is not a directory
226
        if not os.path.isdir(model_path):
227
            ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
228
            if safetensors:
229
                ignore_patterns.extend(["*.pt*", "*.bin*", "consolidated*"])
230
            else:
Casper Hansen's avatar
Casper Hansen committed
231
232
                ignore_patterns.append("*.safetensors*")
            
233
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
234
        
235
236
237
238
        if model_filename != '':
            model_weights_path = model_path + f'/{model_filename}'
        else:
            model_weights_path = model_path
239

240
        # [STEP 2] Load config and set sequence length
241
        # TODO: Create BaseAWQConfig class
Casper's avatar
Casper committed
242
        quant_config = AwqConfig.from_pretrained(model_path)
243
        
244
245
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
246
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
247
248
249
250
            config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048)
            # To add the generate support for Multi-modal models as well
            if hasattr(config, "text_config"):
                config.text_config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048)
251
252
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
253
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
254
255
            config.max_new_tokens = max_new_tokens
        
Casper Hansen's avatar
Casper Hansen committed
256
        return model_weights_path, config, quant_config
Casper's avatar
Casper committed
257

Casper Hansen's avatar
Casper Hansen committed
258
    def _load_quantized_modules(self, model, quant_config, version):
259
        # Real quantization of weights
Casper's avatar
Casper committed
260
        assert quant_config.zero_point, "We only support zero_point quantization now."
261
262
        
        # Get blocks of model
263
        layers = self.get_model_layers(model)
264

265
266
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
267
268

            # Get every linear layer in a block
269
            named_linears = get_named_linears(layer)
270

271
272
273
            # Filter out the linear layers we don't want to exclude
            named_linears = exclude_layers_to_not_quantize(named_linears, quant_config.modules_to_not_convert)

274
            # Replace activation functions
275
            self._scale_activations(self, layer)
276

277
            # Replace nn.Linear with WQLinear
278
            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
279
280
281
282
283
284
                if version == 'GEMM':
                    q_linear_module = WQLinear_GEMM
                elif version == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
285
                    module,
Casper's avatar
Casper committed
286
287
                    quant_config.w_bit,
                    quant_config.q_group_size,
Casper Hansen's avatar
Casper Hansen committed
288
289
                    True
                )
290
291
292
293
294
295
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
296
    @staticmethod
297
    def _scale_activations(self, layer):
298
        scale_dict = self.get_act_for_scaling(layer)
299

300
301
302
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
303

304
305
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
306

307
308
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
309
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)