base.py 11 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
5
import torch
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
6
from tqdm import tqdm
Vik Paruchuri's avatar
Vik Paruchuri committed
7
from typing import List, Union, Dict
8
from safetensors.torch import save_file
9
from awq.modules.act import ScaledActivation
10
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
11
from awq.quantize.quantizer import AwqQuantizer
Casper Hansen's avatar
Casper Hansen committed
12
from awq.utils.utils import simple_dispatch_model
13
from transformers.modeling_utils import shard_checkpoint
14
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Casper Hansen's avatar
Casper Hansen committed
15
from awq.utils.module import get_named_linears, set_op_by_name
16
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
Casper Hansen's avatar
Casper Hansen committed
17
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
18

19
class BaseAWQForCausalLM(nn.Module):
20
    def __init__(self, model, model_type, is_quantized, quant_config):
21
        super().__init__()
22
23
24
25
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
Vik Paruchuri's avatar
Vik Paruchuri committed
26
        self.quant_config: Dict = quant_config
27
28
29
30
31
32
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
33
34
35
36
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
37

Casper Hansen's avatar
Casper Hansen committed
38
    @torch.no_grad()
Casper Hansen's avatar
Casper Hansen committed
39
40
    def quantize(self, tokenizer=None, quant_config={},
                       calib_data: Union[str, List[str]]="pileval", 
Casper Hansen's avatar
Casper Hansen committed
41
                       split="train", text_column="text"):
42
        self.quant_config = quant_config
43
        quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
44

Casper Hansen's avatar
Casper Hansen committed
45
46
        quantizer = AwqQuantizer(
            self, self.model, tokenizer, quant_config["w_bit"], quant_config["q_group_size"],
Casper Hansen's avatar
Casper Hansen committed
47
            quant_config["version"], calib_data, split, text_column
Casper Hansen's avatar
Casper Hansen committed
48
49
50
        )
        quantizer.quantize()
        self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
51
    
qwopqwop200's avatar
qwopqwop200 committed
52
    @staticmethod
53
    def fuse_layers(model, quant_config):
qwopqwop200's avatar
qwopqwop200 committed
54
        pass
Casper's avatar
Casper committed
55

56
    def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
Casper Hansen's avatar
Casper Hansen committed
57
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
58

Casper Hansen's avatar
Casper Hansen committed
59
60
61
62
        # Save model
        class EmptyModule(nn.Module):
            def __init__(self): super(EmptyModule, self).__init__()
            def forward(self, x): return x
63

Casper Hansen's avatar
Casper Hansen committed
64
65
        # Save model files with empty state dict
        self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
66

Casper Hansen's avatar
Casper Hansen committed
67
68
        # Remove empty state dict
        os.remove(f'{save_dir}/pytorch_model.bin')
69

Casper Hansen's avatar
Casper Hansen committed
70
71
        # model_name has no extension, add it when saving state_dict
        model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
72

Casper Hansen's avatar
Casper Hansen committed
73
74
75
76
77
78
        # shard checkpoint into chunks (10GB default)
        shards, index = shard_checkpoint(
            self.model.state_dict(), 
            max_shard_size=shard_size, 
            weights_name=model_name
        )
79

Casper Hansen's avatar
Casper Hansen committed
80
81
82
83
84
85
86
        for shard_file, shard in shards.items():
            if safetensors:
                # safetensors must be in the same memory, so we duplicate and use contiguous memory
                shard = {k: v.clone().contiguous() for k, v in shard.items()}
                save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
            else:
                torch.save(shard, os.path.join(save_dir, shard_file))
87

Casper Hansen's avatar
Casper Hansen committed
88
89
90
91
        # save shard index
        if index is not None:
            with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
                file.write(json.dumps(index, indent=4))
92

Casper Hansen's avatar
Casper Hansen committed
93
94
95
96
        # Save config
        with open(f'{save_dir}/quant_config.json', 'w+') as file:
            file.write(json.dumps(self.quant_config, indent=4))
        
97
        
98
99
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
100
101
102
103
104
                        trust_remote_code=True, safetensors=False, device_map=None,
                        **model_init_kwargs):
        # Get weights path and quant config
        model_weights_path, config, quant_config = self._load_config(
            self, model_path, '', safetensors, trust_remote_code=trust_remote_code
105
        )
Casper's avatar
Casper committed
106

Casper Hansen's avatar
Casper Hansen committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        if device_map is None:
            with init_empty_weights():
                model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)

            # Get device map
            device_map = infer_auto_device_map(
                model,
                no_split_module_classes=[self.layer_type], 
                dtype=torch_dtype
            )
            del model

        # If not quantized, must load with AutoModelForCausalLM
        model = AutoModelForCausalLM.from_pretrained(
            model_weights_path,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
            use_safetensors=safetensors,
            **model_init_kwargs
        )

        model.eval()

        return self(model, model_type, is_quantized=False, quant_config=quant_config)

132
    @classmethod
133
    def from_quantized(self, model_path, model_type, model_filename='', 
Casper Hansen's avatar
Casper Hansen committed
134
                             max_new_tokens=None, torch_dtype=torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
135
                             trust_remote_code=True, safetensors=False, is_quantized=True, 
s4rduk4r's avatar
s4rduk4r committed
136
137
                             fuse_layers=False, version='GEMM',
                             max_memory=None, offload_folder=None):
Casper Hansen's avatar
Casper Hansen committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        # [STEP 1-2] Load weights path and configs
        model_weights_path, config, quant_config = self._load_config(
            self, model_path, model_filename, safetensors, version, 
            trust_remote_code, max_new_tokens=max_new_tokens
        )
        
        # [STEP 3] Load model
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
        # Prepare WQLinear layers, replace nn.Linear
        self._load_quantized_modules(self, model, quant_config, quant_config["version"])
        
        model.tie_weights()

        # Get device map
        device_map = infer_auto_device_map(
            model,
            no_split_module_classes=[self.layer_type], 
s4rduk4r's avatar
s4rduk4r committed
157
            max_memory=max_memory,
Casper Hansen's avatar
Casper Hansen committed
158
            dtype=torch_dtype
s4rduk4r's avatar
s4rduk4r committed
159
            )
Casper Hansen's avatar
Casper Hansen committed
160
161
162
163
164

        # Load checkpoint
        load_checkpoint_in_model(
            model,
            checkpoint=model_weights_path,
s4rduk4r's avatar
s4rduk4r committed
165
166
167
            device_map=device_map,
            offload_folder=offload_folder,
            dtype=torch_dtype
Casper Hansen's avatar
Casper Hansen committed
168
169
170
        )
        
        # Dispath to devices
s4rduk4r's avatar
s4rduk4r committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        if max_memory is None:
            # VRAM only
            model = simple_dispatch_model(model, device_map)

            if fuse_layers:
                self.fuse_layers(model, quant_config)
        else:
            if fuse_layers:
                self.fuse_layers(model, quant_config)

            # Offloading dispatch
            from accelerate import dispatch_model
            model = dispatch_model(
                model,
                device_map=device_map,
                # offload_buffers=offload_folder is not None,
                offload_dir=offload_folder
            )
Casper Hansen's avatar
Casper Hansen committed
189
190
191
192
193
194
        

        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)

    def _load_config(self, model_path, model_filename, safetensors=False, 
                           version="GEMM", trust_remote_code=True, max_new_tokens=4096):
195
        # [STEP 1] Download model if path is not a directory
196
        if not os.path.isdir(model_path):
197
198
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
Casper Hansen's avatar
Casper Hansen committed
199
                ignore_patterns.extend(["*.pt*", "*.bin*"])
200
            else:
Casper Hansen's avatar
Casper Hansen committed
201
202
                ignore_patterns.append("*.safetensors*")
            
203
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
204
        
205
206
207
208
        if model_filename != '':
            model_weights_path = model_path + f'/{model_filename}'
        else:
            model_weights_path = model_path
209

210
        # [STEP 2] Load config and set sequence length
211
        # TODO: Create BaseAWQConfig class
212
213
214
215
        quant_config_path = f'{model_path}/quant_config.json'
        if os.path.exists(quant_config_path):
            with open(quant_config_path, 'r') as file:
                quant_config = json.loads(file.read())
216
217
218
            
            if "version" not in quant_config.keys():
                quant_config["version"] = version
219
220
        else:
            # Default config that works for most models
221
            quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version}
222
        
223
224
225
226
227
228
229
230
231
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = getattr(config, self.max_new_tokens_key)
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = max_new_tokens
        
Casper Hansen's avatar
Casper Hansen committed
232
        return model_weights_path, config, quant_config
Casper's avatar
Casper committed
233

Casper Hansen's avatar
Casper Hansen committed
234
    def _load_quantized_modules(self, model, quant_config, version):
235
        # Real quantization of weights
236
        assert quant_config["zero_point"], "We only support zero_point quantization now."
237
238
        
        # Get blocks of model
239
        layers = self.get_model_layers(model)
240

241
242
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
243
244

            # Get every linear layer in a block
245
            named_linears = get_named_linears(layer)
246
247

            # Replace activation functions
248
            self._scale_activations(self, layer)
249

250
            # Replace nn.Linear with WQLinear
251
            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
252
253
254
255
256
257
                if version == 'GEMM':
                    q_linear_module = WQLinear_GEMM
                elif version == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
258
259
260
                    module,
                    quant_config['w_bit'],
                    quant_config['q_group_size'],
Casper Hansen's avatar
Casper Hansen committed
261
262
                    True
                )
263
264
265
266
267
268
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
269
    @staticmethod
270
    def _scale_activations(self, layer):
271
        scale_dict = self.get_act_for_scaling(layer)
272

273
274
275
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
276

277
278
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
279

280
281
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
282
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)