base.py 19.8 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
import torch
5
import transformers
Casper Hansen's avatar
Casper Hansen committed
6
import torch.nn as nn
7

Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper's avatar
Casper committed
9
from typing import List, Union, Dict
10
from safetensors.torch import save_file
Casper's avatar
Casper committed
11
from typing_extensions import Doc, Annotated
12
from huggingface_hub import snapshot_download
13
from transformers.modeling_utils import shard_checkpoint
14

15
16
from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
17
from awq.modules.linear.marlin import WQLinear_Marlin, marlin_post_init
18
19
from awq.modules.linear.exllama import WQLinear_Exllama, exllama_post_init
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
20
21
22
23
24
from awq.utils.module import (
    get_named_linears,
    set_op_by_name,
    exclude_layers_to_not_quantize,
)
Casper's avatar
Casper committed
25
26
27
28
from transformers import (
    AutoConfig,
    PreTrainedModel,
    PretrainedConfig,
29
30
    AutoProcessor,
    CLIPImageProcessor,
Casper's avatar
Casper committed
31
    PreTrainedTokenizer,
Casper's avatar
Casper committed
32
)
33
34
35
36
from accelerate.big_modeling import (
    init_empty_weights,
    load_checkpoint_and_dispatch,
)
Casper's avatar
Casper committed
37

38
39
40
41
42
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name

43
# Since we support different `AutoModelForxxx` from transformers
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# we need to define a custom mapping dict as below:
TRANSFORMERS_AUTO_MAPPING_DICT = {
    "mpt": "AutoModelForCausalLM",
    "llama": "AutoModelForCausalLM",
    "opt": "AutoModelForCausalLM",
    "RefinedWeb": "AutoModelForCausalLM",
    "RefinedWebModel": "AutoModelForCausalLM",
    "falcon": "AutoModelForCausalLM",
    "bloom": "AutoModelForCausalLM",
    "gptj": "AutoModelForCausalLM",
    "gpt_bigcode": "AutoModelForCausalLM",
    "mistral": "AutoModelForCausalLM",
    "mixtral": "AutoModelForCausalLM",
    "gpt_neox": "AutoModelForCausalLM",
    "aquila": "AutoModelForCausalLM",
    "Yi": "AutoModelForCausalLM",
    "qwen": "AutoModelForCausalLM",
Aoyu's avatar
Aoyu committed
61
    "baichuan": "AutoModelForCausalLM",
62
    "llava": "AutoModelForVision2Seq",
Junyang Lin's avatar
Junyang Lin committed
63
    "qwen2": "AutoModelForCausalLM",
64
65
}

66

67
class BaseAWQForCausalLM(nn.Module):
68
    def __init__(
Casper's avatar
Casper committed
69
70
71
72
73
74
75
76
77
78
79
80
81
        self,
        model: Annotated[PreTrainedModel, Doc("The pretrained or quantized model.")],
        model_type: Annotated[str, Doc("The model type, found in config.json.")],
        is_quantized: Annotated[
            bool, Doc("Indicates if the current model is quantized.")
        ],
        config: Annotated[PretrainedConfig, Doc("The config of the model.")],
        quant_config: Annotated[
            AwqConfig, Doc("The quantization config of the model.")
        ],
        processor: Annotated[
            AutoProcessor, Doc("An optional processor, e.g. for vision models.")
        ],
82
    ):
Casper's avatar
Casper committed
83
        """The base model for all AutoAWQ models."""
84
        super().__init__()
85
86
87
        self.model: PreTrainedModel = model
        self.model_type: str = model_type
        self.is_quantized: bool = is_quantized
88
        self.search_result = None
Casper's avatar
Casper committed
89
        self.config: PretrainedConfig = config
Casper's avatar
Casper committed
90
        self.quant_config: AwqConfig = quant_config
91
        self.processor: CLIPImageProcessor = processor
92

Casper's avatar
Casper committed
93
94
    def to(self, device: Annotated[str, Doc("The device to move your model to.")]):
        """A utility function for moving the model to a device."""
95
        return self.model.to(device)
96

97
    def forward(self, *args, **kwargs):
Casper's avatar
Casper committed
98
        """A forward function that mimics the torch forward."""
99
        return self.model(*args, **kwargs)
100

Casper Hansen's avatar
Casper Hansen committed
101
    def generate(self, *args, **kwargs):
Casper's avatar
Casper committed
102
        """A generate function that mimics the HF generate function."""
Casper Hansen's avatar
Casper Hansen committed
103
104
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
105

Casper Hansen's avatar
Casper Hansen committed
106
    @torch.no_grad()
107
108
    def quantize(
        self,
Casper's avatar
Casper committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        tokenizer: Annotated[
            PreTrainedTokenizer, Doc("The tokenizer to use for quantization.")
        ] = None,
        quant_config: Annotated[
            Dict, Doc("The quantization config you want to use.")
        ] = {},
        calib_data: Annotated[
            Union[str, List[str]],
            Doc(
                "The calibration dataset. Either a string pointing to Huggingface or a list of preloaded examples."
            ),
        ] = "pileval",
        split: Annotated[str, Doc("The split of calib_data.")] = "train",
        text_column: Annotated[str, Doc("The text column of calib_data.")] = "text",
        duo_scaling: Annotated[
            bool, Doc("Whether to scale using both w/x or just x.")
        ] = True,
        export_compatible: Annotated[
            bool,
            Doc(
                "This argument avoids real quantization by only applying the scales without quantizing down to FP16."
            ),
        ] = False,
132
    ):
Casper's avatar
Casper committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        """
        The main quantization function that you can use to quantize your model.

        Example:

        ```python
        from awq import AutoAWQForCausalLM
        from transformers import AutoTokenizer

        model_path = "..."
        model = AutoAWQForCausalLM.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model_path)

        quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
        model.quantize(tokenizer, quant_config)
        ```
        """
Casper's avatar
Casper committed
150
        self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
151

Casper's avatar
Casper committed
152
153
154
        if hasattr(self, "modules_to_not_convert"):
            self.quant_config.modules_to_not_convert = self.modules_to_not_convert

155
        self.quantizer = AwqQuantizer(
156
157
158
159
160
            self,
            self.model,
            tokenizer,
            self.quant_config.w_bit,
            self.quant_config.q_group_size,
161
            self.quant_config.zero_point,
162
163
164
165
166
            self.quant_config.version,
            calib_data,
            split,
            text_column,
            duo_scaling,
Casper's avatar
Casper committed
167
            modules_to_not_convert=self.quant_config.modules_to_not_convert,
168
            export_compatible=export_compatible,
Casper Hansen's avatar
Casper Hansen committed
169
        )
170
        self.quantizer.quantize()
171

Casper Hansen's avatar
Casper Hansen committed
172
        self.is_quantized = True
173

174
175
176
177
178
    @torch.no_grad()
    def pack(self):
        """
        A utility function for the following scenario. Note that save_quantized will
        overwrite existing weights if you use the same quant_path.
179

Casper's avatar
Casper committed
180
181
182
        Example:

        ```python
183
184
185
186
        model.quantize(
            tokenizer,
            quant_config=quant_config,
            export_compatible=True
187
        )
188
189
190
        model.save_quantized(...)  # produces GGUF/other compat weights
        model.pack(...) # makes the model CUDA compat
        model.save_quantized(...)  # produces CUDA compat weights
Casper's avatar
Casper committed
191
        ```
192
193
        """
        self.quantizer.pack()
194

qwopqwop200's avatar
qwopqwop200 committed
195
    @staticmethod
Casper's avatar
Casper committed
196
    def fuse_layers(model):
qwopqwop200's avatar
qwopqwop200 committed
197
        pass
Casper's avatar
Casper committed
198

Casper's avatar
Casper committed
199
200
201
202
203
204
205
206
207
208
    def save_quantized(
        self,
        save_dir: Annotated[str, Doc("The directory to save your model to.")],
        safetensors: Annotated[
            bool, Doc("Whether to save the model as safetensors or torch files.")
        ] = True,
        shard_size: Annotated[
            str, Doc("The shard size for sharding large models into multiple chunks.")
        ] = "5GB",
    ):
209
        save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir
210

Casper Hansen's avatar
Casper Hansen committed
211
212
        # Save model
        class EmptyModule(nn.Module):
213
214
215
216
217
            def __init__(self):
                super(EmptyModule, self).__init__()

            def forward(self, x):
                return x
218

Casper's avatar
Casper committed
219
220
        # Save model and config files with empty state dict
        self.model.config.quantization_config = self.quant_config.to_transformers_dict()
221
        self.model.generation_config.do_sample = True
Casper Hansen's avatar
Casper Hansen committed
222
        self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
223

224
225
226
227
        # Vision transformers have a processor
        if self.processor is not None:
            self.processor.save_pretrained(save_dir)

Casper Hansen's avatar
Casper Hansen committed
228
        # Remove empty state dict
229
230
231
232
        default_paths = [
            f"{save_dir}/model.safetensors",
            f"{save_dir}/pytorch_model.bin",
        ]
233
234
235
        for path in default_paths:
            if os.path.exists(path):
                os.remove(path)
236

Casper Hansen's avatar
Casper Hansen committed
237
        # model_name has no extension, add it when saving state_dict
238
        model_name = "model.safetensors" if safetensors else "pytorch_model.bin"
239

Casper Hansen's avatar
Casper Hansen committed
240
241
        # shard checkpoint into chunks (10GB default)
        shards, index = shard_checkpoint(
242
            self.model.state_dict(), max_shard_size=shard_size, weights_name=model_name
Casper Hansen's avatar
Casper Hansen committed
243
        )
244

Casper Hansen's avatar
Casper Hansen committed
245
246
247
248
        for shard_file, shard in shards.items():
            if safetensors:
                # safetensors must be in the same memory, so we duplicate and use contiguous memory
                shard = {k: v.clone().contiguous() for k, v in shard.items()}
249
250
251
                save_file(
                    shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}
                )
Casper Hansen's avatar
Casper Hansen committed
252
253
            else:
                torch.save(shard, os.path.join(save_dir, shard_file))
254

Casper Hansen's avatar
Casper Hansen committed
255
256
        # save shard index
        if index is not None:
257
            with open(f"{save_dir}/{model_name}.index.json", "w+") as file:
Casper Hansen's avatar
Casper Hansen committed
258
                file.write(json.dumps(index, indent=4))
259

260
    @classmethod
261
262
    def from_pretrained(
        self,
Casper's avatar
Casper committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        model_path: Annotated[str, Doc("A Huggingface path or local path to a model.")],
        model_type: Annotated[str, Doc("The model type, loaded from config.json.")],
        torch_dtype: Annotated[
            torch.dtype,
            Doc(
                "The dtype to load the model as. May not work with other values than float16."
            ),
        ] = torch.float16,
        trust_remote_code: Annotated[
            bool,
            Doc(
                "Useful for Huggingface repositories that have not been integrated into transformers yet."
            ),
        ] = True,
        safetensors: Annotated[
            bool, Doc("Whether to download/load safetensors instead of torch weights.")
        ] = True,
        device_map: Annotated[
            Union[str, Dict],
            Doc(
                "A device map that will be passed onto the model loading method from transformers."
            ),
        ] = None,
        **model_init_kwargs: Annotated[
            Dict,
            Doc(
                "Additional kwargs that are passed to the model during initialization."
            ),
        ],
292
    ):
Casper's avatar
Casper committed
293
        """A method for initialization of pretrained models, usually in FP16."""
Casper Hansen's avatar
Casper Hansen committed
294
295
        # Get weights path and quant config
        model_weights_path, config, quant_config = self._load_config(
296
            self, model_path, "", safetensors, trust_remote_code=trust_remote_code
297
        )
Casper's avatar
Casper committed
298

299
300
301
302
303
304
305
306
        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)

        processor = None
        if target_cls_name == "AutoModelForVision2Seq":
            processor = AutoProcessor.from_pretrained(model_weights_path)
            processor: CLIPImageProcessor = processor.image_processor

Casper Hansen's avatar
Casper Hansen committed
307
        # If not quantized, must load with AutoModelForCausalLM
308
        model = target_cls.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
309
310
311
312
            model_weights_path,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
            use_safetensors=safetensors,
313
            device_map=device_map,
314
            **model_init_kwargs,
Casper Hansen's avatar
Casper Hansen committed
315
316
317
318
        )

        model.eval()

319
320
321
322
323
324
325
326
        return self(
            model,
            model_type,
            is_quantized=False,
            config=config,
            quant_config=quant_config,
            processor=processor,
        )
Casper Hansen's avatar
Casper Hansen committed
327

328
    @classmethod
329
330
    def from_quantized(
        self,
Casper's avatar
Casper committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        model_path: Annotated[str, Doc("A Huggingface path or local path to a model.")],
        model_type: Annotated[str, Doc("The model type, loaded from config.json.")],
        model_filename: Annotated[
            str, Doc("Load a specific model's filename by specifying this argument.")
        ] = "",
        max_seq_len: Annotated[
            int,
            Doc(
                "The maximum sequence cached sequence length of the model. Larger values may increase loading time and memory usage."
            ),
        ] = None,
        torch_dtype: Annotated[
            torch.dtype,
            Doc(
                "The dtype to load the model as. May not work with other values than float16."
            ),
        ] = torch.float16,
        trust_remote_code: Annotated[
            bool,
            Doc(
                "Useful for Huggingface repositories that have not been integrated into transformers yet."
            ),
        ] = True,
        safetensors: Annotated[
            bool, Doc("Whether to download/load safetensors instead of torch weights.")
        ] = True,
        fuse_layers: Annotated[
            bool,
            Doc(
                "Whether to use fused/optimized combination of layers for increased speed."
            ),
        ] = True,
        use_exllama: Annotated[
            bool, Doc("Whether to map the weights to ExLlamaV1 kernels.")
        ] = False,
        use_exllama_v2: Annotated[
            bool, Doc("Whether to map the weights to ExLlamaV2 kernels.")
        ] = False,
        device_map: Annotated[
            Union[str, Dict],
            Doc(
                "A device map that will be passed onto the model loading method from transformers."
            ),
        ] = "balanced",
        offload_folder: Annotated[
            str,
            Doc("The folder ot offload the model to."),
        ] = None,
        **config_kwargs: Annotated[
            Dict,
            Doc(
                "Additional kwargs that are passed to the config during initialization."
            ),
        ],
385
    ):
Casper's avatar
Casper committed
386
        """A method for initialization of a quantized model, usually in INT4."""
Casper Hansen's avatar
Casper Hansen committed
387
388
        # [STEP 1-2] Load weights path and configs
        model_weights_path, config, quant_config = self._load_config(
389
390
391
392
393
            self,
            model_path,
            model_filename,
            safetensors,
            trust_remote_code,
Casper's avatar
Casper committed
394
            max_seq_len=max_seq_len,
395
            **config_kwargs,
Casper Hansen's avatar
Casper Hansen committed
396
        )
397
398
399

        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)
400

Casper Hansen's avatar
Casper Hansen committed
401
402
        # [STEP 3] Load model
        with init_empty_weights():
403
404
405
406
407
408
            model = target_cls.from_config(
                config=config,
                torch_dtype=torch_dtype,
                trust_remote_code=trust_remote_code,
            )

Casper Hansen's avatar
Casper Hansen committed
409
        # Prepare WQLinear layers, replace nn.Linear
410
411
412
413
414
415
416
417
418
        self._load_quantized_modules(
            self,
            model,
            quant_config,
            quant_config.version,
            use_exllama=use_exllama,
            use_exllama_v2=use_exllama_v2,
        )

Casper Hansen's avatar
Casper Hansen committed
419
420
        model.tie_weights()

421
422
423
        # loads the weights into modules and distributes
        # across available devices automatically
        load_checkpoint_and_dispatch(
Casper Hansen's avatar
Casper Hansen committed
424
425
            model,
            checkpoint=model_weights_path,
s4rduk4r's avatar
s4rduk4r committed
426
            device_map=device_map,
427
            no_split_module_classes=[self.layer_type],
s4rduk4r's avatar
s4rduk4r committed
428
            offload_folder=offload_folder,
429
            dtype=torch_dtype,
Casper Hansen's avatar
Casper Hansen committed
430
        )
431

Casper Hansen's avatar
Casper Hansen committed
432
        # Dispath to devices
433
        if fuse_layers:
Casper's avatar
Casper committed
434
            self.fuse_layers(model)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
435

Casper's avatar
Casper committed
436
        if quant_config.version == "marlin":
437
438
439
            model = marlin_post_init(model)

        elif use_exllama:
440
441
442
            # creates q4 handle
            model = exllama_post_init(model)
        elif use_exllama_v2:
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
443
            # creates q4 handle and allocates scratch spaces wrt max_input_len and max_batch_size
444
445
            model = exllamav2_post_init(
                model,
Casper's avatar
Casper committed
446
                max_input_len=max_seq_len or 2048,
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
447
                max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1)),
448
449
450
451
452
            )

        return self(
            model,
            model_type,
Casper's avatar
Casper committed
453
            is_quantized=True,
454
455
456
457
            config=config,
            quant_config=quant_config,
            processor=None,
        )
s4rduk4r's avatar
s4rduk4r committed
458

459
460
461
462
463
464
    def _load_config(
        self,
        model_path,
        model_filename,
        safetensors=True,
        trust_remote_code=True,
Casper's avatar
Casper committed
465
        max_seq_len=4096,
466
467
        **config_kwargs,
    ):
468
        # [STEP 1] Download model if path is not a directory
469
        if not os.path.isdir(model_path):
470
            ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
471
            if safetensors:
472
                ignore_patterns.extend(["*.pt*", "*.bin*", "consolidated*"])
473
            else:
Casper Hansen's avatar
Casper Hansen committed
474
                ignore_patterns.append("*.safetensors*")
475

476
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
477
478
479

        if model_filename != "":
            model_weights_path = model_path + f"/{model_filename}"
480
481
        else:
            model_weights_path = model_path
482

483
        # [STEP 2] Load config and set sequence length
484
        # TODO: Create BaseAWQConfig class
Casper's avatar
Casper committed
485
        quant_config = AwqConfig.from_pretrained(model_path)
486

487
        # Load model config and set max generation length
Casper's avatar
Casper committed
488
        if max_seq_len is None and hasattr(self, "max_seq_len_key"):
489
490
491
            config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=trust_remote_code, **config_kwargs
            )
Casper's avatar
Casper committed
492
            config.max_seq_len = getattr(config, self.max_seq_len_key, 2048)
493
494
            # To add the generate support for Multi-modal models as well
            if hasattr(config, "text_config"):
Casper's avatar
Casper committed
495
496
                config.text_config.max_seq_len = getattr(
                    config, self.max_seq_len_key, 2048
497
                )
498
        else:
Casper's avatar
Casper committed
499
            max_seq_len = 2048 if max_seq_len is None else max_seq_len
500
501
502
            config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=trust_remote_code, **config_kwargs
            )
Casper's avatar
Casper committed
503
            config.max_seq_len = max_seq_len
504

Casper Hansen's avatar
Casper Hansen committed
505
        return model_weights_path, config, quant_config
Casper's avatar
Casper committed
506

507
508
509
    def _load_quantized_modules(
        self, model, quant_config, version, use_exllama, use_exllama_v2
    ):
510
        # Real quantization of weights
511
        assert not (
Casper's avatar
Casper committed
512
            version == "gemv" and (use_exllama or use_exllama_v2)
513
514
        ), "Exllama kernels only support GEMM version."

515
        # Get blocks of model
516
        layers = self.get_model_layers(model)
517

518
519
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
520
521

            # Get every linear layer in a block
522
            named_linears = get_named_linears(layer)
523

524
            # Filter out the linear layers we don't want to exclude
525
526
527
            named_linears = exclude_layers_to_not_quantize(
                named_linears, quant_config.modules_to_not_convert
            )
528

529
            # Replace activation functions
530
            self._scale_activations(self, layer)
531

532
            # Replace nn.Linear with WQLinear
533
            for name, module in named_linears.items():
Casper's avatar
Casper committed
534
                if version == "marlin":
535
536
                    q_linear_module = WQLinear_Marlin
                elif use_exllama:
537
538
539
                    q_linear_module = WQLinear_Exllama
                elif use_exllama_v2:
                    q_linear_module = WQLinear_ExllamaV2
Casper's avatar
Casper committed
540
                elif version == "gemm":
Casper Hansen's avatar
Casper Hansen committed
541
                    q_linear_module = WQLinear_GEMM
Casper's avatar
Casper committed
542
                elif version == "gemv":
Casper Hansen's avatar
Casper Hansen committed
543
                    q_linear_module = WQLinear_GEMV
544

Casper Hansen's avatar
Casper Hansen committed
545
                q_linear = q_linear_module.from_linear(
546
                    module, quant_config.w_bit, quant_config.q_group_size, True
Casper Hansen's avatar
Casper Hansen committed
547
                )
548
549
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
550

551
552
            torch.cuda.empty_cache()
            gc.collect()
553

554
    @staticmethod
555
    def _scale_activations(self, layer):
556
        scale_dict = self.get_act_for_scaling(layer)
557

558
559
        if scale_dict["is_scalable"]:
            if not isinstance(scale_dict["scale_layer"], ScaledActivation):
560
                param = next(layer.parameters())
561

562
                # get activation scale
563
564
565
                scale_like = torch.ones(
                    scale_dict["scale_shape"], dtype=param.dtype, device=param.device
                )
566

567
                # scale activation
568
569
                scaled_act = ScaledActivation(scale_dict["scale_layer"], scale_like)
                set_op_by_name(layer, scale_dict["scale_name"], scaled_act)