base.py 19.8 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
import torch
5
import transformers
Casper Hansen's avatar
Casper Hansen committed
6
import torch.nn as nn
7

Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper's avatar
Casper committed
9
from typing import List, Union, Dict
10
from safetensors.torch import save_file
Casper's avatar
Casper committed
11
from typing_extensions import Doc, Annotated
12
from huggingface_hub import snapshot_download
13
from transformers.modeling_utils import shard_checkpoint
14

Casper's avatar
Casper committed
15
16
17
18
19
20
21
22
23
24
25
from awq.modules.linear import (
    WQLinear_GEMM,
    WQLinear_GEMV,
    WQLinear_Marlin,
    WQLinear_Exllama,
    WQLinear_ExllamaV2,
    WQLinear_GEMVFast,
    marlin_post_init,
    exllama_post_init,
    exllamav2_post_init,
)
26
27
28
29
30
from awq.utils.module import (
    get_named_linears,
    set_op_by_name,
    exclude_layers_to_not_quantize,
)
Casper's avatar
Casper committed
31
32
33
34
from transformers import (
    AutoConfig,
    PreTrainedModel,
    PretrainedConfig,
35
36
    AutoProcessor,
    CLIPImageProcessor,
Casper's avatar
Casper committed
37
    PreTrainedTokenizer,
Casper's avatar
Casper committed
38
)
39
40
41
42
from accelerate.big_modeling import (
    init_empty_weights,
    load_checkpoint_and_dispatch,
)
Casper's avatar
Casper committed
43

44
45
46
47
48
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name

49
# Since we support different `AutoModelForxxx` from transformers
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# we need to define a custom mapping dict as below:
TRANSFORMERS_AUTO_MAPPING_DICT = {
    "mpt": "AutoModelForCausalLM",
    "llama": "AutoModelForCausalLM",
    "opt": "AutoModelForCausalLM",
    "RefinedWeb": "AutoModelForCausalLM",
    "RefinedWebModel": "AutoModelForCausalLM",
    "falcon": "AutoModelForCausalLM",
    "bloom": "AutoModelForCausalLM",
    "gptj": "AutoModelForCausalLM",
    "gpt_bigcode": "AutoModelForCausalLM",
    "mistral": "AutoModelForCausalLM",
    "mixtral": "AutoModelForCausalLM",
    "gpt_neox": "AutoModelForCausalLM",
    "aquila": "AutoModelForCausalLM",
    "Yi": "AutoModelForCausalLM",
    "qwen": "AutoModelForCausalLM",
Aoyu's avatar
Aoyu committed
67
    "baichuan": "AutoModelForCausalLM",
68
    "llava": "AutoModelForVision2Seq",
Junyang Lin's avatar
Junyang Lin committed
69
    "qwen2": "AutoModelForCausalLM",
70
71
}

72

73
class BaseAWQForCausalLM(nn.Module):
74
    def __init__(
Casper's avatar
Casper committed
75
76
77
78
79
80
81
82
83
84
85
86
87
        self,
        model: Annotated[PreTrainedModel, Doc("The pretrained or quantized model.")],
        model_type: Annotated[str, Doc("The model type, found in config.json.")],
        is_quantized: Annotated[
            bool, Doc("Indicates if the current model is quantized.")
        ],
        config: Annotated[PretrainedConfig, Doc("The config of the model.")],
        quant_config: Annotated[
            AwqConfig, Doc("The quantization config of the model.")
        ],
        processor: Annotated[
            AutoProcessor, Doc("An optional processor, e.g. for vision models.")
        ],
88
    ):
Casper's avatar
Casper committed
89
        """The base model for all AutoAWQ models."""
90
        super().__init__()
91
92
93
        self.model: PreTrainedModel = model
        self.model_type: str = model_type
        self.is_quantized: bool = is_quantized
94
        self.search_result = None
Casper's avatar
Casper committed
95
        self.config: PretrainedConfig = config
Casper's avatar
Casper committed
96
        self.quant_config: AwqConfig = quant_config
97
        self.processor: CLIPImageProcessor = processor
98

Casper's avatar
Casper committed
99
100
    def to(self, device: Annotated[str, Doc("The device to move your model to.")]):
        """A utility function for moving the model to a device."""
101
        return self.model.to(device)
102

103
    def forward(self, *args, **kwargs):
Casper's avatar
Casper committed
104
        """A forward function that mimics the torch forward."""
105
        return self.model(*args, **kwargs)
106

Casper Hansen's avatar
Casper Hansen committed
107
    def generate(self, *args, **kwargs):
Casper's avatar
Casper committed
108
        """A generate function that mimics the HF generate function."""
Casper Hansen's avatar
Casper Hansen committed
109
110
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
111

Casper Hansen's avatar
Casper Hansen committed
112
    @torch.no_grad()
113
114
    def quantize(
        self,
Casper's avatar
Casper committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        tokenizer: Annotated[
            PreTrainedTokenizer, Doc("The tokenizer to use for quantization.")
        ] = None,
        quant_config: Annotated[
            Dict, Doc("The quantization config you want to use.")
        ] = {},
        calib_data: Annotated[
            Union[str, List[str]],
            Doc(
                "The calibration dataset. Either a string pointing to Huggingface or a list of preloaded examples."
            ),
        ] = "pileval",
        split: Annotated[str, Doc("The split of calib_data.")] = "train",
        text_column: Annotated[str, Doc("The text column of calib_data.")] = "text",
        duo_scaling: Annotated[
            bool, Doc("Whether to scale using both w/x or just x.")
        ] = True,
        export_compatible: Annotated[
            bool,
            Doc(
                "This argument avoids real quantization by only applying the scales without quantizing down to FP16."
            ),
        ] = False,
138
    ):
Casper's avatar
Casper committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        """
        The main quantization function that you can use to quantize your model.

        Example:

        ```python
        from awq import AutoAWQForCausalLM
        from transformers import AutoTokenizer

        model_path = "..."
        model = AutoAWQForCausalLM.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model_path)

        quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
        model.quantize(tokenizer, quant_config)
        ```
        """
Casper's avatar
Casper committed
156
        self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
157

Casper's avatar
Casper committed
158
159
160
        if hasattr(self, "modules_to_not_convert"):
            self.quant_config.modules_to_not_convert = self.modules_to_not_convert

161
        self.quantizer = AwqQuantizer(
162
163
164
165
166
            self,
            self.model,
            tokenizer,
            self.quant_config.w_bit,
            self.quant_config.q_group_size,
167
            self.quant_config.zero_point,
168
169
170
171
172
            self.quant_config.version,
            calib_data,
            split,
            text_column,
            duo_scaling,
Casper's avatar
Casper committed
173
            modules_to_not_convert=self.quant_config.modules_to_not_convert,
174
            export_compatible=export_compatible,
Casper Hansen's avatar
Casper Hansen committed
175
        )
176
        self.quantizer.quantize()
177

Casper Hansen's avatar
Casper Hansen committed
178
        self.is_quantized = True
179

180
181
182
183
184
    @torch.no_grad()
    def pack(self):
        """
        A utility function for the following scenario. Note that save_quantized will
        overwrite existing weights if you use the same quant_path.
185

Casper's avatar
Casper committed
186
187
188
        Example:

        ```python
189
190
191
192
        model.quantize(
            tokenizer,
            quant_config=quant_config,
            export_compatible=True
193
        )
194
195
196
        model.save_quantized(...)  # produces GGUF/other compat weights
        model.pack(...) # makes the model CUDA compat
        model.save_quantized(...)  # produces CUDA compat weights
Casper's avatar
Casper committed
197
        ```
198
199
        """
        self.quantizer.pack()
200

qwopqwop200's avatar
qwopqwop200 committed
201
    @staticmethod
Casper's avatar
Casper committed
202
    def fuse_layers(model):
qwopqwop200's avatar
qwopqwop200 committed
203
        pass
Casper's avatar
Casper committed
204

Casper's avatar
Casper committed
205
206
207
208
209
210
211
212
213
214
    def save_quantized(
        self,
        save_dir: Annotated[str, Doc("The directory to save your model to.")],
        safetensors: Annotated[
            bool, Doc("Whether to save the model as safetensors or torch files.")
        ] = True,
        shard_size: Annotated[
            str, Doc("The shard size for sharding large models into multiple chunks.")
        ] = "5GB",
    ):
215
        save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir
216

Casper Hansen's avatar
Casper Hansen committed
217
218
        # Save model
        class EmptyModule(nn.Module):
219
220
221
222
223
            def __init__(self):
                super(EmptyModule, self).__init__()

            def forward(self, x):
                return x
224

Casper's avatar
Casper committed
225
226
        # Save model and config files with empty state dict
        self.model.config.quantization_config = self.quant_config.to_transformers_dict()
227
        self.model.generation_config.do_sample = True
Casper Hansen's avatar
Casper Hansen committed
228
        self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
229

230
231
232
233
        # Vision transformers have a processor
        if self.processor is not None:
            self.processor.save_pretrained(save_dir)

Casper Hansen's avatar
Casper Hansen committed
234
        # Remove empty state dict
235
236
237
238
        default_paths = [
            f"{save_dir}/model.safetensors",
            f"{save_dir}/pytorch_model.bin",
        ]
239
240
241
        for path in default_paths:
            if os.path.exists(path):
                os.remove(path)
242

Casper Hansen's avatar
Casper Hansen committed
243
        # model_name has no extension, add it when saving state_dict
244
        model_name = "model.safetensors" if safetensors else "pytorch_model.bin"
245

Casper Hansen's avatar
Casper Hansen committed
246
247
        # shard checkpoint into chunks (10GB default)
        shards, index = shard_checkpoint(
248
            self.model.state_dict(), max_shard_size=shard_size, weights_name=model_name
Casper Hansen's avatar
Casper Hansen committed
249
        )
250

Casper Hansen's avatar
Casper Hansen committed
251
252
253
254
        for shard_file, shard in shards.items():
            if safetensors:
                # safetensors must be in the same memory, so we duplicate and use contiguous memory
                shard = {k: v.clone().contiguous() for k, v in shard.items()}
255
256
257
                save_file(
                    shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}
                )
Casper Hansen's avatar
Casper Hansen committed
258
259
            else:
                torch.save(shard, os.path.join(save_dir, shard_file))
260

Casper Hansen's avatar
Casper Hansen committed
261
262
        # save shard index
        if index is not None:
263
            with open(f"{save_dir}/{model_name}.index.json", "w+") as file:
Casper Hansen's avatar
Casper Hansen committed
264
                file.write(json.dumps(index, indent=4))
265

266
    @classmethod
267
268
    def from_pretrained(
        self,
Casper's avatar
Casper committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        model_path: Annotated[str, Doc("A Huggingface path or local path to a model.")],
        model_type: Annotated[str, Doc("The model type, loaded from config.json.")],
        torch_dtype: Annotated[
            torch.dtype,
            Doc(
                "The dtype to load the model as. May not work with other values than float16."
            ),
        ] = torch.float16,
        trust_remote_code: Annotated[
            bool,
            Doc(
                "Useful for Huggingface repositories that have not been integrated into transformers yet."
            ),
        ] = True,
        safetensors: Annotated[
            bool, Doc("Whether to download/load safetensors instead of torch weights.")
        ] = True,
        device_map: Annotated[
            Union[str, Dict],
            Doc(
                "A device map that will be passed onto the model loading method from transformers."
            ),
        ] = None,
        **model_init_kwargs: Annotated[
            Dict,
            Doc(
                "Additional kwargs that are passed to the model during initialization."
            ),
        ],
298
    ):
Casper's avatar
Casper committed
299
        """A method for initialization of pretrained models, usually in FP16."""
Casper Hansen's avatar
Casper Hansen committed
300
301
        # Get weights path and quant config
        model_weights_path, config, quant_config = self._load_config(
302
            self, model_path, "", safetensors, trust_remote_code=trust_remote_code
303
        )
Casper's avatar
Casper committed
304

305
306
307
308
309
310
311
312
        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)

        processor = None
        if target_cls_name == "AutoModelForVision2Seq":
            processor = AutoProcessor.from_pretrained(model_weights_path)
            processor: CLIPImageProcessor = processor.image_processor

Casper Hansen's avatar
Casper Hansen committed
313
        # If not quantized, must load with AutoModelForCausalLM
314
        model = target_cls.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
315
316
317
318
            model_weights_path,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
            use_safetensors=safetensors,
319
            device_map=device_map,
320
            **model_init_kwargs,
Casper Hansen's avatar
Casper Hansen committed
321
322
323
324
        )

        model.eval()

325
326
327
328
329
330
331
332
        return self(
            model,
            model_type,
            is_quantized=False,
            config=config,
            quant_config=quant_config,
            processor=processor,
        )
Casper Hansen's avatar
Casper Hansen committed
333

334
    @classmethod
335
336
    def from_quantized(
        self,
Casper's avatar
Casper committed
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
        model_path: Annotated[str, Doc("A Huggingface path or local path to a model.")],
        model_type: Annotated[str, Doc("The model type, loaded from config.json.")],
        model_filename: Annotated[
            str, Doc("Load a specific model's filename by specifying this argument.")
        ] = "",
        max_seq_len: Annotated[
            int,
            Doc(
                "The maximum sequence cached sequence length of the model. Larger values may increase loading time and memory usage."
            ),
        ] = None,
        torch_dtype: Annotated[
            torch.dtype,
            Doc(
                "The dtype to load the model as. May not work with other values than float16."
            ),
        ] = torch.float16,
        trust_remote_code: Annotated[
            bool,
            Doc(
                "Useful for Huggingface repositories that have not been integrated into transformers yet."
            ),
        ] = True,
        safetensors: Annotated[
            bool, Doc("Whether to download/load safetensors instead of torch weights.")
        ] = True,
        fuse_layers: Annotated[
            bool,
            Doc(
                "Whether to use fused/optimized combination of layers for increased speed."
            ),
        ] = True,
        use_exllama: Annotated[
            bool, Doc("Whether to map the weights to ExLlamaV1 kernels.")
        ] = False,
        use_exllama_v2: Annotated[
            bool, Doc("Whether to map the weights to ExLlamaV2 kernels.")
        ] = False,
        device_map: Annotated[
            Union[str, Dict],
            Doc(
                "A device map that will be passed onto the model loading method from transformers."
            ),
        ] = "balanced",
        offload_folder: Annotated[
            str,
            Doc("The folder ot offload the model to."),
        ] = None,
        **config_kwargs: Annotated[
            Dict,
            Doc(
                "Additional kwargs that are passed to the config during initialization."
            ),
        ],
391
    ):
Casper's avatar
Casper committed
392
        """A method for initialization of a quantized model, usually in INT4."""
Casper Hansen's avatar
Casper Hansen committed
393
394
        # [STEP 1-2] Load weights path and configs
        model_weights_path, config, quant_config = self._load_config(
395
396
397
398
399
            self,
            model_path,
            model_filename,
            safetensors,
            trust_remote_code,
Casper's avatar
Casper committed
400
            max_seq_len=max_seq_len,
401
            **config_kwargs,
Casper Hansen's avatar
Casper Hansen committed
402
        )
403
404
405

        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)
406

Casper Hansen's avatar
Casper Hansen committed
407
408
        # [STEP 3] Load model
        with init_empty_weights():
409
410
411
412
413
414
            model = target_cls.from_config(
                config=config,
                torch_dtype=torch_dtype,
                trust_remote_code=trust_remote_code,
            )

Casper Hansen's avatar
Casper Hansen committed
415
        # Prepare WQLinear layers, replace nn.Linear
416
417
418
419
420
421
422
423
424
        self._load_quantized_modules(
            self,
            model,
            quant_config,
            quant_config.version,
            use_exllama=use_exllama,
            use_exllama_v2=use_exllama_v2,
        )

Casper Hansen's avatar
Casper Hansen committed
425
426
        model.tie_weights()

427
428
429
        # loads the weights into modules and distributes
        # across available devices automatically
        load_checkpoint_and_dispatch(
Casper Hansen's avatar
Casper Hansen committed
430
431
            model,
            checkpoint=model_weights_path,
s4rduk4r's avatar
s4rduk4r committed
432
            device_map=device_map,
433
            no_split_module_classes=[self.layer_type],
s4rduk4r's avatar
s4rduk4r committed
434
            offload_folder=offload_folder,
435
            dtype=torch_dtype,
Casper Hansen's avatar
Casper Hansen committed
436
        )
437

Casper Hansen's avatar
Casper Hansen committed
438
        # Dispath to devices
439
        if fuse_layers:
Casper's avatar
Casper committed
440
            self.fuse_layers(model)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
441

Casper's avatar
Casper committed
442
        if quant_config.version == "marlin":
443
444
445
            model = marlin_post_init(model)

        elif use_exllama:
446
447
448
            # creates q4 handle
            model = exllama_post_init(model)
        elif use_exllama_v2:
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
449
            # creates q4 handle and allocates scratch spaces wrt max_input_len and max_batch_size
450
451
            model = exllamav2_post_init(
                model,
Casper's avatar
Casper committed
452
                max_input_len=max_seq_len or 2048,
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
453
                max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1)),
454
455
456
457
458
            )

        return self(
            model,
            model_type,
Casper's avatar
Casper committed
459
            is_quantized=True,
460
461
462
463
            config=config,
            quant_config=quant_config,
            processor=None,
        )
s4rduk4r's avatar
s4rduk4r committed
464

465
466
467
468
469
470
    def _load_config(
        self,
        model_path,
        model_filename,
        safetensors=True,
        trust_remote_code=True,
Casper's avatar
Casper committed
471
        max_seq_len=4096,
472
473
        **config_kwargs,
    ):
474
        # [STEP 1] Download model if path is not a directory
475
        if not os.path.isdir(model_path):
476
            ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
477
            if safetensors:
478
                ignore_patterns.extend(["*.pt*", "*.bin*", "consolidated*"])
479
            else:
Casper Hansen's avatar
Casper Hansen committed
480
                ignore_patterns.append("*.safetensors*")
481

482
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
483
484
485

        if model_filename != "":
            model_weights_path = model_path + f"/{model_filename}"
486
487
        else:
            model_weights_path = model_path
488

489
        # [STEP 2] Load config and set sequence length
490
        # TODO: Create BaseAWQConfig class
Casper's avatar
Casper committed
491
        quant_config = AwqConfig.from_pretrained(model_path)
492

493
        # Load model config and set max generation length
Casper's avatar
Casper committed
494
        if max_seq_len is None and hasattr(self, "max_seq_len_key"):
495
496
497
            config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=trust_remote_code, **config_kwargs
            )
Casper's avatar
Casper committed
498
            config.max_seq_len = getattr(config, self.max_seq_len_key, 2048)
499
500
            # To add the generate support for Multi-modal models as well
            if hasattr(config, "text_config"):
Casper's avatar
Casper committed
501
502
                config.text_config.max_seq_len = getattr(
                    config, self.max_seq_len_key, 2048
503
                )
504
        else:
Casper's avatar
Casper committed
505
            max_seq_len = 2048 if max_seq_len is None else max_seq_len
506
507
508
            config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=trust_remote_code, **config_kwargs
            )
Casper's avatar
Casper committed
509
            config.max_seq_len = max_seq_len
510

Casper Hansen's avatar
Casper Hansen committed
511
        return model_weights_path, config, quant_config
Casper's avatar
Casper committed
512

513
514
515
    def _load_quantized_modules(
        self, model, quant_config, version, use_exllama, use_exllama_v2
    ):
516
        # Real quantization of weights
517
        assert not (
Casper's avatar
Casper committed
518
            version == "gemv" and (use_exllama or use_exllama_v2)
519
520
        ), "Exllama kernels only support GEMM version."

521
        # Get blocks of model
522
        layers = self.get_model_layers(model)
523

524
525
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
526
527

            # Get every linear layer in a block
528
            named_linears = get_named_linears(layer)
529

530
            # Filter out the linear layers we don't want to exclude
531
532
533
            named_linears = exclude_layers_to_not_quantize(
                named_linears, quant_config.modules_to_not_convert
            )
534

535
            # Replace activation functions
536
            self._scale_activations(self, layer)
537

538
            # Replace nn.Linear with WQLinear
539
            for name, module in named_linears.items():
Casper's avatar
Casper committed
540
                if version == "marlin":
541
542
                    q_linear_module = WQLinear_Marlin
                elif use_exllama:
543
544
545
                    q_linear_module = WQLinear_Exllama
                elif use_exllama_v2:
                    q_linear_module = WQLinear_ExllamaV2
Casper's avatar
Casper committed
546
                elif version == "gemm":
Casper Hansen's avatar
Casper Hansen committed
547
                    q_linear_module = WQLinear_GEMM
Casper's avatar
Casper committed
548
                elif version == "gemv":
Casper Hansen's avatar
Casper Hansen committed
549
                    q_linear_module = WQLinear_GEMV
Casper's avatar
Casper committed
550
551
                elif version == "gemv_fast":
                    q_linear_module = WQLinear_GEMVFast
552

Casper Hansen's avatar
Casper Hansen committed
553
                q_linear = q_linear_module.from_linear(
554
                    module, quant_config.w_bit, quant_config.q_group_size, True
Casper Hansen's avatar
Casper Hansen committed
555
                )
556
557
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
558

559
560
            torch.cuda.empty_cache()
            gc.collect()
561

562
    @staticmethod
563
    def _scale_activations(self, layer):
564
        scale_dict = self.get_act_for_scaling(layer)
565

566
567
        if scale_dict["is_scalable"]:
            if not isinstance(scale_dict["scale_layer"], ScaledActivation):
568
                param = next(layer.parameters())
569

570
                # get activation scale
571
572
573
                scale_like = torch.ones(
                    scale_dict["scale_shape"], dtype=param.dtype, device=param.device
                )
574

575
                # scale activation
576
577
                scaled_act = ScaledActivation(scale_dict["scale_layer"], scale_like)
                set_op_by_name(layer, scale_dict["scale_name"], scaled_act)