"megatron/vscode:/vscode.git/clone" did not exist on "bf1da27ea094d027e288645a6ba15bc76e5c3fba"
base.py 10.1 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
5
import torch
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
6
from tqdm import tqdm
Casper's avatar
Casper committed
7
from typing import List, Union
8
from safetensors.torch import save_file
Casper's avatar
Casper committed
9
from awq.models._config import AwqConfig
10
from awq.modules.act import ScaledActivation
11
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
12
from awq.quantize.quantizer import AwqQuantizer
13
from transformers.modeling_utils import shard_checkpoint
14
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Casper Hansen's avatar
Casper Hansen committed
15
from awq.utils.module import get_named_linears, set_op_by_name
16
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
Casper Hansen's avatar
Casper Hansen committed
17
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
18

Casper's avatar
Casper committed
19

20
class BaseAWQForCausalLM(nn.Module):
21
    def __init__(self, model, model_type, is_quantized, quant_config):
22
        super().__init__()
23
24
25
26
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
Casper's avatar
Casper committed
27
        self.quant_config: AwqConfig = quant_config
28
29
30
31
32
33
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
34
35
36
37
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
38

Casper Hansen's avatar
Casper Hansen committed
39
    @torch.no_grad()
Casper Hansen's avatar
Casper Hansen committed
40
41
    def quantize(self, tokenizer=None, quant_config={},
                       calib_data: Union[str, List[str]]="pileval", 
Casper Hansen's avatar
Casper Hansen committed
42
                       split="train", text_column="text"):
Casper's avatar
Casper committed
43
        self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
44

Casper Hansen's avatar
Casper Hansen committed
45
        quantizer = AwqQuantizer(
Casper's avatar
Casper committed
46
47
            self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size,
            self.quant_config.version, calib_data, split, text_column
Casper Hansen's avatar
Casper Hansen committed
48
49
50
        )
        quantizer.quantize()
        self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
51
    
qwopqwop200's avatar
qwopqwop200 committed
52
    @staticmethod
Casper's avatar
Casper committed
53
    def fuse_layers(model):
qwopqwop200's avatar
qwopqwop200 committed
54
        pass
Casper's avatar
Casper committed
55

56
    def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
Casper Hansen's avatar
Casper Hansen committed
57
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
58

Casper Hansen's avatar
Casper Hansen committed
59
60
61
62
        # Save model
        class EmptyModule(nn.Module):
            def __init__(self): super(EmptyModule, self).__init__()
            def forward(self, x): return x
63

Casper's avatar
Casper committed
64
65
        # Save model and config files with empty state dict
        self.model.config.quantization_config = self.quant_config.to_transformers_dict()
Casper Hansen's avatar
Casper Hansen committed
66
        self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
Casper's avatar
Casper committed
67
        self.quant_config.save_pretrained(save_dir)
68

Casper Hansen's avatar
Casper Hansen committed
69
70
        # Remove empty state dict
        os.remove(f'{save_dir}/pytorch_model.bin')
71

Casper Hansen's avatar
Casper Hansen committed
72
73
        # model_name has no extension, add it when saving state_dict
        model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
74

Casper Hansen's avatar
Casper Hansen committed
75
76
77
78
79
80
        # shard checkpoint into chunks (10GB default)
        shards, index = shard_checkpoint(
            self.model.state_dict(), 
            max_shard_size=shard_size, 
            weights_name=model_name
        )
81

Casper Hansen's avatar
Casper Hansen committed
82
83
84
85
86
87
88
        for shard_file, shard in shards.items():
            if safetensors:
                # safetensors must be in the same memory, so we duplicate and use contiguous memory
                shard = {k: v.clone().contiguous() for k, v in shard.items()}
                save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
            else:
                torch.save(shard, os.path.join(save_dir, shard_file))
89

Casper Hansen's avatar
Casper Hansen committed
90
91
92
93
94
        # save shard index
        if index is not None:
            with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
                file.write(json.dumps(index, indent=4))
        
95
        
96
97
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
98
99
100
101
102
                        trust_remote_code=True, safetensors=False, device_map=None,
                        **model_init_kwargs):
        # Get weights path and quant config
        model_weights_path, config, quant_config = self._load_config(
            self, model_path, '', safetensors, trust_remote_code=trust_remote_code
103
        )
Casper's avatar
Casper committed
104

Casper Hansen's avatar
Casper Hansen committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        if device_map is None:
            with init_empty_weights():
                model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)

            # Get device map
            device_map = infer_auto_device_map(
                model,
                no_split_module_classes=[self.layer_type], 
                dtype=torch_dtype
            )
            del model

        # If not quantized, must load with AutoModelForCausalLM
        model = AutoModelForCausalLM.from_pretrained(
            model_weights_path,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
            use_safetensors=safetensors,
            **model_init_kwargs
        )

        model.eval()

        return self(model, model_type, is_quantized=False, quant_config=quant_config)

130
    @classmethod
131
    def from_quantized(self, model_path, model_type, model_filename='', 
Casper Hansen's avatar
Casper Hansen committed
132
                             max_new_tokens=None, torch_dtype=torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
133
                             trust_remote_code=True, safetensors=False, is_quantized=True, 
s4rduk4r's avatar
s4rduk4r committed
134
135
                             fuse_layers=False, version='GEMM',
                             max_memory=None, offload_folder=None):
Casper Hansen's avatar
Casper Hansen committed
136
137
138
139
140
141
142
143
144
145
146
        # [STEP 1-2] Load weights path and configs
        model_weights_path, config, quant_config = self._load_config(
            self, model_path, model_filename, safetensors, version, 
            trust_remote_code, max_new_tokens=max_new_tokens
        )
        
        # [STEP 3] Load model
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
        # Prepare WQLinear layers, replace nn.Linear
Casper's avatar
Casper committed
147
        self._load_quantized_modules(self, model, quant_config, quant_config.version)
Casper Hansen's avatar
Casper Hansen committed
148
149
150
151
152
153
154
        
        model.tie_weights()

        # Get device map
        device_map = infer_auto_device_map(
            model,
            no_split_module_classes=[self.layer_type], 
s4rduk4r's avatar
s4rduk4r committed
155
            max_memory=max_memory,
Casper Hansen's avatar
Casper Hansen committed
156
            dtype=torch_dtype
s4rduk4r's avatar
Clean  
s4rduk4r committed
157
        )
Casper Hansen's avatar
Casper Hansen committed
158
159
160
161
162

        # Load checkpoint
        load_checkpoint_in_model(
            model,
            checkpoint=model_weights_path,
s4rduk4r's avatar
s4rduk4r committed
163
164
165
            device_map=device_map,
            offload_folder=offload_folder,
            dtype=torch_dtype
Casper Hansen's avatar
Casper Hansen committed
166
167
168
        )
        
        # Dispath to devices
169
        if fuse_layers:
Casper's avatar
Casper committed
170
            self.fuse_layers(model)
s4rduk4r's avatar
s4rduk4r committed
171

172
173
174
175
176
177
178
        # Offloading dispatch
        from accelerate import dispatch_model
        model = dispatch_model(
            model,
            device_map=device_map,
            offload_dir=offload_folder
        )
Casper Hansen's avatar
Casper Hansen committed
179
180
181
182
183
184
        

        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)

    def _load_config(self, model_path, model_filename, safetensors=False, 
                           version="GEMM", trust_remote_code=True, max_new_tokens=4096):
185
        # [STEP 1] Download model if path is not a directory
186
        if not os.path.isdir(model_path):
187
188
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
Casper Hansen's avatar
Casper Hansen committed
189
                ignore_patterns.extend(["*.pt*", "*.bin*"])
190
            else:
Casper Hansen's avatar
Casper Hansen committed
191
192
                ignore_patterns.append("*.safetensors*")
            
193
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
194
        
195
196
197
198
        if model_filename != '':
            model_weights_path = model_path + f'/{model_filename}'
        else:
            model_weights_path = model_path
199

200
        # [STEP 2] Load config and set sequence length
201
        # TODO: Create BaseAWQConfig class
Casper's avatar
Casper committed
202
        quant_config = AwqConfig.from_pretrained(model_path)
203
        
204
205
206
207
208
209
210
211
212
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = getattr(config, self.max_new_tokens_key)
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = max_new_tokens
        
Casper Hansen's avatar
Casper Hansen committed
213
        return model_weights_path, config, quant_config
Casper's avatar
Casper committed
214

Casper Hansen's avatar
Casper Hansen committed
215
    def _load_quantized_modules(self, model, quant_config, version):
216
        # Real quantization of weights
Casper's avatar
Casper committed
217
        assert quant_config.zero_point, "We only support zero_point quantization now."
218
219
        
        # Get blocks of model
220
        layers = self.get_model_layers(model)
221

222
223
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
224
225

            # Get every linear layer in a block
226
            named_linears = get_named_linears(layer)
227
228

            # Replace activation functions
229
            self._scale_activations(self, layer)
230

231
            # Replace nn.Linear with WQLinear
232
            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
233
234
235
236
237
238
                if version == 'GEMM':
                    q_linear_module = WQLinear_GEMM
                elif version == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
239
                    module,
Casper's avatar
Casper committed
240
241
                    quant_config.w_bit,
                    quant_config.q_group_size,
Casper Hansen's avatar
Casper Hansen committed
242
243
                    True
                )
244
245
246
247
248
249
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
250
    @staticmethod
251
    def _scale_activations(self, layer):
252
        scale_dict = self.get_act_for_scaling(layer)
253

254
255
256
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
257

258
259
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
260

261
262
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
263
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)