base.py 21.4 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
import torch
5
import transformers
Casper Hansen's avatar
Casper Hansen committed
6
import torch.nn as nn
7

Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper's avatar
Casper committed
9
from typing import List, Union, Dict
10
from safetensors.torch import save_file
Casper's avatar
Casper committed
11
from typing_extensions import Doc, Annotated
12
from huggingface_hub import snapshot_download
13
from transformers.modeling_utils import shard_checkpoint
14

Casper's avatar
Casper committed
15
16
17
18
19
20
21
22
23
24
25
from awq.modules.linear import (
    WQLinear_GEMM,
    WQLinear_GEMV,
    WQLinear_Marlin,
    WQLinear_Exllama,
    WQLinear_ExllamaV2,
    WQLinear_GEMVFast,
    marlin_post_init,
    exllama_post_init,
    exllamav2_post_init,
)
26
27
28
29
30
from awq.utils.module import (
    get_named_linears,
    set_op_by_name,
    exclude_layers_to_not_quantize,
)
Casper's avatar
Casper committed
31
32
33
34
from transformers import (
    AutoConfig,
    PreTrainedModel,
    PretrainedConfig,
35
36
    AutoProcessor,
    CLIPImageProcessor,
Casper's avatar
Casper committed
37
    PreTrainedTokenizer,
Casper's avatar
Casper committed
38
)
39
40
41
42
from accelerate.big_modeling import (
    init_empty_weights,
    load_checkpoint_and_dispatch,
)
Casper's avatar
Casper committed
43

44
45
46
47
48
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name

49
# Since we support different `AutoModelForxxx` from transformers
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# we need to define a custom mapping dict as below:
TRANSFORMERS_AUTO_MAPPING_DICT = {
    "mpt": "AutoModelForCausalLM",
    "llama": "AutoModelForCausalLM",
    "opt": "AutoModelForCausalLM",
    "RefinedWeb": "AutoModelForCausalLM",
    "RefinedWebModel": "AutoModelForCausalLM",
    "falcon": "AutoModelForCausalLM",
    "bloom": "AutoModelForCausalLM",
    "gptj": "AutoModelForCausalLM",
    "gpt_bigcode": "AutoModelForCausalLM",
    "mistral": "AutoModelForCausalLM",
    "mixtral": "AutoModelForCausalLM",
    "gpt_neox": "AutoModelForCausalLM",
    "aquila": "AutoModelForCausalLM",
    "Yi": "AutoModelForCausalLM",
    "qwen": "AutoModelForCausalLM",
Aoyu's avatar
Aoyu committed
67
    "baichuan": "AutoModelForCausalLM",
68
    "llava": "AutoModelForVision2Seq",
Junyang Lin's avatar
Junyang Lin committed
69
    "qwen2": "AutoModelForCausalLM",
TechxGenus's avatar
TechxGenus committed
70
    "gemma": "AutoModelForCausalLM",
Isotr0py's avatar
Isotr0py committed
71
    "stablelm": "AutoModelForCausalLM",
少年's avatar
少年 committed
72
    "starcoder2": "AutoModelForCausalLM",
73
74
}

75

76
class BaseAWQForCausalLM(nn.Module):
77
    def __init__(
Casper's avatar
Casper committed
78
79
80
81
82
83
84
85
86
87
88
89
90
        self,
        model: Annotated[PreTrainedModel, Doc("The pretrained or quantized model.")],
        model_type: Annotated[str, Doc("The model type, found in config.json.")],
        is_quantized: Annotated[
            bool, Doc("Indicates if the current model is quantized.")
        ],
        config: Annotated[PretrainedConfig, Doc("The config of the model.")],
        quant_config: Annotated[
            AwqConfig, Doc("The quantization config of the model.")
        ],
        processor: Annotated[
            AutoProcessor, Doc("An optional processor, e.g. for vision models.")
        ],
91
    ):
Casper's avatar
Casper committed
92
        """The base model for all AutoAWQ models."""
93
        super().__init__()
94
95
96
        self.model: PreTrainedModel = model
        self.model_type: str = model_type
        self.is_quantized: bool = is_quantized
97
        self.search_result = None
Casper's avatar
Casper committed
98
        self.config: PretrainedConfig = config
Casper's avatar
Casper committed
99
        self.quant_config: AwqConfig = quant_config
100
        self.processor: CLIPImageProcessor = processor
101

Casper's avatar
Casper committed
102
103
    def to(self, device: Annotated[str, Doc("The device to move your model to.")]):
        """A utility function for moving the model to a device."""
104
        return self.model.to(device)
105

106
    def forward(self, *args, **kwargs):
Casper's avatar
Casper committed
107
        """A forward function that mimics the torch forward."""
108
        return self.model(*args, **kwargs)
109

Casper Hansen's avatar
Casper Hansen committed
110
    def generate(self, *args, **kwargs):
Casper's avatar
Casper committed
111
        """A generate function that mimics the HF generate function."""
Casper Hansen's avatar
Casper Hansen committed
112
113
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
114

Casper Hansen's avatar
Casper Hansen committed
115
    @torch.no_grad()
116
117
    def quantize(
        self,
Casper's avatar
Casper committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        tokenizer: Annotated[
            PreTrainedTokenizer, Doc("The tokenizer to use for quantization.")
        ] = None,
        quant_config: Annotated[
            Dict, Doc("The quantization config you want to use.")
        ] = {},
        calib_data: Annotated[
            Union[str, List[str]],
            Doc(
                "The calibration dataset. Either a string pointing to Huggingface or a list of preloaded examples."
            ),
        ] = "pileval",
        split: Annotated[str, Doc("The split of calib_data.")] = "train",
        text_column: Annotated[str, Doc("The text column of calib_data.")] = "text",
        duo_scaling: Annotated[
            bool, Doc("Whether to scale using both w/x or just x.")
        ] = True,
        export_compatible: Annotated[
            bool,
            Doc(
                "This argument avoids real quantization by only applying the scales without quantizing down to FP16."
            ),
        ] = False,
141
142
143
144
145
146
        apply_clip: Annotated[
            bool,
            Doc(
                "Whether to apply clipping to the model during quantization. Some models may perform better with this set to False."
            ),
        ] = True,
147
    ):
Casper's avatar
Casper committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        """
        The main quantization function that you can use to quantize your model.

        Example:

        ```python
        from awq import AutoAWQForCausalLM
        from transformers import AutoTokenizer

        model_path = "..."
        model = AutoAWQForCausalLM.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model_path)

        quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
        model.quantize(tokenizer, quant_config)
        ```
        """
Casper's avatar
Casper committed
165
        self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
166

Casper's avatar
Casper committed
167
168
169
        if hasattr(self, "modules_to_not_convert"):
            self.quant_config.modules_to_not_convert = self.modules_to_not_convert

170
        self.quantizer = AwqQuantizer(
171
172
173
174
175
            self,
            self.model,
            tokenizer,
            self.quant_config.w_bit,
            self.quant_config.q_group_size,
176
            self.quant_config.zero_point,
177
178
179
180
181
            self.quant_config.version,
            calib_data,
            split,
            text_column,
            duo_scaling,
Casper's avatar
Casper committed
182
            modules_to_not_convert=self.quant_config.modules_to_not_convert,
183
            export_compatible=export_compatible,
184
            apply_clip=apply_clip,
Casper Hansen's avatar
Casper Hansen committed
185
        )
186
        self.quantizer.quantize()
187

Casper Hansen's avatar
Casper Hansen committed
188
        self.is_quantized = True
189

190
191
192
193
194
    @torch.no_grad()
    def pack(self):
        """
        A utility function for the following scenario. Note that save_quantized will
        overwrite existing weights if you use the same quant_path.
195

Casper's avatar
Casper committed
196
197
198
        Example:

        ```python
199
200
201
202
        model.quantize(
            tokenizer,
            quant_config=quant_config,
            export_compatible=True
203
        )
204
205
206
        model.save_quantized(...)  # produces GGUF/other compat weights
        model.pack(...) # makes the model CUDA compat
        model.save_quantized(...)  # produces CUDA compat weights
Casper's avatar
Casper committed
207
        ```
208
209
        """
        self.quantizer.pack()
210

qwopqwop200's avatar
qwopqwop200 committed
211
    @staticmethod
Casper's avatar
Casper committed
212
    def fuse_layers(model):
qwopqwop200's avatar
qwopqwop200 committed
213
        pass
Casper's avatar
Casper committed
214

Casper's avatar
Casper committed
215
216
217
218
219
220
221
222
223
224
    def save_quantized(
        self,
        save_dir: Annotated[str, Doc("The directory to save your model to.")],
        safetensors: Annotated[
            bool, Doc("Whether to save the model as safetensors or torch files.")
        ] = True,
        shard_size: Annotated[
            str, Doc("The shard size for sharding large models into multiple chunks.")
        ] = "5GB",
    ):
225
        save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir
226

Casper Hansen's avatar
Casper Hansen committed
227
228
        # Save model
        class EmptyModule(nn.Module):
229
230
231
232
233
            def __init__(self):
                super(EmptyModule, self).__init__()

            def forward(self, x):
                return x
234

Casper's avatar
Casper committed
235
236
        # Save model and config files with empty state dict
        self.model.config.quantization_config = self.quant_config.to_transformers_dict()
237
        self.model.generation_config.do_sample = True
Casper Hansen's avatar
Casper Hansen committed
238
        self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
239

240
241
242
243
        # Vision transformers have a processor
        if self.processor is not None:
            self.processor.save_pretrained(save_dir)

Casper Hansen's avatar
Casper Hansen committed
244
        # Remove empty state dict
245
246
247
248
        default_paths = [
            f"{save_dir}/model.safetensors",
            f"{save_dir}/pytorch_model.bin",
        ]
249
250
251
        for path in default_paths:
            if os.path.exists(path):
                os.remove(path)
252

Casper Hansen's avatar
Casper Hansen committed
253
        # model_name has no extension, add it when saving state_dict
254
        model_name = "model.safetensors" if safetensors else "pytorch_model.bin"
255

Casper Hansen's avatar
Casper Hansen committed
256
257
        # shard checkpoint into chunks (10GB default)
        shards, index = shard_checkpoint(
258
            self.model.state_dict(), max_shard_size=shard_size, weights_name=model_name
Casper Hansen's avatar
Casper Hansen committed
259
        )
260

Casper Hansen's avatar
Casper Hansen committed
261
262
263
264
        for shard_file, shard in shards.items():
            if safetensors:
                # safetensors must be in the same memory, so we duplicate and use contiguous memory
                shard = {k: v.clone().contiguous() for k, v in shard.items()}
265
266
267
                save_file(
                    shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}
                )
Casper Hansen's avatar
Casper Hansen committed
268
269
            else:
                torch.save(shard, os.path.join(save_dir, shard_file))
270

Casper Hansen's avatar
Casper Hansen committed
271
272
        # save shard index
        if index is not None:
273
            with open(f"{save_dir}/{model_name}.index.json", "w+") as file:
Casper Hansen's avatar
Casper Hansen committed
274
                file.write(json.dumps(index, indent=4))
275

276
    @classmethod
277
278
    def from_pretrained(
        self,
Casper's avatar
Casper committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
        model_path: Annotated[str, Doc("A Huggingface path or local path to a model.")],
        model_type: Annotated[str, Doc("The model type, loaded from config.json.")],
        torch_dtype: Annotated[
            torch.dtype,
            Doc(
                "The dtype to load the model as. May not work with other values than float16."
            ),
        ] = torch.float16,
        trust_remote_code: Annotated[
            bool,
            Doc(
                "Useful for Huggingface repositories that have not been integrated into transformers yet."
            ),
        ] = True,
        safetensors: Annotated[
            bool, Doc("Whether to download/load safetensors instead of torch weights.")
        ] = True,
        device_map: Annotated[
            Union[str, Dict],
            Doc(
                "A device map that will be passed onto the model loading method from transformers."
            ),
        ] = None,
302
303
304
        download_kwargs: Annotated[
            Dict, Doc("Used for configure download model"),
        ] = None,
Casper's avatar
Casper committed
305
306
307
308
309
310
        **model_init_kwargs: Annotated[
            Dict,
            Doc(
                "Additional kwargs that are passed to the model during initialization."
            ),
        ],
311
    ):
Casper's avatar
Casper committed
312
        """A method for initialization of pretrained models, usually in FP16."""
Casper Hansen's avatar
Casper Hansen committed
313
314
        # Get weights path and quant config
        model_weights_path, config, quant_config = self._load_config(
315
316
317
            self, model_path, "", safetensors,
            trust_remote_code=trust_remote_code,
            download_kwargs=download_kwargs
318
        )
Casper's avatar
Casper committed
319

320
321
322
323
324
325
326
327
        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)

        processor = None
        if target_cls_name == "AutoModelForVision2Seq":
            processor = AutoProcessor.from_pretrained(model_weights_path)
            processor: CLIPImageProcessor = processor.image_processor

Casper Hansen's avatar
Casper Hansen committed
328
        # If not quantized, must load with AutoModelForCausalLM
329
        model = target_cls.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
330
331
332
333
            model_weights_path,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
            use_safetensors=safetensors,
334
            device_map=device_map,
335
            **model_init_kwargs,
Casper Hansen's avatar
Casper Hansen committed
336
337
338
339
        )

        model.eval()

340
341
342
343
344
345
346
347
        return self(
            model,
            model_type,
            is_quantized=False,
            config=config,
            quant_config=quant_config,
            processor=processor,
        )
Casper Hansen's avatar
Casper Hansen committed
348

349
    @classmethod
350
351
    def from_quantized(
        self,
Casper's avatar
Casper committed
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
        model_path: Annotated[str, Doc("A Huggingface path or local path to a model.")],
        model_type: Annotated[str, Doc("The model type, loaded from config.json.")],
        model_filename: Annotated[
            str, Doc("Load a specific model's filename by specifying this argument.")
        ] = "",
        max_seq_len: Annotated[
            int,
            Doc(
                "The maximum sequence cached sequence length of the model. Larger values may increase loading time and memory usage."
            ),
        ] = None,
        torch_dtype: Annotated[
            torch.dtype,
            Doc(
                "The dtype to load the model as. May not work with other values than float16."
            ),
        ] = torch.float16,
        trust_remote_code: Annotated[
            bool,
            Doc(
                "Useful for Huggingface repositories that have not been integrated into transformers yet."
            ),
        ] = True,
        safetensors: Annotated[
            bool, Doc("Whether to download/load safetensors instead of torch weights.")
        ] = True,
        fuse_layers: Annotated[
            bool,
            Doc(
                "Whether to use fused/optimized combination of layers for increased speed."
            ),
        ] = True,
        use_exllama: Annotated[
            bool, Doc("Whether to map the weights to ExLlamaV1 kernels.")
        ] = False,
        use_exllama_v2: Annotated[
            bool, Doc("Whether to map the weights to ExLlamaV2 kernels.")
        ] = False,
        device_map: Annotated[
            Union[str, Dict],
            Doc(
                "A device map that will be passed onto the model loading method from transformers."
            ),
        ] = "balanced",
396
397
398
399
400
401
        max_memory: Annotated[
            Dict[Union[int, str], Union[int, str]], 
            Doc(
                'A dictionary device identifier to maximum memory which will be passed onto the model loading method from transformers. For example:{0: "4GB",1: "10GB"'
            ),
        ] = None,
Casper's avatar
Casper committed
402
403
404
405
        offload_folder: Annotated[
            str,
            Doc("The folder ot offload the model to."),
        ] = None,
406
407
408
        download_kwargs: Annotated[
            Dict, Doc("Used for configure download model"),
        ] = None,
Casper's avatar
Casper committed
409
410
411
412
413
414
        **config_kwargs: Annotated[
            Dict,
            Doc(
                "Additional kwargs that are passed to the config during initialization."
            ),
        ],
415
    ):
Casper's avatar
Casper committed
416
        """A method for initialization of a quantized model, usually in INT4."""
Casper Hansen's avatar
Casper Hansen committed
417
418
        # [STEP 1-2] Load weights path and configs
        model_weights_path, config, quant_config = self._load_config(
419
420
421
422
423
            self,
            model_path,
            model_filename,
            safetensors,
            trust_remote_code,
Casper's avatar
Casper committed
424
            max_seq_len=max_seq_len,
425
            download_kwargs=download_kwargs,
426
            **config_kwargs,
Casper Hansen's avatar
Casper Hansen committed
427
        )
428
429
430

        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)
431

Casper Hansen's avatar
Casper Hansen committed
432
433
        # [STEP 3] Load model
        with init_empty_weights():
434
435
436
437
438
439
            model = target_cls.from_config(
                config=config,
                torch_dtype=torch_dtype,
                trust_remote_code=trust_remote_code,
            )

Casper Hansen's avatar
Casper Hansen committed
440
        # Prepare WQLinear layers, replace nn.Linear
441
442
443
444
445
446
447
448
449
        self._load_quantized_modules(
            self,
            model,
            quant_config,
            quant_config.version,
            use_exllama=use_exllama,
            use_exllama_v2=use_exllama_v2,
        )

Casper Hansen's avatar
Casper Hansen committed
450
451
        model.tie_weights()

452
453
454
        # loads the weights into modules and distributes
        # across available devices automatically
        load_checkpoint_and_dispatch(
Casper Hansen's avatar
Casper Hansen committed
455
456
            model,
            checkpoint=model_weights_path,
s4rduk4r's avatar
s4rduk4r committed
457
            device_map=device_map,
458
            max_memory=max_memory,
459
            no_split_module_classes=[self.layer_type],
s4rduk4r's avatar
s4rduk4r committed
460
            offload_folder=offload_folder,
461
            dtype=torch_dtype,
Casper Hansen's avatar
Casper Hansen committed
462
        )
463

Casper Hansen's avatar
Casper Hansen committed
464
        # Dispath to devices
465
        if fuse_layers:
Casper's avatar
Casper committed
466
            self.fuse_layers(model)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
467

Casper's avatar
Casper committed
468
        if quant_config.version == "marlin":
469
470
471
            model = marlin_post_init(model)

        elif use_exllama:
472
473
474
            # creates q4 handle
            model = exllama_post_init(model)
        elif use_exllama_v2:
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
475
            # creates q4 handle and allocates scratch spaces wrt max_input_len and max_batch_size
476
477
            model = exllamav2_post_init(
                model,
Casper's avatar
Casper committed
478
                max_input_len=max_seq_len or 2048,
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
479
                max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1)),
480
481
482
483
484
            )

        return self(
            model,
            model_type,
Casper's avatar
Casper committed
485
            is_quantized=True,
486
487
488
489
            config=config,
            quant_config=quant_config,
            processor=None,
        )
s4rduk4r's avatar
s4rduk4r committed
490

491
492
493
494
495
496
    def _load_config(
        self,
        model_path,
        model_filename,
        safetensors=True,
        trust_remote_code=True,
Casper's avatar
Casper committed
497
        max_seq_len=4096,
498
        download_kwargs=None,
499
500
        **config_kwargs,
    ):
501
        # [STEP 1] Download model if path is not a directory
502
        if not os.path.isdir(model_path):
503
            ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
504
            if safetensors:
505
                ignore_patterns.extend(["*.pt*", "*.bin*", "consolidated*"])
506
            else:
Casper Hansen's avatar
Casper Hansen committed
507
                ignore_patterns.append("*.safetensors*")
508
509
510
511
512
513
514
515
516
517
518
519
520
            
            if download_kwargs is None:
                download_kwargs = {}
            
            if "ignore_patterns" in download_kwargs:
                download_kwargs_ignore_patterns = download_kwargs.pop("ignore_patterns")

                if isinstance(download_kwargs_ignore_patterns, str):
                    ignore_patterns.append(download_kwargs_ignore_patterns)
                elif isinstance(download_kwargs_ignore_patterns, list):
                    ignore_patterns.extend(download_kwargs_ignore_patterns)

            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns, **download_kwargs)
521
522
523

        if model_filename != "":
            model_weights_path = model_path + f"/{model_filename}"
524
525
        else:
            model_weights_path = model_path
526

527
        # [STEP 2] Load config and set sequence length
528
        # TODO: Create BaseAWQConfig class
Casper's avatar
Casper committed
529
        quant_config = AwqConfig.from_pretrained(model_path)
530

531
        # Load model config and set max generation length
Casper's avatar
Casper committed
532
        if max_seq_len is None and hasattr(self, "max_seq_len_key"):
533
534
535
            config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=trust_remote_code, **config_kwargs
            )
Casper's avatar
Casper committed
536
            config.max_seq_len = getattr(config, self.max_seq_len_key, 2048)
537
538
            # To add the generate support for Multi-modal models as well
            if hasattr(config, "text_config"):
Casper's avatar
Casper committed
539
540
                config.text_config.max_seq_len = getattr(
                    config, self.max_seq_len_key, 2048
541
                )
542
        else:
Casper's avatar
Casper committed
543
            max_seq_len = 2048 if max_seq_len is None else max_seq_len
544
545
546
            config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=trust_remote_code, **config_kwargs
            )
Casper's avatar
Casper committed
547
            config.max_seq_len = max_seq_len
548

Casper Hansen's avatar
Casper Hansen committed
549
        return model_weights_path, config, quant_config
Casper's avatar
Casper committed
550

551
552
553
    def _load_quantized_modules(
        self, model, quant_config, version, use_exllama, use_exllama_v2
    ):
554
        # Real quantization of weights
555
        assert not (
Casper's avatar
Casper committed
556
            version == "gemv" and (use_exllama or use_exllama_v2)
557
558
        ), "Exllama kernels only support GEMM version."

559
        # Get blocks of model
560
        layers = self.get_model_layers(model)
561

562
563
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
564
565

            # Get every linear layer in a block
566
            named_linears = get_named_linears(layer)
567

568
            # Filter out the linear layers we don't want to exclude
569
570
571
            named_linears = exclude_layers_to_not_quantize(
                named_linears, quant_config.modules_to_not_convert
            )
572

573
            # Replace activation functions
574
            self._scale_activations(self, layer)
575

576
            # Replace nn.Linear with WQLinear
577
            for name, module in named_linears.items():
Casper's avatar
Casper committed
578
                if version == "marlin":
579
580
                    q_linear_module = WQLinear_Marlin
                elif use_exllama:
581
582
583
                    q_linear_module = WQLinear_Exllama
                elif use_exllama_v2:
                    q_linear_module = WQLinear_ExllamaV2
Casper's avatar
Casper committed
584
                elif version == "gemm":
Casper Hansen's avatar
Casper Hansen committed
585
                    q_linear_module = WQLinear_GEMM
Casper's avatar
Casper committed
586
                elif version == "gemv":
Casper Hansen's avatar
Casper Hansen committed
587
                    q_linear_module = WQLinear_GEMV
Casper's avatar
Casper committed
588
589
                elif version == "gemv_fast":
                    q_linear_module = WQLinear_GEMVFast
590

Casper Hansen's avatar
Casper Hansen committed
591
                q_linear = q_linear_module.from_linear(
592
                    module, quant_config.w_bit, quant_config.q_group_size, True
Casper Hansen's avatar
Casper Hansen committed
593
                )
594
595
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
596

597
598
            torch.cuda.empty_cache()
            gc.collect()
599

600
    @staticmethod
601
    def _scale_activations(self, layer):
602
        scale_dict = self.get_act_for_scaling(layer)
603

604
605
        if scale_dict["is_scalable"]:
            if not isinstance(scale_dict["scale_layer"], ScaledActivation):
606
                param = next(layer.parameters())
607

608
                # get activation scale
609
610
611
                scale_like = torch.ones(
                    scale_dict["scale_shape"], dtype=param.dtype, device=param.device
                )
612

613
                # scale activation
614
615
                scaled_act = ScaledActivation(scale_dict["scale_layer"], scale_like)
                set_op_by_name(layer, scale_dict["scale_name"], scaled_act)