base.py 9.78 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
5
import torch
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
6
from tqdm import tqdm
7
from typing import List, Union
8
from safetensors.torch import save_file
9
from awq.modules.act import ScaledActivation
10
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
11
from awq.quantize.quantizer import AwqQuantizer
Casper Hansen's avatar
Casper Hansen committed
12
from awq.utils.utils import simple_dispatch_model
13
from transformers.modeling_utils import shard_checkpoint
14
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Casper Hansen's avatar
Casper Hansen committed
15
from awq.utils.module import get_named_linears, set_op_by_name
16
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
Casper Hansen's avatar
Casper Hansen committed
17
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
18

19
class BaseAWQForCausalLM(nn.Module):
20
    def __init__(self, model, model_type, is_quantized, quant_config):
21
        super().__init__()
22
23
24
25
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
26
        self.quant_config:dict = quant_config
27
28
29
30
31
32
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
33
34
35
36
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
37

Casper Hansen's avatar
Casper Hansen committed
38
    @torch.no_grad()
Casper Hansen's avatar
Casper Hansen committed
39
40
    def quantize(self, tokenizer=None, quant_config={},
                       calib_data: Union[str, List[str]]="pileval", 
41
                       split="train", text_column="text", loss_objective='mse'):
42
        self.quant_config = quant_config
43
        quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
44

Casper Hansen's avatar
Casper Hansen committed
45
46
        quantizer = AwqQuantizer(
            self, self.model, tokenizer, quant_config["w_bit"], quant_config["q_group_size"],
47
            quant_config["version"], calib_data, split, text_column, loss_objective
Casper Hansen's avatar
Casper Hansen committed
48
49
50
        )
        quantizer.quantize()
        self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
51
    
qwopqwop200's avatar
qwopqwop200 committed
52
    @staticmethod
53
    def fuse_layers(model, quant_config):
qwopqwop200's avatar
qwopqwop200 committed
54
        pass
Casper's avatar
Casper committed
55

56
    def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
Casper Hansen's avatar
Casper Hansen committed
57
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
58

Casper Hansen's avatar
Casper Hansen committed
59
60
61
62
        # Save model
        class EmptyModule(nn.Module):
            def __init__(self): super(EmptyModule, self).__init__()
            def forward(self, x): return x
63

Casper Hansen's avatar
Casper Hansen committed
64
65
        # Save model files with empty state dict
        self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
66

Casper Hansen's avatar
Casper Hansen committed
67
68
        # Remove empty state dict
        os.remove(f'{save_dir}/pytorch_model.bin')
69

Casper Hansen's avatar
Casper Hansen committed
70
71
        # model_name has no extension, add it when saving state_dict
        model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
72

Casper Hansen's avatar
Casper Hansen committed
73
74
75
76
77
78
        # shard checkpoint into chunks (10GB default)
        shards, index = shard_checkpoint(
            self.model.state_dict(), 
            max_shard_size=shard_size, 
            weights_name=model_name
        )
79

Casper Hansen's avatar
Casper Hansen committed
80
81
82
83
84
85
86
        for shard_file, shard in shards.items():
            if safetensors:
                # safetensors must be in the same memory, so we duplicate and use contiguous memory
                shard = {k: v.clone().contiguous() for k, v in shard.items()}
                save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
            else:
                torch.save(shard, os.path.join(save_dir, shard_file))
87

Casper Hansen's avatar
Casper Hansen committed
88
89
90
91
        # save shard index
        if index is not None:
            with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
                file.write(json.dumps(index, indent=4))
92

Casper Hansen's avatar
Casper Hansen committed
93
94
95
96
        # Save config
        with open(f'{save_dir}/quant_config.json', 'w+') as file:
            file.write(json.dumps(self.quant_config, indent=4))
        
97
        
98
99
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
100
                        trust_remote_code=True, safetensors=False):
101
102
103
        return self.from_quantized(
            model_path, 
            model_type, 
104
            model_filename='', 
105
            max_new_tokens=None,
106
107
108
            device='balanced', 
            torch_dtype=torch_dtype, 
            trust_remote_code=trust_remote_code, 
Casper Hansen's avatar
Casper Hansen committed
109
            safetensors=safetensors,
110
111
            is_quantized=False
        )
Casper's avatar
Casper committed
112

113
    @classmethod
114
    def from_quantized(self, model_path, model_type, model_filename='', 
Casper Hansen's avatar
Casper Hansen committed
115
116
117
                             max_new_tokens=None, device='balanced', torch_dtype=torch.float16, 
                             trust_remote_code=True, safetensors=False, is_quantized=True, 
                             fuse_layers=False, version='GEMM'):
118
        # [STEP 1] Download model if path is not a directory
119
        if not os.path.isdir(model_path):
120
121
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
Casper Hansen's avatar
Casper Hansen committed
122
                ignore_patterns.extend(["*.pt*", "*.bin*"])
123
            else:
Casper Hansen's avatar
Casper Hansen committed
124
125
                ignore_patterns.append("*.safetensors*")
            
126
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
127
        
128
129
130
131
        if model_filename != '':
            model_weights_path = model_path + f'/{model_filename}'
        else:
            model_weights_path = model_path
132

133
        # [STEP 2] Load config and set sequence length
134
        # TODO: Create BaseAWQConfig class
135
136
137
138
        quant_config_path = f'{model_path}/quant_config.json'
        if os.path.exists(quant_config_path):
            with open(quant_config_path, 'r') as file:
                quant_config = json.loads(file.read())
139
140
141
            
            if "version" not in quant_config.keys():
                quant_config["version"] = version
142
143
        else:
            # Default config that works for most models
144
            quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version}
145
        
146
147
148
149
150
151
152
153
154
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = getattr(config, self.max_new_tokens_key)
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = max_new_tokens
        
155
        # [STEP 3] Load model
156
        with init_empty_weights():
157
158
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
159
        # Only need to replace layers if a model is AWQ quantized
160
161
        if is_quantized:
            # Prepare WQLinear layers, replace nn.Linear
162
            self._load_quantized_modules(self, model, quant_config, quant_config["version"])
163
164
        
        model.tie_weights()
165

166
167
168
169
170
171
        device_map = infer_auto_device_map(
            model,
            no_split_module_classes=[self.layer_type], 
            dtype=torch_dtype
        )

172
        # Load model weights
173
        if is_quantized:
Casper Hansen's avatar
Casper Hansen committed
174
175
            load_checkpoint_in_model(
                model,
176
                checkpoint=model_weights_path,
Casper Hansen's avatar
Casper Hansen committed
177
                device_map=device_map
178
            )
Casper Hansen's avatar
Casper Hansen committed
179
180
181
            
            model = simple_dispatch_model(model, device_map)
            
182
            if fuse_layers:
183
                self.fuse_layers(model, quant_config)
184

185
186
        else:
            # If not quantized, must load with AutoModelForCausalLM
187
188
189
190
            del model
            
            # Load model weights
            model = AutoModelForCausalLM.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
191
                model_weights_path, 
192
193
194
195
196
197
                device_map=device_map, 
                trust_remote_code=trust_remote_code, 
                offload_folder="offload", 
                offload_state_dict=True, 
                torch_dtype=torch_dtype, 
                use_safetensors=safetensors
198
199
            )
            model.eval()
200

201
        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
Casper's avatar
Casper committed
202

Casper Hansen's avatar
Casper Hansen committed
203
    def _load_quantized_modules(self, model, quant_config, version):
204
        # Real quantization of weights
205
        assert quant_config["zero_point"], "We only support zero_point quantization now."
206
207
        
        # Get blocks of model
208
        layers = self.get_model_layers(model)
209

210
211
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
212
213

            # Get every linear layer in a block
214
            named_linears = get_named_linears(layer)
215
216

            # Replace activation functions
217
            self._scale_activations(self, layer)
218

219
            # Replace nn.Linear with WQLinear
220
            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
221
222
223
224
225
226
                if version == 'GEMM':
                    q_linear_module = WQLinear_GEMM
                elif version == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
227
228
229
                    module,
                    quant_config['w_bit'],
                    quant_config['q_group_size'],
Casper Hansen's avatar
Casper Hansen committed
230
231
                    True
                )
232
233
234
235
236
237
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
238
    @staticmethod
239
    def _scale_activations(self, layer):
240
        scale_dict = self.get_act_for_scaling(layer)
241

242
243
244
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
245

246
247
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
248

249
250
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
251
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)