base.py 19.8 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
import torch
5
import transformers
Casper Hansen's avatar
Casper Hansen committed
6
import torch.nn as nn
7

Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper's avatar
Casper committed
9
from typing import List, Union, Dict
10
from safetensors.torch import save_file
Casper's avatar
Casper committed
11
from typing_extensions import Doc, Annotated
12
from huggingface_hub import snapshot_download
13
from transformers.modeling_utils import shard_checkpoint
14

Casper's avatar
Casper committed
15
16
17
18
19
20
21
22
23
24
25
from awq.modules.linear import (
    WQLinear_GEMM,
    WQLinear_GEMV,
    WQLinear_Marlin,
    WQLinear_Exllama,
    WQLinear_ExllamaV2,
    WQLinear_GEMVFast,
    marlin_post_init,
    exllama_post_init,
    exllamav2_post_init,
)
26
27
28
29
30
from awq.utils.module import (
    get_named_linears,
    set_op_by_name,
    exclude_layers_to_not_quantize,
)
Casper's avatar
Casper committed
31
32
33
34
from transformers import (
    AutoConfig,
    PreTrainedModel,
    PretrainedConfig,
35
36
    AutoProcessor,
    CLIPImageProcessor,
Casper's avatar
Casper committed
37
    PreTrainedTokenizer,
Casper's avatar
Casper committed
38
)
39
40
41
42
from accelerate.big_modeling import (
    init_empty_weights,
    load_checkpoint_and_dispatch,
)
Casper's avatar
Casper committed
43

44
45
46
47
48
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name

49
# Since we support different `AutoModelForxxx` from transformers
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# we need to define a custom mapping dict as below:
TRANSFORMERS_AUTO_MAPPING_DICT = {
    "mpt": "AutoModelForCausalLM",
    "llama": "AutoModelForCausalLM",
    "opt": "AutoModelForCausalLM",
    "RefinedWeb": "AutoModelForCausalLM",
    "RefinedWebModel": "AutoModelForCausalLM",
    "falcon": "AutoModelForCausalLM",
    "bloom": "AutoModelForCausalLM",
    "gptj": "AutoModelForCausalLM",
    "gpt_bigcode": "AutoModelForCausalLM",
    "mistral": "AutoModelForCausalLM",
    "mixtral": "AutoModelForCausalLM",
    "gpt_neox": "AutoModelForCausalLM",
    "aquila": "AutoModelForCausalLM",
    "Yi": "AutoModelForCausalLM",
    "qwen": "AutoModelForCausalLM",
Aoyu's avatar
Aoyu committed
67
    "baichuan": "AutoModelForCausalLM",
68
    "llava": "AutoModelForVision2Seq",
Junyang Lin's avatar
Junyang Lin committed
69
    "qwen2": "AutoModelForCausalLM",
TechxGenus's avatar
TechxGenus committed
70
    "gemma": "AutoModelForCausalLM",
71
72
}

73

74
class BaseAWQForCausalLM(nn.Module):
75
    def __init__(
Casper's avatar
Casper committed
76
77
78
79
80
81
82
83
84
85
86
87
88
        self,
        model: Annotated[PreTrainedModel, Doc("The pretrained or quantized model.")],
        model_type: Annotated[str, Doc("The model type, found in config.json.")],
        is_quantized: Annotated[
            bool, Doc("Indicates if the current model is quantized.")
        ],
        config: Annotated[PretrainedConfig, Doc("The config of the model.")],
        quant_config: Annotated[
            AwqConfig, Doc("The quantization config of the model.")
        ],
        processor: Annotated[
            AutoProcessor, Doc("An optional processor, e.g. for vision models.")
        ],
89
    ):
Casper's avatar
Casper committed
90
        """The base model for all AutoAWQ models."""
91
        super().__init__()
92
93
94
        self.model: PreTrainedModel = model
        self.model_type: str = model_type
        self.is_quantized: bool = is_quantized
95
        self.search_result = None
Casper's avatar
Casper committed
96
        self.config: PretrainedConfig = config
Casper's avatar
Casper committed
97
        self.quant_config: AwqConfig = quant_config
98
        self.processor: CLIPImageProcessor = processor
99

Casper's avatar
Casper committed
100
101
    def to(self, device: Annotated[str, Doc("The device to move your model to.")]):
        """A utility function for moving the model to a device."""
102
        return self.model.to(device)
103

104
    def forward(self, *args, **kwargs):
Casper's avatar
Casper committed
105
        """A forward function that mimics the torch forward."""
106
        return self.model(*args, **kwargs)
107

Casper Hansen's avatar
Casper Hansen committed
108
    def generate(self, *args, **kwargs):
Casper's avatar
Casper committed
109
        """A generate function that mimics the HF generate function."""
Casper Hansen's avatar
Casper Hansen committed
110
111
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
112

Casper Hansen's avatar
Casper Hansen committed
113
    @torch.no_grad()
114
115
    def quantize(
        self,
Casper's avatar
Casper committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        tokenizer: Annotated[
            PreTrainedTokenizer, Doc("The tokenizer to use for quantization.")
        ] = None,
        quant_config: Annotated[
            Dict, Doc("The quantization config you want to use.")
        ] = {},
        calib_data: Annotated[
            Union[str, List[str]],
            Doc(
                "The calibration dataset. Either a string pointing to Huggingface or a list of preloaded examples."
            ),
        ] = "pileval",
        split: Annotated[str, Doc("The split of calib_data.")] = "train",
        text_column: Annotated[str, Doc("The text column of calib_data.")] = "text",
        duo_scaling: Annotated[
            bool, Doc("Whether to scale using both w/x or just x.")
        ] = True,
        export_compatible: Annotated[
            bool,
            Doc(
                "This argument avoids real quantization by only applying the scales without quantizing down to FP16."
            ),
        ] = False,
139
    ):
Casper's avatar
Casper committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        """
        The main quantization function that you can use to quantize your model.

        Example:

        ```python
        from awq import AutoAWQForCausalLM
        from transformers import AutoTokenizer

        model_path = "..."
        model = AutoAWQForCausalLM.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model_path)

        quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
        model.quantize(tokenizer, quant_config)
        ```
        """
Casper's avatar
Casper committed
157
        self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
158

Casper's avatar
Casper committed
159
160
161
        if hasattr(self, "modules_to_not_convert"):
            self.quant_config.modules_to_not_convert = self.modules_to_not_convert

162
        self.quantizer = AwqQuantizer(
163
164
165
166
167
            self,
            self.model,
            tokenizer,
            self.quant_config.w_bit,
            self.quant_config.q_group_size,
168
            self.quant_config.zero_point,
169
170
171
172
173
            self.quant_config.version,
            calib_data,
            split,
            text_column,
            duo_scaling,
Casper's avatar
Casper committed
174
            modules_to_not_convert=self.quant_config.modules_to_not_convert,
175
            export_compatible=export_compatible,
Casper Hansen's avatar
Casper Hansen committed
176
        )
177
        self.quantizer.quantize()
178

Casper Hansen's avatar
Casper Hansen committed
179
        self.is_quantized = True
180

181
182
183
184
185
    @torch.no_grad()
    def pack(self):
        """
        A utility function for the following scenario. Note that save_quantized will
        overwrite existing weights if you use the same quant_path.
186

Casper's avatar
Casper committed
187
188
189
        Example:

        ```python
190
191
192
193
        model.quantize(
            tokenizer,
            quant_config=quant_config,
            export_compatible=True
194
        )
195
196
197
        model.save_quantized(...)  # produces GGUF/other compat weights
        model.pack(...) # makes the model CUDA compat
        model.save_quantized(...)  # produces CUDA compat weights
Casper's avatar
Casper committed
198
        ```
199
200
        """
        self.quantizer.pack()
201

qwopqwop200's avatar
qwopqwop200 committed
202
    @staticmethod
Casper's avatar
Casper committed
203
    def fuse_layers(model):
qwopqwop200's avatar
qwopqwop200 committed
204
        pass
Casper's avatar
Casper committed
205

Casper's avatar
Casper committed
206
207
208
209
210
211
212
213
214
215
    def save_quantized(
        self,
        save_dir: Annotated[str, Doc("The directory to save your model to.")],
        safetensors: Annotated[
            bool, Doc("Whether to save the model as safetensors or torch files.")
        ] = True,
        shard_size: Annotated[
            str, Doc("The shard size for sharding large models into multiple chunks.")
        ] = "5GB",
    ):
216
        save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir
217

Casper Hansen's avatar
Casper Hansen committed
218
219
        # Save model
        class EmptyModule(nn.Module):
220
221
222
223
224
            def __init__(self):
                super(EmptyModule, self).__init__()

            def forward(self, x):
                return x
225

Casper's avatar
Casper committed
226
227
        # Save model and config files with empty state dict
        self.model.config.quantization_config = self.quant_config.to_transformers_dict()
228
        self.model.generation_config.do_sample = True
Casper Hansen's avatar
Casper Hansen committed
229
        self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
230

231
232
233
234
        # Vision transformers have a processor
        if self.processor is not None:
            self.processor.save_pretrained(save_dir)

Casper Hansen's avatar
Casper Hansen committed
235
        # Remove empty state dict
236
237
238
239
        default_paths = [
            f"{save_dir}/model.safetensors",
            f"{save_dir}/pytorch_model.bin",
        ]
240
241
242
        for path in default_paths:
            if os.path.exists(path):
                os.remove(path)
243

Casper Hansen's avatar
Casper Hansen committed
244
        # model_name has no extension, add it when saving state_dict
245
        model_name = "model.safetensors" if safetensors else "pytorch_model.bin"
246

Casper Hansen's avatar
Casper Hansen committed
247
248
        # shard checkpoint into chunks (10GB default)
        shards, index = shard_checkpoint(
249
            self.model.state_dict(), max_shard_size=shard_size, weights_name=model_name
Casper Hansen's avatar
Casper Hansen committed
250
        )
251

Casper Hansen's avatar
Casper Hansen committed
252
253
254
255
        for shard_file, shard in shards.items():
            if safetensors:
                # safetensors must be in the same memory, so we duplicate and use contiguous memory
                shard = {k: v.clone().contiguous() for k, v in shard.items()}
256
257
258
                save_file(
                    shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}
                )
Casper Hansen's avatar
Casper Hansen committed
259
260
            else:
                torch.save(shard, os.path.join(save_dir, shard_file))
261

Casper Hansen's avatar
Casper Hansen committed
262
263
        # save shard index
        if index is not None:
264
            with open(f"{save_dir}/{model_name}.index.json", "w+") as file:
Casper Hansen's avatar
Casper Hansen committed
265
                file.write(json.dumps(index, indent=4))
266

267
    @classmethod
268
269
    def from_pretrained(
        self,
Casper's avatar
Casper committed
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        model_path: Annotated[str, Doc("A Huggingface path or local path to a model.")],
        model_type: Annotated[str, Doc("The model type, loaded from config.json.")],
        torch_dtype: Annotated[
            torch.dtype,
            Doc(
                "The dtype to load the model as. May not work with other values than float16."
            ),
        ] = torch.float16,
        trust_remote_code: Annotated[
            bool,
            Doc(
                "Useful for Huggingface repositories that have not been integrated into transformers yet."
            ),
        ] = True,
        safetensors: Annotated[
            bool, Doc("Whether to download/load safetensors instead of torch weights.")
        ] = True,
        device_map: Annotated[
            Union[str, Dict],
            Doc(
                "A device map that will be passed onto the model loading method from transformers."
            ),
        ] = None,
        **model_init_kwargs: Annotated[
            Dict,
            Doc(
                "Additional kwargs that are passed to the model during initialization."
            ),
        ],
299
    ):
Casper's avatar
Casper committed
300
        """A method for initialization of pretrained models, usually in FP16."""
Casper Hansen's avatar
Casper Hansen committed
301
302
        # Get weights path and quant config
        model_weights_path, config, quant_config = self._load_config(
303
            self, model_path, "", safetensors, trust_remote_code=trust_remote_code
304
        )
Casper's avatar
Casper committed
305

306
307
308
309
310
311
312
313
        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)

        processor = None
        if target_cls_name == "AutoModelForVision2Seq":
            processor = AutoProcessor.from_pretrained(model_weights_path)
            processor: CLIPImageProcessor = processor.image_processor

Casper Hansen's avatar
Casper Hansen committed
314
        # If not quantized, must load with AutoModelForCausalLM
315
        model = target_cls.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
316
317
318
319
            model_weights_path,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
            use_safetensors=safetensors,
320
            device_map=device_map,
321
            **model_init_kwargs,
Casper Hansen's avatar
Casper Hansen committed
322
323
324
325
        )

        model.eval()

326
327
328
329
330
331
332
333
        return self(
            model,
            model_type,
            is_quantized=False,
            config=config,
            quant_config=quant_config,
            processor=processor,
        )
Casper Hansen's avatar
Casper Hansen committed
334

335
    @classmethod
336
337
    def from_quantized(
        self,
Casper's avatar
Casper committed
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
        model_path: Annotated[str, Doc("A Huggingface path or local path to a model.")],
        model_type: Annotated[str, Doc("The model type, loaded from config.json.")],
        model_filename: Annotated[
            str, Doc("Load a specific model's filename by specifying this argument.")
        ] = "",
        max_seq_len: Annotated[
            int,
            Doc(
                "The maximum sequence cached sequence length of the model. Larger values may increase loading time and memory usage."
            ),
        ] = None,
        torch_dtype: Annotated[
            torch.dtype,
            Doc(
                "The dtype to load the model as. May not work with other values than float16."
            ),
        ] = torch.float16,
        trust_remote_code: Annotated[
            bool,
            Doc(
                "Useful for Huggingface repositories that have not been integrated into transformers yet."
            ),
        ] = True,
        safetensors: Annotated[
            bool, Doc("Whether to download/load safetensors instead of torch weights.")
        ] = True,
        fuse_layers: Annotated[
            bool,
            Doc(
                "Whether to use fused/optimized combination of layers for increased speed."
            ),
        ] = True,
        use_exllama: Annotated[
            bool, Doc("Whether to map the weights to ExLlamaV1 kernels.")
        ] = False,
        use_exllama_v2: Annotated[
            bool, Doc("Whether to map the weights to ExLlamaV2 kernels.")
        ] = False,
        device_map: Annotated[
            Union[str, Dict],
            Doc(
                "A device map that will be passed onto the model loading method from transformers."
            ),
        ] = "balanced",
        offload_folder: Annotated[
            str,
            Doc("The folder ot offload the model to."),
        ] = None,
        **config_kwargs: Annotated[
            Dict,
            Doc(
                "Additional kwargs that are passed to the config during initialization."
            ),
        ],
392
    ):
Casper's avatar
Casper committed
393
        """A method for initialization of a quantized model, usually in INT4."""
Casper Hansen's avatar
Casper Hansen committed
394
395
        # [STEP 1-2] Load weights path and configs
        model_weights_path, config, quant_config = self._load_config(
396
397
398
399
400
            self,
            model_path,
            model_filename,
            safetensors,
            trust_remote_code,
Casper's avatar
Casper committed
401
            max_seq_len=max_seq_len,
402
            **config_kwargs,
Casper Hansen's avatar
Casper Hansen committed
403
        )
404
405
406

        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)
407

Casper Hansen's avatar
Casper Hansen committed
408
409
        # [STEP 3] Load model
        with init_empty_weights():
410
411
412
413
414
415
            model = target_cls.from_config(
                config=config,
                torch_dtype=torch_dtype,
                trust_remote_code=trust_remote_code,
            )

Casper Hansen's avatar
Casper Hansen committed
416
        # Prepare WQLinear layers, replace nn.Linear
417
418
419
420
421
422
423
424
425
        self._load_quantized_modules(
            self,
            model,
            quant_config,
            quant_config.version,
            use_exllama=use_exllama,
            use_exllama_v2=use_exllama_v2,
        )

Casper Hansen's avatar
Casper Hansen committed
426
427
        model.tie_weights()

428
429
430
        # loads the weights into modules and distributes
        # across available devices automatically
        load_checkpoint_and_dispatch(
Casper Hansen's avatar
Casper Hansen committed
431
432
            model,
            checkpoint=model_weights_path,
s4rduk4r's avatar
s4rduk4r committed
433
            device_map=device_map,
434
            no_split_module_classes=[self.layer_type],
s4rduk4r's avatar
s4rduk4r committed
435
            offload_folder=offload_folder,
436
            dtype=torch_dtype,
Casper Hansen's avatar
Casper Hansen committed
437
        )
438

Casper Hansen's avatar
Casper Hansen committed
439
        # Dispath to devices
440
        if fuse_layers:
Casper's avatar
Casper committed
441
            self.fuse_layers(model)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
442

Casper's avatar
Casper committed
443
        if quant_config.version == "marlin":
444
445
446
            model = marlin_post_init(model)

        elif use_exllama:
447
448
449
            # creates q4 handle
            model = exllama_post_init(model)
        elif use_exllama_v2:
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
450
            # creates q4 handle and allocates scratch spaces wrt max_input_len and max_batch_size
451
452
            model = exllamav2_post_init(
                model,
Casper's avatar
Casper committed
453
                max_input_len=max_seq_len or 2048,
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
454
                max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1)),
455
456
457
458
459
            )

        return self(
            model,
            model_type,
Casper's avatar
Casper committed
460
            is_quantized=True,
461
462
463
464
            config=config,
            quant_config=quant_config,
            processor=None,
        )
s4rduk4r's avatar
s4rduk4r committed
465

466
467
468
469
470
471
    def _load_config(
        self,
        model_path,
        model_filename,
        safetensors=True,
        trust_remote_code=True,
Casper's avatar
Casper committed
472
        max_seq_len=4096,
473
474
        **config_kwargs,
    ):
475
        # [STEP 1] Download model if path is not a directory
476
        if not os.path.isdir(model_path):
477
            ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
478
            if safetensors:
479
                ignore_patterns.extend(["*.pt*", "*.bin*", "consolidated*"])
480
            else:
Casper Hansen's avatar
Casper Hansen committed
481
                ignore_patterns.append("*.safetensors*")
482

483
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
484
485
486

        if model_filename != "":
            model_weights_path = model_path + f"/{model_filename}"
487
488
        else:
            model_weights_path = model_path
489

490
        # [STEP 2] Load config and set sequence length
491
        # TODO: Create BaseAWQConfig class
Casper's avatar
Casper committed
492
        quant_config = AwqConfig.from_pretrained(model_path)
493

494
        # Load model config and set max generation length
Casper's avatar
Casper committed
495
        if max_seq_len is None and hasattr(self, "max_seq_len_key"):
496
497
498
            config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=trust_remote_code, **config_kwargs
            )
Casper's avatar
Casper committed
499
            config.max_seq_len = getattr(config, self.max_seq_len_key, 2048)
500
501
            # To add the generate support for Multi-modal models as well
            if hasattr(config, "text_config"):
Casper's avatar
Casper committed
502
503
                config.text_config.max_seq_len = getattr(
                    config, self.max_seq_len_key, 2048
504
                )
505
        else:
Casper's avatar
Casper committed
506
            max_seq_len = 2048 if max_seq_len is None else max_seq_len
507
508
509
            config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=trust_remote_code, **config_kwargs
            )
Casper's avatar
Casper committed
510
            config.max_seq_len = max_seq_len
511

Casper Hansen's avatar
Casper Hansen committed
512
        return model_weights_path, config, quant_config
Casper's avatar
Casper committed
513

514
515
516
    def _load_quantized_modules(
        self, model, quant_config, version, use_exllama, use_exllama_v2
    ):
517
        # Real quantization of weights
518
        assert not (
Casper's avatar
Casper committed
519
            version == "gemv" and (use_exllama or use_exllama_v2)
520
521
        ), "Exllama kernels only support GEMM version."

522
        # Get blocks of model
523
        layers = self.get_model_layers(model)
524

525
526
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
527
528

            # Get every linear layer in a block
529
            named_linears = get_named_linears(layer)
530

531
            # Filter out the linear layers we don't want to exclude
532
533
534
            named_linears = exclude_layers_to_not_quantize(
                named_linears, quant_config.modules_to_not_convert
            )
535

536
            # Replace activation functions
537
            self._scale_activations(self, layer)
538

539
            # Replace nn.Linear with WQLinear
540
            for name, module in named_linears.items():
Casper's avatar
Casper committed
541
                if version == "marlin":
542
543
                    q_linear_module = WQLinear_Marlin
                elif use_exllama:
544
545
546
                    q_linear_module = WQLinear_Exllama
                elif use_exllama_v2:
                    q_linear_module = WQLinear_ExllamaV2
Casper's avatar
Casper committed
547
                elif version == "gemm":
Casper Hansen's avatar
Casper Hansen committed
548
                    q_linear_module = WQLinear_GEMM
Casper's avatar
Casper committed
549
                elif version == "gemv":
Casper Hansen's avatar
Casper Hansen committed
550
                    q_linear_module = WQLinear_GEMV
Casper's avatar
Casper committed
551
552
                elif version == "gemv_fast":
                    q_linear_module = WQLinear_GEMVFast
553

Casper Hansen's avatar
Casper Hansen committed
554
                q_linear = q_linear_module.from_linear(
555
                    module, quant_config.w_bit, quant_config.q_group_size, True
Casper Hansen's avatar
Casper Hansen committed
556
                )
557
558
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
559

560
561
            torch.cuda.empty_cache()
            gc.collect()
562

563
    @staticmethod
564
    def _scale_activations(self, layer):
565
        scale_dict = self.get_act_for_scaling(layer)
566

567
568
        if scale_dict["is_scalable"]:
            if not isinstance(scale_dict["scale_layer"], ScaledActivation):
569
                param = next(layer.parameters())
570

571
                # get activation scale
572
573
574
                scale_like = torch.ones(
                    scale_dict["scale_shape"], dtype=param.dtype, device=param.device
                )
575

576
                # scale activation
577
578
                scaled_act = ScaledActivation(scale_dict["scale_layer"], scale_like)
                set_op_by_name(layer, scale_dict["scale_name"], scaled_act)