base.py 9.75 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
5
import torch
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
6
from tqdm import tqdm
7
from typing import List, Union
8
from safetensors.torch import save_file
9
from awq.modules.act import ScaledActivation
10
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
11
from awq.utils.utils import simple_dispatch_model
12
from transformers.modeling_utils import shard_checkpoint
13
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Casper Hansen's avatar
Casper Hansen committed
14
from awq.utils.module import get_named_linears, set_op_by_name
15
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
Casper Hansen's avatar
Casper Hansen committed
16
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
17

18
class BaseAWQForCausalLM(nn.Module):
19
    def __init__(self, model, model_type, is_quantized, quant_config):
20
        super().__init__()
21
22
23
24
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
25
        self.quant_config:dict = quant_config
26
27
28
29
30
31
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
32
33
34
35
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
36

Casper Hansen's avatar
Casper Hansen committed
37
    @torch.no_grad()
Casper Hansen's avatar
Casper Hansen committed
38
39
40
    def quantize(self, tokenizer=None, quant_config={},
                       calib_data: Union[str, List[str]]="pileval", 
                       split="train", text_column="text"):
41
        self.quant_config = quant_config
42
        quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
43

Casper Hansen's avatar
Casper Hansen committed
44
45
46
47
48
49
50
51
        from awq.quantize.quantizer import AwqQuantizer

        quantizer = AwqQuantizer(
            self, self.model, tokenizer, quant_config["w_bit"], quant_config["q_group_size"],
            quant_config["version"], calib_data, split, text_column
        )
        quantizer.quantize()
        self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
52
    
qwopqwop200's avatar
qwopqwop200 committed
53
    @staticmethod
54
    def fuse_layers(model, quant_config):
qwopqwop200's avatar
qwopqwop200 committed
55
        pass
Casper's avatar
Casper committed
56

57
    def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
Casper Hansen's avatar
Casper Hansen committed
58
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
59

Casper Hansen's avatar
Casper Hansen committed
60
61
62
63
        # Save model
        class EmptyModule(nn.Module):
            def __init__(self): super(EmptyModule, self).__init__()
            def forward(self, x): return x
64

Casper Hansen's avatar
Casper Hansen committed
65
66
        # Save model files with empty state dict
        self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
67

Casper Hansen's avatar
Casper Hansen committed
68
69
        # Remove empty state dict
        os.remove(f'{save_dir}/pytorch_model.bin')
70

Casper Hansen's avatar
Casper Hansen committed
71
72
        # model_name has no extension, add it when saving state_dict
        model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
73

Casper Hansen's avatar
Casper Hansen committed
74
75
76
77
78
79
        # shard checkpoint into chunks (10GB default)
        shards, index = shard_checkpoint(
            self.model.state_dict(), 
            max_shard_size=shard_size, 
            weights_name=model_name
        )
80

Casper Hansen's avatar
Casper Hansen committed
81
82
83
84
85
86
87
        for shard_file, shard in shards.items():
            if safetensors:
                # safetensors must be in the same memory, so we duplicate and use contiguous memory
                shard = {k: v.clone().contiguous() for k, v in shard.items()}
                save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
            else:
                torch.save(shard, os.path.join(save_dir, shard_file))
88

Casper Hansen's avatar
Casper Hansen committed
89
90
91
92
        # save shard index
        if index is not None:
            with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
                file.write(json.dumps(index, indent=4))
93

Casper Hansen's avatar
Casper Hansen committed
94
95
96
97
        # Save config
        with open(f'{save_dir}/quant_config.json', 'w+') as file:
            file.write(json.dumps(self.quant_config, indent=4))
        
98
        
99
100
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
101
                        trust_remote_code=True, safetensors=False):
102
103
104
        return self.from_quantized(
            model_path, 
            model_type, 
105
            model_filename='', 
106
            max_new_tokens=None,
107
108
109
            device='balanced', 
            torch_dtype=torch_dtype, 
            trust_remote_code=trust_remote_code, 
Casper Hansen's avatar
Casper Hansen committed
110
            safetensors=safetensors,
111
112
            is_quantized=False
        )
Casper's avatar
Casper committed
113

114
    @classmethod
115
    def from_quantized(self, model_path, model_type, model_filename='', 
Casper Hansen's avatar
Casper Hansen committed
116
117
118
                             max_new_tokens=None, device='balanced', torch_dtype=torch.float16, 
                             trust_remote_code=True, safetensors=False, is_quantized=True, 
                             fuse_layers=False, version='GEMM'):
119
        # [STEP 1] Download model if path is not a directory
120
        if not os.path.isdir(model_path):
121
122
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
Casper Hansen's avatar
Casper Hansen committed
123
                ignore_patterns.extend(["*.pt*", "*.bin*"])
124
            else:
Casper Hansen's avatar
Casper Hansen committed
125
126
                ignore_patterns.append("*.safetensors*")
            
127
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
128
        
129
130
131
132
        if model_filename != '':
            model_weights_path = model_path + f'/{model_filename}'
        else:
            model_weights_path = model_path
133

134
        # [STEP 2] Load config and set sequence length
135
        # TODO: Create BaseAWQConfig class
136
137
138
139
        quant_config_path = f'{model_path}/quant_config.json'
        if os.path.exists(quant_config_path):
            with open(quant_config_path, 'r') as file:
                quant_config = json.loads(file.read())
140
141
142
            
            if "version" not in quant_config.keys():
                quant_config["version"] = version
143
144
        else:
            # Default config that works for most models
145
            quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version}
146
        
147
148
149
150
151
152
153
154
155
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = getattr(config, self.max_new_tokens_key)
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = max_new_tokens
        
156
        # [STEP 3] Load model
157
        with init_empty_weights():
158
159
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
160
        # Only need to replace layers if a model is AWQ quantized
161
162
        if is_quantized:
            # Prepare WQLinear layers, replace nn.Linear
163
            self._load_quantized_modules(self, model, quant_config, quant_config["version"])
164
165
        
        model.tie_weights()
166

167
168
169
170
171
172
        device_map = infer_auto_device_map(
            model,
            no_split_module_classes=[self.layer_type], 
            dtype=torch_dtype
        )

173
        # Load model weights
174
        if is_quantized:
Casper Hansen's avatar
Casper Hansen committed
175
176
            load_checkpoint_in_model(
                model,
177
                checkpoint=model_weights_path,
Casper Hansen's avatar
Casper Hansen committed
178
                device_map=device_map
179
            )
Casper Hansen's avatar
Casper Hansen committed
180
181
182
            
            model = simple_dispatch_model(model, device_map)
            
183
            if fuse_layers:
184
                self.fuse_layers(model, quant_config)
185

186
187
        else:
            # If not quantized, must load with AutoModelForCausalLM
188
189
190
191
            del model
            
            # Load model weights
            model = AutoModelForCausalLM.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
192
                model_weights_path, 
193
194
195
196
197
198
                device_map=device_map, 
                trust_remote_code=trust_remote_code, 
                offload_folder="offload", 
                offload_state_dict=True, 
                torch_dtype=torch_dtype, 
                use_safetensors=safetensors
199
200
            )
            model.eval()
201

202
        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
Casper's avatar
Casper committed
203

Casper Hansen's avatar
Casper Hansen committed
204
    def _load_quantized_modules(self, model, quant_config, version):
205
        # Real quantization of weights
206
        assert quant_config["zero_point"], "We only support zero_point quantization now."
207
208
        
        # Get blocks of model
209
        layers = self.get_model_layers(model)
210

211
212
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
213
214

            # Get every linear layer in a block
215
            named_linears = get_named_linears(layer)
216
217

            # Replace activation functions
218
            self._scale_activations(self, layer)
219

220
            # Replace nn.Linear with WQLinear
221
            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
222
223
224
225
226
227
                if version == 'GEMM':
                    q_linear_module = WQLinear_GEMM
                elif version == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
228
229
230
                    module,
                    quant_config['w_bit'],
                    quant_config['q_group_size'],
Casper Hansen's avatar
Casper Hansen committed
231
232
                    True
                )
233
234
235
236
237
238
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
239
    @staticmethod
240
    def _scale_activations(self, layer):
241
        scale_dict = self.get_act_for_scaling(layer)
242

243
244
245
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
246

247
248
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
249

250
251
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
252
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)