base.py 10 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
5
import torch
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
6
from tqdm import tqdm
Casper's avatar
Casper committed
7
from typing import List, Union
8
from safetensors.torch import save_file
Casper's avatar
Casper committed
9
from awq.models._config import AwqConfig
10
from awq.modules.act import ScaledActivation
11
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
12
from awq.quantize.quantizer import AwqQuantizer
13
from transformers.modeling_utils import shard_checkpoint
14
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Casper Hansen's avatar
Casper Hansen committed
15
from awq.utils.module import get_named_linears, set_op_by_name
Casper's avatar
Casper committed
16
17
18
19
20
21
from transformers import (
    AutoModelForCausalLM,
    AutoConfig,
    PreTrainedModel,
    PretrainedConfig,
)
22
23
24
25
26
27
from accelerate.big_modeling import (
    init_empty_weights,
    infer_auto_device_map,
    load_checkpoint_and_dispatch,
)
from accelerate.utils import get_balanced_memory
Casper's avatar
Casper committed
28

29
class BaseAWQForCausalLM(nn.Module):
Casper's avatar
Casper committed
30
    def __init__(self, model, model_type, is_quantized, config, quant_config):
31
        super().__init__()
32
33
34
35
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
Casper's avatar
Casper committed
36
        self.config: PretrainedConfig = config
Casper's avatar
Casper committed
37
        self.quant_config: AwqConfig = quant_config
38
39
40
41
42
43
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
44
45
46
47
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
48

Casper Hansen's avatar
Casper Hansen committed
49
    @torch.no_grad()
Casper Hansen's avatar
Casper Hansen committed
50
51
    def quantize(self, tokenizer=None, quant_config={},
                       calib_data: Union[str, List[str]]="pileval", 
52
                       split="train", text_column="text", duo_scaling=True, modules_to_not_convert=None):
Casper's avatar
Casper committed
53
        self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
54

Casper Hansen's avatar
Casper Hansen committed
55
        quantizer = AwqQuantizer(
Casper's avatar
Casper committed
56
            self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size,
57
            self.quant_config.version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=modules_to_not_convert
Casper Hansen's avatar
Casper Hansen committed
58
59
60
        )
        quantizer.quantize()
        self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
61
    
qwopqwop200's avatar
qwopqwop200 committed
62
    @staticmethod
Casper's avatar
Casper committed
63
    def fuse_layers(model):
qwopqwop200's avatar
qwopqwop200 committed
64
        pass
Casper's avatar
Casper committed
65

66
    def save_quantized(self, save_dir, safetensors=True, shard_size="10GB"):
Casper Hansen's avatar
Casper Hansen committed
67
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
68

Casper Hansen's avatar
Casper Hansen committed
69
70
71
72
        # Save model
        class EmptyModule(nn.Module):
            def __init__(self): super(EmptyModule, self).__init__()
            def forward(self, x): return x
73

Casper's avatar
Casper committed
74
75
        # Save model and config files with empty state dict
        self.model.config.quantization_config = self.quant_config.to_transformers_dict()
Casper Hansen's avatar
Casper Hansen committed
76
        self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
Casper's avatar
Casper committed
77
        self.quant_config.save_pretrained(save_dir)
78

Casper Hansen's avatar
Casper Hansen committed
79
        # Remove empty state dict
80
81
82
83
        default_paths = [f'{save_dir}/model.safetensors', f'{save_dir}/pytorch_model.bin']
        for path in default_paths:
            if os.path.exists(path):
                os.remove(path)
84

Casper Hansen's avatar
Casper Hansen committed
85
86
        # model_name has no extension, add it when saving state_dict
        model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
87

Casper Hansen's avatar
Casper Hansen committed
88
89
90
91
92
93
        # shard checkpoint into chunks (10GB default)
        shards, index = shard_checkpoint(
            self.model.state_dict(), 
            max_shard_size=shard_size, 
            weights_name=model_name
        )
94

Casper Hansen's avatar
Casper Hansen committed
95
96
97
98
99
100
101
        for shard_file, shard in shards.items():
            if safetensors:
                # safetensors must be in the same memory, so we duplicate and use contiguous memory
                shard = {k: v.clone().contiguous() for k, v in shard.items()}
                save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
            else:
                torch.save(shard, os.path.join(save_dir, shard_file))
102

Casper Hansen's avatar
Casper Hansen committed
103
104
105
106
107
        # save shard index
        if index is not None:
            with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
                file.write(json.dumps(index, indent=4))
        
108
        
109
110
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
111
112
113
114
115
                        trust_remote_code=True, safetensors=False, device_map=None,
                        **model_init_kwargs):
        # Get weights path and quant config
        model_weights_path, config, quant_config = self._load_config(
            self, model_path, '', safetensors, trust_remote_code=trust_remote_code
116
        )
Casper's avatar
Casper committed
117

Casper Hansen's avatar
Casper Hansen committed
118
119
120
121
122
123
        # If not quantized, must load with AutoModelForCausalLM
        model = AutoModelForCausalLM.from_pretrained(
            model_weights_path,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
            use_safetensors=safetensors,
124
            device_map=device_map,
Casper Hansen's avatar
Casper Hansen committed
125
126
127
128
129
            **model_init_kwargs
        )

        model.eval()

Casper's avatar
Casper committed
130
        return self(model, model_type, is_quantized=False, config=config, quant_config=quant_config)
Casper Hansen's avatar
Casper Hansen committed
131

132
    @classmethod
133
    def from_quantized(self, model_path, model_type, model_filename='', 
Casper Hansen's avatar
Casper Hansen committed
134
                             max_new_tokens=None, torch_dtype=torch.float16, 
135
                             trust_remote_code=True, safetensors=True, is_quantized=True, 
s4rduk4r's avatar
s4rduk4r committed
136
                             fuse_layers=False, version='GEMM',
137
                             device_map="balanced", offload_folder=None,
138
                             **config_kwargs):
Casper Hansen's avatar
Casper Hansen committed
139
140
141
        # [STEP 1-2] Load weights path and configs
        model_weights_path, config, quant_config = self._load_config(
            self, model_path, model_filename, safetensors, version, 
142
143
            trust_remote_code, max_new_tokens=max_new_tokens,
            **config_kwargs
Casper Hansen's avatar
Casper Hansen committed
144
145
146
147
148
149
150
        )
        
        # [STEP 3] Load model
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
        # Prepare WQLinear layers, replace nn.Linear
Casper's avatar
Casper committed
151
        self._load_quantized_modules(self, model, quant_config, quant_config.version)
Casper Hansen's avatar
Casper Hansen committed
152
153
154
        
        model.tie_weights()

155
156
157
        # loads the weights into modules and distributes
        # across available devices automatically
        load_checkpoint_and_dispatch(
Casper Hansen's avatar
Casper Hansen committed
158
159
            model,
            checkpoint=model_weights_path,
s4rduk4r's avatar
s4rduk4r committed
160
            device_map=device_map,
161
            no_split_module_classes=[self.layer_type],
s4rduk4r's avatar
s4rduk4r committed
162
            offload_folder=offload_folder,
163
            dtype=torch_dtype,
Casper Hansen's avatar
Casper Hansen committed
164
165
166
        )
        
        # Dispath to devices
167
        if fuse_layers:
Casper's avatar
Casper committed
168
            self.fuse_layers(model)
s4rduk4r's avatar
s4rduk4r committed
169

Casper's avatar
Casper committed
170
        return self(model, model_type, is_quantized=is_quantized, config=config, quant_config=quant_config)
Casper Hansen's avatar
Casper Hansen committed
171

172
    def _load_config(self, model_path, model_filename, safetensors=True, 
173
174
                           version="GEMM", trust_remote_code=True, max_new_tokens=4096,
                           **config_kwargs):
175
        # [STEP 1] Download model if path is not a directory
176
        if not os.path.isdir(model_path):
177
            ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
178
            if safetensors:
Casper Hansen's avatar
Casper Hansen committed
179
                ignore_patterns.extend(["*.pt*", "*.bin*"])
180
            else:
Casper Hansen's avatar
Casper Hansen committed
181
182
                ignore_patterns.append("*.safetensors*")
            
183
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
184
        
185
186
187
188
        if model_filename != '':
            model_weights_path = model_path + f'/{model_filename}'
        else:
            model_weights_path = model_path
189

190
        # [STEP 2] Load config and set sequence length
191
        # TODO: Create BaseAWQConfig class
Casper's avatar
Casper committed
192
        quant_config = AwqConfig.from_pretrained(model_path)
193
        
194
195
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
196
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
197
198
199
            config.max_new_tokens = getattr(config, self.max_new_tokens_key)
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
200
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
201
202
            config.max_new_tokens = max_new_tokens
        
Casper Hansen's avatar
Casper Hansen committed
203
        return model_weights_path, config, quant_config
Casper's avatar
Casper committed
204

Casper Hansen's avatar
Casper Hansen committed
205
    def _load_quantized_modules(self, model, quant_config, version):
206
        # Real quantization of weights
Casper's avatar
Casper committed
207
        assert quant_config.zero_point, "We only support zero_point quantization now."
208
209
        
        # Get blocks of model
210
        layers = self.get_model_layers(model)
211

212
213
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
214
215

            # Get every linear layer in a block
216
            named_linears = get_named_linears(layer)
217
218

            # Replace activation functions
219
            self._scale_activations(self, layer)
220

221
            # Replace nn.Linear with WQLinear
222
            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
223
224
225
226
227
228
                if version == 'GEMM':
                    q_linear_module = WQLinear_GEMM
                elif version == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
229
                    module,
Casper's avatar
Casper committed
230
231
                    quant_config.w_bit,
                    quant_config.q_group_size,
Casper Hansen's avatar
Casper Hansen committed
232
233
                    True
                )
234
235
236
237
238
239
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
240
    @staticmethod
241
    def _scale_activations(self, layer):
242
        scale_dict = self.get_act_for_scaling(layer)
243

244
245
246
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
247

248
249
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
250

251
252
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
253
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)