base.py 10.4 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
5
import torch
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
6
from tqdm import tqdm
Casper's avatar
Casper committed
7
from typing import List, Union
8
from safetensors.torch import save_file
Casper's avatar
Casper committed
9
from awq.models._config import AwqConfig
10
from awq.modules.act import ScaledActivation
11
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
12
from awq.quantize.quantizer import AwqQuantizer
13
from transformers.modeling_utils import shard_checkpoint
14
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Casper Hansen's avatar
Casper Hansen committed
15
from awq.utils.module import get_named_linears, set_op_by_name
16
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
Casper Hansen's avatar
Casper Hansen committed
17
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
18

Casper's avatar
Casper committed
19

20
class BaseAWQForCausalLM(nn.Module):
21
    def __init__(self, model, model_type, is_quantized, quant_config):
22
        super().__init__()
23
24
25
26
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
Casper's avatar
Casper committed
27
        self.quant_config: AwqConfig = quant_config
28
29
30
31
32
33
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
34
35
36
37
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
38

Casper Hansen's avatar
Casper Hansen committed
39
    @torch.no_grad()
Casper Hansen's avatar
Casper Hansen committed
40
41
    def quantize(self, tokenizer=None, quant_config={},
                       calib_data: Union[str, List[str]]="pileval", 
Casper Hansen's avatar
Casper Hansen committed
42
                       split="train", text_column="text"):
Casper's avatar
Casper committed
43
        self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
44

Casper Hansen's avatar
Casper Hansen committed
45
        quantizer = AwqQuantizer(
Casper's avatar
Casper committed
46
47
            self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size,
            self.quant_config.version, calib_data, split, text_column
Casper Hansen's avatar
Casper Hansen committed
48
49
50
        )
        quantizer.quantize()
        self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
51
    
qwopqwop200's avatar
qwopqwop200 committed
52
    @staticmethod
Casper's avatar
Casper committed
53
    def fuse_layers(model):
qwopqwop200's avatar
qwopqwop200 committed
54
        pass
Casper's avatar
Casper committed
55

56
    def save_quantized(self, save_dir, safetensors=True, shard_size="10GB"):
Casper Hansen's avatar
Casper Hansen committed
57
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
58

Casper Hansen's avatar
Casper Hansen committed
59
60
61
62
        # Save model
        class EmptyModule(nn.Module):
            def __init__(self): super(EmptyModule, self).__init__()
            def forward(self, x): return x
63

Casper's avatar
Casper committed
64
65
        # Save model and config files with empty state dict
        self.model.config.quantization_config = self.quant_config.to_transformers_dict()
Casper Hansen's avatar
Casper Hansen committed
66
        self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
Casper's avatar
Casper committed
67
        self.quant_config.save_pretrained(save_dir)
68

Casper Hansen's avatar
Casper Hansen committed
69
        # Remove empty state dict
70
71
72
73
        default_paths = [f'{save_dir}/model.safetensors', f'{save_dir}/pytorch_model.bin']
        for path in default_paths:
            if os.path.exists(path):
                os.remove(path)
74

Casper Hansen's avatar
Casper Hansen committed
75
76
        # model_name has no extension, add it when saving state_dict
        model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
77

Casper Hansen's avatar
Casper Hansen committed
78
79
80
81
82
83
        # shard checkpoint into chunks (10GB default)
        shards, index = shard_checkpoint(
            self.model.state_dict(), 
            max_shard_size=shard_size, 
            weights_name=model_name
        )
84

Casper Hansen's avatar
Casper Hansen committed
85
86
87
88
89
90
91
        for shard_file, shard in shards.items():
            if safetensors:
                # safetensors must be in the same memory, so we duplicate and use contiguous memory
                shard = {k: v.clone().contiguous() for k, v in shard.items()}
                save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
            else:
                torch.save(shard, os.path.join(save_dir, shard_file))
92

Casper Hansen's avatar
Casper Hansen committed
93
94
95
96
97
        # save shard index
        if index is not None:
            with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
                file.write(json.dumps(index, indent=4))
        
98
        
99
100
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
101
102
103
104
105
                        trust_remote_code=True, safetensors=False, device_map=None,
                        **model_init_kwargs):
        # Get weights path and quant config
        model_weights_path, config, quant_config = self._load_config(
            self, model_path, '', safetensors, trust_remote_code=trust_remote_code
106
        )
Casper's avatar
Casper committed
107

Casper Hansen's avatar
Casper Hansen committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        if device_map is None:
            with init_empty_weights():
                model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)

            # Get device map
            device_map = infer_auto_device_map(
                model,
                no_split_module_classes=[self.layer_type], 
                dtype=torch_dtype
            )
            del model

        # If not quantized, must load with AutoModelForCausalLM
        model = AutoModelForCausalLM.from_pretrained(
            model_weights_path,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
            use_safetensors=safetensors,
            **model_init_kwargs
        )

        model.eval()

        return self(model, model_type, is_quantized=False, quant_config=quant_config)

133
    @classmethod
134
    def from_quantized(self, model_path, model_type, model_filename='', 
Casper Hansen's avatar
Casper Hansen committed
135
                             max_new_tokens=None, torch_dtype=torch.float16, 
136
                             trust_remote_code=True, safetensors=True, is_quantized=True, 
s4rduk4r's avatar
s4rduk4r committed
137
                             fuse_layers=False, version='GEMM',
138
139
                             max_memory=None, offload_folder=None,
                             **config_kwargs):
Casper Hansen's avatar
Casper Hansen committed
140
141
142
        # [STEP 1-2] Load weights path and configs
        model_weights_path, config, quant_config = self._load_config(
            self, model_path, model_filename, safetensors, version, 
143
144
            trust_remote_code, max_new_tokens=max_new_tokens,
            **config_kwargs
Casper Hansen's avatar
Casper Hansen committed
145
146
147
148
149
150
151
        )
        
        # [STEP 3] Load model
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
        # Prepare WQLinear layers, replace nn.Linear
Casper's avatar
Casper committed
152
        self._load_quantized_modules(self, model, quant_config, quant_config.version)
Casper Hansen's avatar
Casper Hansen committed
153
154
155
156
157
158
159
        
        model.tie_weights()

        # Get device map
        device_map = infer_auto_device_map(
            model,
            no_split_module_classes=[self.layer_type], 
s4rduk4r's avatar
s4rduk4r committed
160
            max_memory=max_memory,
Casper Hansen's avatar
Casper Hansen committed
161
            dtype=torch_dtype
s4rduk4r's avatar
Clean  
s4rduk4r committed
162
        )
Casper Hansen's avatar
Casper Hansen committed
163
164
165
166
167

        # Load checkpoint
        load_checkpoint_in_model(
            model,
            checkpoint=model_weights_path,
s4rduk4r's avatar
s4rduk4r committed
168
169
170
            device_map=device_map,
            offload_folder=offload_folder,
            dtype=torch_dtype
Casper Hansen's avatar
Casper Hansen committed
171
172
173
        )
        
        # Dispath to devices
174
        if fuse_layers:
Casper's avatar
Casper committed
175
            self.fuse_layers(model)
s4rduk4r's avatar
s4rduk4r committed
176

177
178
179
180
181
182
183
        # Offloading dispatch
        from accelerate import dispatch_model
        model = dispatch_model(
            model,
            device_map=device_map,
            offload_dir=offload_folder
        )
Casper Hansen's avatar
Casper Hansen committed
184
185
186
187
        

        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)

188
    def _load_config(self, model_path, model_filename, safetensors=True, 
189
190
                           version="GEMM", trust_remote_code=True, max_new_tokens=4096,
                           **config_kwargs):
191
        # [STEP 1] Download model if path is not a directory
192
        if not os.path.isdir(model_path):
193
            ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
194
            if safetensors:
Casper Hansen's avatar
Casper Hansen committed
195
                ignore_patterns.extend(["*.pt*", "*.bin*"])
196
            else:
Casper Hansen's avatar
Casper Hansen committed
197
198
                ignore_patterns.append("*.safetensors*")
            
199
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
200
        
201
202
203
204
        if model_filename != '':
            model_weights_path = model_path + f'/{model_filename}'
        else:
            model_weights_path = model_path
205

206
        # [STEP 2] Load config and set sequence length
207
        # TODO: Create BaseAWQConfig class
Casper's avatar
Casper committed
208
        quant_config = AwqConfig.from_pretrained(model_path)
209
        
210
211
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
212
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
213
214
215
            config.max_new_tokens = getattr(config, self.max_new_tokens_key)
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
216
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
217
218
            config.max_new_tokens = max_new_tokens
        
Casper Hansen's avatar
Casper Hansen committed
219
        return model_weights_path, config, quant_config
Casper's avatar
Casper committed
220

Casper Hansen's avatar
Casper Hansen committed
221
    def _load_quantized_modules(self, model, quant_config, version):
222
        # Real quantization of weights
Casper's avatar
Casper committed
223
        assert quant_config.zero_point, "We only support zero_point quantization now."
224
225
        
        # Get blocks of model
226
        layers = self.get_model_layers(model)
227

228
229
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
230
231

            # Get every linear layer in a block
232
            named_linears = get_named_linears(layer)
233
234

            # Replace activation functions
235
            self._scale_activations(self, layer)
236

237
            # Replace nn.Linear with WQLinear
238
            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
239
240
241
242
243
244
                if version == 'GEMM':
                    q_linear_module = WQLinear_GEMM
                elif version == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
245
                    module,
Casper's avatar
Casper committed
246
247
                    quant_config.w_bit,
                    quant_config.q_group_size,
Casper Hansen's avatar
Casper Hansen committed
248
249
                    True
                )
250
251
252
253
254
255
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
256
    @staticmethod
257
    def _scale_activations(self, layer):
258
        scale_dict = self.get_act_for_scaling(layer)
259

260
261
262
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
263

264
265
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
266

267
268
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
269
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)