base.py 14.5 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
4
import time
Casper Hansen's avatar
Casper Hansen committed
5
6
import torch
import torch.nn as nn
7

Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper's avatar
Casper committed
9
from typing import List, Union
10
from safetensors.torch import save_file
11

12
from huggingface_hub import snapshot_download
13
import transformers
14
from transformers.modeling_utils import shard_checkpoint
15

16
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
17
18
from awq.modules.exllama import WQLinear_Exllama, exllama_post_init
from awq.modules.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
19
20
21
22
23
from awq.utils.module import (
    get_named_linears,
    set_op_by_name,
    exclude_layers_to_not_quantize,
)
Casper's avatar
Casper committed
24
25
26
27
from transformers import (
    AutoConfig,
    PreTrainedModel,
    PretrainedConfig,
28
29
    AutoProcessor,
    CLIPImageProcessor,
Casper's avatar
Casper committed
30
)
31
32
33
34
from accelerate.big_modeling import (
    init_empty_weights,
    load_checkpoint_and_dispatch,
)
Casper's avatar
Casper committed
35

36
37
38
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
39
40
from awq.modules.exllama import WQLinear_Exllama
from awq.modules.exllamav2 import WQLinear_ExllamaV2
41
42
43
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name

44
# Since we support different `AutoModelForxxx` from transformers
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# we need to define a custom mapping dict as below:
TRANSFORMERS_AUTO_MAPPING_DICT = {
    "mpt": "AutoModelForCausalLM",
    "llama": "AutoModelForCausalLM",
    "opt": "AutoModelForCausalLM",
    "RefinedWeb": "AutoModelForCausalLM",
    "RefinedWebModel": "AutoModelForCausalLM",
    "falcon": "AutoModelForCausalLM",
    "bloom": "AutoModelForCausalLM",
    "gptj": "AutoModelForCausalLM",
    "gpt_bigcode": "AutoModelForCausalLM",
    "mistral": "AutoModelForCausalLM",
    "mixtral": "AutoModelForCausalLM",
    "gpt_neox": "AutoModelForCausalLM",
    "aquila": "AutoModelForCausalLM",
    "Yi": "AutoModelForCausalLM",
    "qwen": "AutoModelForCausalLM",
Aoyu's avatar
Aoyu committed
62
    "baichuan": "AutoModelForCausalLM",
63
64
65
    "llava": "AutoModelForVision2Seq",
}

66

67
class BaseAWQForCausalLM(nn.Module):
68
69
70
    def __init__(
        self, model, model_type, is_quantized, config, quant_config, processor
    ):
71
        super().__init__()
72
73
74
        self.model: PreTrainedModel = model
        self.model_type: str = model_type
        self.is_quantized: bool = is_quantized
75
        self.search_result = None
Casper's avatar
Casper committed
76
        self.config: PretrainedConfig = config
Casper's avatar
Casper committed
77
        self.quant_config: AwqConfig = quant_config
78
        self.processor: CLIPImageProcessor = processor
79

80
81
    def to(self, device: str):
        return self.model.to(device)
82

83
84
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
85

Casper Hansen's avatar
Casper Hansen committed
86
87
88
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
89

Casper Hansen's avatar
Casper Hansen committed
90
    @torch.no_grad()
91
92
93
94
95
96
97
98
99
100
101
    def quantize(
        self,
        tokenizer=None,
        quant_config={},
        calib_data: Union[str, List[str]] = "pileval",
        split="train",
        text_column="text",
        duo_scaling=True,
        modules_to_not_convert=None,
        export_compatible=False,
    ):
Casper's avatar
Casper committed
102
        self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
103

104
        self.quantizer = AwqQuantizer(
105
106
107
108
109
110
111
112
113
114
115
116
            self,
            self.model,
            tokenizer,
            self.quant_config.w_bit,
            self.quant_config.q_group_size,
            self.quant_config.version,
            calib_data,
            split,
            text_column,
            duo_scaling,
            modules_to_not_convert=modules_to_not_convert,
            export_compatible=export_compatible,
Casper Hansen's avatar
Casper Hansen committed
117
        )
118
        self.quantizer.quantize()
119

Casper Hansen's avatar
Casper Hansen committed
120
        self.is_quantized = True
121

122
123
124
125
126
    @torch.no_grad()
    def pack(self):
        """
        A utility function for the following scenario. Note that save_quantized will
        overwrite existing weights if you use the same quant_path.
127

128
129
130
131
        model.quantize(
            tokenizer,
            quant_config=quant_config,
            export_compatible=True
132
        )
133
134
135
136
137
        model.save_quantized(...)  # produces GGUF/other compat weights
        model.pack(...) # makes the model CUDA compat
        model.save_quantized(...)  # produces CUDA compat weights
        """
        self.quantizer.pack()
138

qwopqwop200's avatar
qwopqwop200 committed
139
    @staticmethod
Casper's avatar
Casper committed
140
    def fuse_layers(model):
qwopqwop200's avatar
qwopqwop200 committed
141
        pass
Casper's avatar
Casper committed
142

143
    def save_quantized(self, save_dir, safetensors=True, shard_size="10GB"):
144
        save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir
145

Casper Hansen's avatar
Casper Hansen committed
146
147
        # Save model
        class EmptyModule(nn.Module):
148
149
150
151
152
            def __init__(self):
                super(EmptyModule, self).__init__()

            def forward(self, x):
                return x
153

Casper's avatar
Casper committed
154
155
        # Save model and config files with empty state dict
        self.model.config.quantization_config = self.quant_config.to_transformers_dict()
Casper Hansen's avatar
Casper Hansen committed
156
        self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
Casper's avatar
Casper committed
157
        self.quant_config.save_pretrained(save_dir)
158

159
160
161
162
        # Vision transformers have a processor
        if self.processor is not None:
            self.processor.save_pretrained(save_dir)

Casper Hansen's avatar
Casper Hansen committed
163
        # Remove empty state dict
164
165
166
167
        default_paths = [
            f"{save_dir}/model.safetensors",
            f"{save_dir}/pytorch_model.bin",
        ]
168
169
170
        for path in default_paths:
            if os.path.exists(path):
                os.remove(path)
171

Casper Hansen's avatar
Casper Hansen committed
172
        # model_name has no extension, add it when saving state_dict
173
        model_name = "model.safetensors" if safetensors else "pytorch_model.bin"
174

Casper Hansen's avatar
Casper Hansen committed
175
176
        # shard checkpoint into chunks (10GB default)
        shards, index = shard_checkpoint(
177
            self.model.state_dict(), max_shard_size=shard_size, weights_name=model_name
Casper Hansen's avatar
Casper Hansen committed
178
        )
179

Casper Hansen's avatar
Casper Hansen committed
180
181
182
183
        for shard_file, shard in shards.items():
            if safetensors:
                # safetensors must be in the same memory, so we duplicate and use contiguous memory
                shard = {k: v.clone().contiguous() for k, v in shard.items()}
184
185
186
                save_file(
                    shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}
                )
Casper Hansen's avatar
Casper Hansen committed
187
188
            else:
                torch.save(shard, os.path.join(save_dir, shard_file))
189

Casper Hansen's avatar
Casper Hansen committed
190
191
        # save shard index
        if index is not None:
192
            with open(f"{save_dir}/{model_name}.index.json", "w+") as file:
Casper Hansen's avatar
Casper Hansen committed
193
                file.write(json.dumps(index, indent=4))
194

195
    @classmethod
196
197
198
199
200
201
202
203
204
205
    def from_pretrained(
        self,
        model_path,
        model_type,
        torch_dtype: torch.dtype = torch.float16,
        trust_remote_code=True,
        safetensors=False,
        device_map=None,
        **model_init_kwargs,
    ):
Casper Hansen's avatar
Casper Hansen committed
206
207
        # Get weights path and quant config
        model_weights_path, config, quant_config = self._load_config(
208
            self, model_path, "", safetensors, trust_remote_code=trust_remote_code
209
        )
Casper's avatar
Casper committed
210

211
212
213
214
215
216
217
218
        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)

        processor = None
        if target_cls_name == "AutoModelForVision2Seq":
            processor = AutoProcessor.from_pretrained(model_weights_path)
            processor: CLIPImageProcessor = processor.image_processor

Casper Hansen's avatar
Casper Hansen committed
219
        # If not quantized, must load with AutoModelForCausalLM
220
        model = target_cls.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
221
222
223
224
            model_weights_path,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
            use_safetensors=safetensors,
225
            device_map=device_map,
226
            **model_init_kwargs,
Casper Hansen's avatar
Casper Hansen committed
227
228
229
230
        )

        model.eval()

231
232
233
234
235
236
237
238
        return self(
            model,
            model_type,
            is_quantized=False,
            config=config,
            quant_config=quant_config,
            processor=processor,
        )
Casper Hansen's avatar
Casper Hansen committed
239

240
    @classmethod
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    def from_quantized(
        self,
        model_path,
        model_type,
        model_filename="",
        max_new_tokens=None,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        safetensors=True,
        is_quantized=True,
        fuse_layers=False,
        use_exllama=False,
        use_exllama_v2=False,
        version="GEMM",
        device_map="balanced",
        offload_folder=None,
        **config_kwargs,
    ):
Casper Hansen's avatar
Casper Hansen committed
259
260
        # [STEP 1-2] Load weights path and configs
        model_weights_path, config, quant_config = self._load_config(
261
262
263
264
265
266
267
268
            self,
            model_path,
            model_filename,
            safetensors,
            version,
            trust_remote_code,
            max_new_tokens=max_new_tokens,
            **config_kwargs,
Casper Hansen's avatar
Casper Hansen committed
269
        )
270
271
272

        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)
273

Casper Hansen's avatar
Casper Hansen committed
274
275
        # [STEP 3] Load model
        with init_empty_weights():
276
277
278
279
280
281
            model = target_cls.from_config(
                config=config,
                torch_dtype=torch_dtype,
                trust_remote_code=trust_remote_code,
            )

Casper Hansen's avatar
Casper Hansen committed
282
        # Prepare WQLinear layers, replace nn.Linear
283
284
285
286
287
288
289
290
291
        self._load_quantized_modules(
            self,
            model,
            quant_config,
            quant_config.version,
            use_exllama=use_exllama,
            use_exllama_v2=use_exllama_v2,
        )

Casper Hansen's avatar
Casper Hansen committed
292
293
        model.tie_weights()

294
295
296
        # loads the weights into modules and distributes
        # across available devices automatically
        load_checkpoint_and_dispatch(
Casper Hansen's avatar
Casper Hansen committed
297
298
            model,
            checkpoint=model_weights_path,
s4rduk4r's avatar
s4rduk4r committed
299
            device_map=device_map,
300
            no_split_module_classes=[self.layer_type],
s4rduk4r's avatar
s4rduk4r committed
301
            offload_folder=offload_folder,
302
            dtype=torch_dtype,
Casper Hansen's avatar
Casper Hansen committed
303
        )
304

Casper Hansen's avatar
Casper Hansen committed
305
        # Dispath to devices
306
        if fuse_layers:
Casper's avatar
Casper committed
307
            self.fuse_layers(model)
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
        
        if use_exllama:
            # creates q4 handle
            model = exllama_post_init(model)
        elif use_exllama_v2:
            # creates q4 handle and allocates scratch spaces wrt max_input_len and
            # max_batch_size, which are hardcoded for now but might be worth interfacing
            model = exllamav2_post_init(
                model,
                max_input_len=max_new_tokens,
                max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1))
            )

        return self(
            model,
            model_type,
            is_quantized=is_quantized,
            config=config,
            quant_config=quant_config,
            processor=None,
        )
s4rduk4r's avatar
s4rduk4r committed
329

330
331
332
333
334
335
336
337
338
339
    def _load_config(
        self,
        model_path,
        model_filename,
        safetensors=True,
        version="GEMM",
        trust_remote_code=True,
        max_new_tokens=4096,
        **config_kwargs,
    ):
340
        # [STEP 1] Download model if path is not a directory
341
        if not os.path.isdir(model_path):
342
            ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
343
            if safetensors:
344
                ignore_patterns.extend(["*.pt*", "*.bin*", "consolidated*"])
345
            else:
Casper Hansen's avatar
Casper Hansen committed
346
                ignore_patterns.append("*.safetensors*")
347

348
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
349
350
351

        if model_filename != "":
            model_weights_path = model_path + f"/{model_filename}"
352
353
        else:
            model_weights_path = model_path
354

355
        # [STEP 2] Load config and set sequence length
356
        # TODO: Create BaseAWQConfig class
Casper's avatar
Casper committed
357
        quant_config = AwqConfig.from_pretrained(model_path)
358

359
        # Load model config and set max generation length
360
361
362
363
        if max_new_tokens is None and hasattr(self, "max_new_tokens_key"):
            config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=trust_remote_code, **config_kwargs
            )
364
365
366
            config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048)
            # To add the generate support for Multi-modal models as well
            if hasattr(config, "text_config"):
367
368
369
                config.text_config.max_new_tokens = getattr(
                    config, self.max_new_tokens_key, 2048
                )
370
371
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
372
373
374
            config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=trust_remote_code, **config_kwargs
            )
375
            config.max_new_tokens = max_new_tokens
376

Casper Hansen's avatar
Casper Hansen committed
377
        return model_weights_path, config, quant_config
Casper's avatar
Casper committed
378

379
380
381
    def _load_quantized_modules(
        self, model, quant_config, version, use_exllama, use_exllama_v2
    ):
382
        # Real quantization of weights
Casper's avatar
Casper committed
383
        assert quant_config.zero_point, "We only support zero_point quantization now."
384
385
386
387
        assert not (
            version == "GEMV" and (use_exllama or use_exllama_v2)
        ), "Exllama kernels only support GEMM version."

388
        # Get blocks of model
389
        layers = self.get_model_layers(model)
390

391
392
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
393
394

            # Get every linear layer in a block
395
            named_linears = get_named_linears(layer)
396

397
            # Filter out the linear layers we don't want to exclude
398
399
400
            named_linears = exclude_layers_to_not_quantize(
                named_linears, quant_config.modules_to_not_convert
            )
401

402
            # Replace activation functions
403
            self._scale_activations(self, layer)
404

405
            # Replace nn.Linear with WQLinear
406
            for name, module in named_linears.items():
407
408
409
410
411
                if use_exllama:
                    q_linear_module = WQLinear_Exllama
                elif use_exllama_v2:
                    q_linear_module = WQLinear_ExllamaV2
                elif version == "GEMM":
Casper Hansen's avatar
Casper Hansen committed
412
                    q_linear_module = WQLinear_GEMM
413
                elif version == "GEMV":
Casper Hansen's avatar
Casper Hansen committed
414
                    q_linear_module = WQLinear_GEMV
415

Casper Hansen's avatar
Casper Hansen committed
416
                q_linear = q_linear_module.from_linear(
417
                    module, quant_config.w_bit, quant_config.q_group_size, True
Casper Hansen's avatar
Casper Hansen committed
418
                )
419
420
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
421

422
423
            torch.cuda.empty_cache()
            gc.collect()
424

425
    @staticmethod
426
    def _scale_activations(self, layer):
427
        scale_dict = self.get_act_for_scaling(layer)
428

429
430
        if scale_dict["is_scalable"]:
            if not isinstance(scale_dict["scale_layer"], ScaledActivation):
431
                param = next(layer.parameters())
432

433
                # get activation scale
434
435
436
                scale_like = torch.ones(
                    scale_dict["scale_shape"], dtype=param.dtype, device=param.device
                )
437

438
                # scale activation
439
440
                scaled_act = ScaledActivation(scale_dict["scale_layer"], scale_like)
                set_op_by_name(layer, scale_dict["scale_name"], scaled_act)