base.py 14.3 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
import torch
5
import transformers
Casper Hansen's avatar
Casper Hansen committed
6
import torch.nn as nn
7

Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper's avatar
Casper committed
9
from typing import List, Union
10
from safetensors.torch import save_file
11
from huggingface_hub import snapshot_download
12
from transformers.modeling_utils import shard_checkpoint
13

14
15
16
17
from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
from awq.modules.linear.exllama import WQLinear_Exllama, exllama_post_init
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
18
19
20
21
22
from awq.utils.module import (
    get_named_linears,
    set_op_by_name,
    exclude_layers_to_not_quantize,
)
Casper's avatar
Casper committed
23
24
25
26
from transformers import (
    AutoConfig,
    PreTrainedModel,
    PretrainedConfig,
27
28
    AutoProcessor,
    CLIPImageProcessor,
Casper's avatar
Casper committed
29
)
30
31
32
33
from accelerate.big_modeling import (
    init_empty_weights,
    load_checkpoint_and_dispatch,
)
Casper's avatar
Casper committed
34

35
36
37
38
39
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name

40
# Since we support different `AutoModelForxxx` from transformers
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# we need to define a custom mapping dict as below:
TRANSFORMERS_AUTO_MAPPING_DICT = {
    "mpt": "AutoModelForCausalLM",
    "llama": "AutoModelForCausalLM",
    "opt": "AutoModelForCausalLM",
    "RefinedWeb": "AutoModelForCausalLM",
    "RefinedWebModel": "AutoModelForCausalLM",
    "falcon": "AutoModelForCausalLM",
    "bloom": "AutoModelForCausalLM",
    "gptj": "AutoModelForCausalLM",
    "gpt_bigcode": "AutoModelForCausalLM",
    "mistral": "AutoModelForCausalLM",
    "mixtral": "AutoModelForCausalLM",
    "gpt_neox": "AutoModelForCausalLM",
    "aquila": "AutoModelForCausalLM",
    "Yi": "AutoModelForCausalLM",
    "qwen": "AutoModelForCausalLM",
Aoyu's avatar
Aoyu committed
58
    "baichuan": "AutoModelForCausalLM",
59
60
61
    "llava": "AutoModelForVision2Seq",
}

62

63
class BaseAWQForCausalLM(nn.Module):
64
65
66
    def __init__(
        self, model, model_type, is_quantized, config, quant_config, processor
    ):
67
        super().__init__()
68
69
70
        self.model: PreTrainedModel = model
        self.model_type: str = model_type
        self.is_quantized: bool = is_quantized
71
        self.search_result = None
Casper's avatar
Casper committed
72
        self.config: PretrainedConfig = config
Casper's avatar
Casper committed
73
        self.quant_config: AwqConfig = quant_config
74
        self.processor: CLIPImageProcessor = processor
75

76
77
    def to(self, device: str):
        return self.model.to(device)
78

79
80
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
81

Casper Hansen's avatar
Casper Hansen committed
82
83
84
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
85

Casper Hansen's avatar
Casper Hansen committed
86
    @torch.no_grad()
87
88
89
90
91
92
93
94
95
96
97
    def quantize(
        self,
        tokenizer=None,
        quant_config={},
        calib_data: Union[str, List[str]] = "pileval",
        split="train",
        text_column="text",
        duo_scaling=True,
        modules_to_not_convert=None,
        export_compatible=False,
    ):
Casper's avatar
Casper committed
98
        self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
99

100
        self.quantizer = AwqQuantizer(
101
102
103
104
105
106
107
108
109
110
111
112
            self,
            self.model,
            tokenizer,
            self.quant_config.w_bit,
            self.quant_config.q_group_size,
            self.quant_config.version,
            calib_data,
            split,
            text_column,
            duo_scaling,
            modules_to_not_convert=modules_to_not_convert,
            export_compatible=export_compatible,
Casper Hansen's avatar
Casper Hansen committed
113
        )
114
        self.quantizer.quantize()
115

Casper Hansen's avatar
Casper Hansen committed
116
        self.is_quantized = True
117

118
119
120
121
122
    @torch.no_grad()
    def pack(self):
        """
        A utility function for the following scenario. Note that save_quantized will
        overwrite existing weights if you use the same quant_path.
123

124
125
126
127
        model.quantize(
            tokenizer,
            quant_config=quant_config,
            export_compatible=True
128
        )
129
130
131
132
133
        model.save_quantized(...)  # produces GGUF/other compat weights
        model.pack(...) # makes the model CUDA compat
        model.save_quantized(...)  # produces CUDA compat weights
        """
        self.quantizer.pack()
134

qwopqwop200's avatar
qwopqwop200 committed
135
    @staticmethod
Casper's avatar
Casper committed
136
    def fuse_layers(model):
qwopqwop200's avatar
qwopqwop200 committed
137
        pass
Casper's avatar
Casper committed
138

139
    def save_quantized(self, save_dir, safetensors=True, shard_size="10GB"):
140
        save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir
141

Casper Hansen's avatar
Casper Hansen committed
142
143
        # Save model
        class EmptyModule(nn.Module):
144
145
146
147
148
            def __init__(self):
                super(EmptyModule, self).__init__()

            def forward(self, x):
                return x
149

Casper's avatar
Casper committed
150
151
        # Save model and config files with empty state dict
        self.model.config.quantization_config = self.quant_config.to_transformers_dict()
Casper Hansen's avatar
Casper Hansen committed
152
        self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
Casper's avatar
Casper committed
153
        self.quant_config.save_pretrained(save_dir)
154

155
156
157
158
        # Vision transformers have a processor
        if self.processor is not None:
            self.processor.save_pretrained(save_dir)

Casper Hansen's avatar
Casper Hansen committed
159
        # Remove empty state dict
160
161
162
163
        default_paths = [
            f"{save_dir}/model.safetensors",
            f"{save_dir}/pytorch_model.bin",
        ]
164
165
166
        for path in default_paths:
            if os.path.exists(path):
                os.remove(path)
167

Casper Hansen's avatar
Casper Hansen committed
168
        # model_name has no extension, add it when saving state_dict
169
        model_name = "model.safetensors" if safetensors else "pytorch_model.bin"
170

Casper Hansen's avatar
Casper Hansen committed
171
172
        # shard checkpoint into chunks (10GB default)
        shards, index = shard_checkpoint(
173
            self.model.state_dict(), max_shard_size=shard_size, weights_name=model_name
Casper Hansen's avatar
Casper Hansen committed
174
        )
175

Casper Hansen's avatar
Casper Hansen committed
176
177
178
179
        for shard_file, shard in shards.items():
            if safetensors:
                # safetensors must be in the same memory, so we duplicate and use contiguous memory
                shard = {k: v.clone().contiguous() for k, v in shard.items()}
180
181
182
                save_file(
                    shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}
                )
Casper Hansen's avatar
Casper Hansen committed
183
184
            else:
                torch.save(shard, os.path.join(save_dir, shard_file))
185

Casper Hansen's avatar
Casper Hansen committed
186
187
        # save shard index
        if index is not None:
188
            with open(f"{save_dir}/{model_name}.index.json", "w+") as file:
Casper Hansen's avatar
Casper Hansen committed
189
                file.write(json.dumps(index, indent=4))
190

191
    @classmethod
192
193
194
195
196
197
198
199
200
201
    def from_pretrained(
        self,
        model_path,
        model_type,
        torch_dtype: torch.dtype = torch.float16,
        trust_remote_code=True,
        safetensors=False,
        device_map=None,
        **model_init_kwargs,
    ):
Casper Hansen's avatar
Casper Hansen committed
202
203
        # Get weights path and quant config
        model_weights_path, config, quant_config = self._load_config(
204
            self, model_path, "", safetensors, trust_remote_code=trust_remote_code
205
        )
Casper's avatar
Casper committed
206

207
208
209
210
211
212
213
214
        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)

        processor = None
        if target_cls_name == "AutoModelForVision2Seq":
            processor = AutoProcessor.from_pretrained(model_weights_path)
            processor: CLIPImageProcessor = processor.image_processor

Casper Hansen's avatar
Casper Hansen committed
215
        # If not quantized, must load with AutoModelForCausalLM
216
        model = target_cls.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
217
218
219
220
            model_weights_path,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
            use_safetensors=safetensors,
221
            device_map=device_map,
222
            **model_init_kwargs,
Casper Hansen's avatar
Casper Hansen committed
223
224
225
226
        )

        model.eval()

227
228
229
230
231
232
233
234
        return self(
            model,
            model_type,
            is_quantized=False,
            config=config,
            quant_config=quant_config,
            processor=processor,
        )
Casper Hansen's avatar
Casper Hansen committed
235

236
    @classmethod
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    def from_quantized(
        self,
        model_path,
        model_type,
        model_filename="",
        max_new_tokens=None,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        safetensors=True,
        is_quantized=True,
        fuse_layers=False,
        use_exllama=False,
        use_exllama_v2=False,
        version="GEMM",
        device_map="balanced",
        offload_folder=None,
        **config_kwargs,
    ):
Casper Hansen's avatar
Casper Hansen committed
255
256
        # [STEP 1-2] Load weights path and configs
        model_weights_path, config, quant_config = self._load_config(
257
258
259
260
261
262
263
264
            self,
            model_path,
            model_filename,
            safetensors,
            version,
            trust_remote_code,
            max_new_tokens=max_new_tokens,
            **config_kwargs,
Casper Hansen's avatar
Casper Hansen committed
265
        )
266
267
268

        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)
269

Casper Hansen's avatar
Casper Hansen committed
270
271
        # [STEP 3] Load model
        with init_empty_weights():
272
273
274
275
276
277
            model = target_cls.from_config(
                config=config,
                torch_dtype=torch_dtype,
                trust_remote_code=trust_remote_code,
            )

Casper Hansen's avatar
Casper Hansen committed
278
        # Prepare WQLinear layers, replace nn.Linear
279
280
281
282
283
284
285
286
287
        self._load_quantized_modules(
            self,
            model,
            quant_config,
            quant_config.version,
            use_exllama=use_exllama,
            use_exllama_v2=use_exllama_v2,
        )

Casper Hansen's avatar
Casper Hansen committed
288
289
        model.tie_weights()

290
291
292
        # loads the weights into modules and distributes
        # across available devices automatically
        load_checkpoint_and_dispatch(
Casper Hansen's avatar
Casper Hansen committed
293
294
            model,
            checkpoint=model_weights_path,
s4rduk4r's avatar
s4rduk4r committed
295
            device_map=device_map,
296
            no_split_module_classes=[self.layer_type],
s4rduk4r's avatar
s4rduk4r committed
297
            offload_folder=offload_folder,
298
            dtype=torch_dtype,
Casper Hansen's avatar
Casper Hansen committed
299
        )
300

Casper Hansen's avatar
Casper Hansen committed
301
        # Dispath to devices
302
        if fuse_layers:
Casper's avatar
Casper committed
303
            self.fuse_layers(model)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
304

305
306
307
308
        if use_exllama:
            # creates q4 handle
            model = exllama_post_init(model)
        elif use_exllama_v2:
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
309
            # creates q4 handle and allocates scratch spaces wrt max_input_len and max_batch_size
310
311
            model = exllamav2_post_init(
                model,
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
312
313
                max_input_len=max_new_tokens or 2048,
                max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1)),
314
315
316
317
318
319
320
321
322
323
            )

        return self(
            model,
            model_type,
            is_quantized=is_quantized,
            config=config,
            quant_config=quant_config,
            processor=None,
        )
s4rduk4r's avatar
s4rduk4r committed
324

325
326
327
328
329
330
331
332
333
334
    def _load_config(
        self,
        model_path,
        model_filename,
        safetensors=True,
        version="GEMM",
        trust_remote_code=True,
        max_new_tokens=4096,
        **config_kwargs,
    ):
335
        # [STEP 1] Download model if path is not a directory
336
        if not os.path.isdir(model_path):
337
            ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
338
            if safetensors:
339
                ignore_patterns.extend(["*.pt*", "*.bin*", "consolidated*"])
340
            else:
Casper Hansen's avatar
Casper Hansen committed
341
                ignore_patterns.append("*.safetensors*")
342

343
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
344
345
346

        if model_filename != "":
            model_weights_path = model_path + f"/{model_filename}"
347
348
        else:
            model_weights_path = model_path
349

350
        # [STEP 2] Load config and set sequence length
351
        # TODO: Create BaseAWQConfig class
Casper's avatar
Casper committed
352
        quant_config = AwqConfig.from_pretrained(model_path)
353

354
        # Load model config and set max generation length
355
356
357
358
        if max_new_tokens is None and hasattr(self, "max_new_tokens_key"):
            config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=trust_remote_code, **config_kwargs
            )
359
360
361
            config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048)
            # To add the generate support for Multi-modal models as well
            if hasattr(config, "text_config"):
362
363
364
                config.text_config.max_new_tokens = getattr(
                    config, self.max_new_tokens_key, 2048
                )
365
366
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
367
368
369
            config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=trust_remote_code, **config_kwargs
            )
370
            config.max_new_tokens = max_new_tokens
371

Casper Hansen's avatar
Casper Hansen committed
372
        return model_weights_path, config, quant_config
Casper's avatar
Casper committed
373

374
375
376
    def _load_quantized_modules(
        self, model, quant_config, version, use_exllama, use_exllama_v2
    ):
377
        # Real quantization of weights
Casper's avatar
Casper committed
378
        assert quant_config.zero_point, "We only support zero_point quantization now."
379
380
381
382
        assert not (
            version == "GEMV" and (use_exllama or use_exllama_v2)
        ), "Exllama kernels only support GEMM version."

383
        # Get blocks of model
384
        layers = self.get_model_layers(model)
385

386
387
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
388
389

            # Get every linear layer in a block
390
            named_linears = get_named_linears(layer)
391

392
            # Filter out the linear layers we don't want to exclude
393
394
395
            named_linears = exclude_layers_to_not_quantize(
                named_linears, quant_config.modules_to_not_convert
            )
396

397
            # Replace activation functions
398
            self._scale_activations(self, layer)
399

400
            # Replace nn.Linear with WQLinear
401
            for name, module in named_linears.items():
402
403
404
405
406
                if use_exllama:
                    q_linear_module = WQLinear_Exllama
                elif use_exllama_v2:
                    q_linear_module = WQLinear_ExllamaV2
                elif version == "GEMM":
Casper Hansen's avatar
Casper Hansen committed
407
                    q_linear_module = WQLinear_GEMM
408
                elif version == "GEMV":
Casper Hansen's avatar
Casper Hansen committed
409
                    q_linear_module = WQLinear_GEMV
410

Casper Hansen's avatar
Casper Hansen committed
411
                q_linear = q_linear_module.from_linear(
412
                    module, quant_config.w_bit, quant_config.q_group_size, True
Casper Hansen's avatar
Casper Hansen committed
413
                )
414
415
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
416

417
418
            torch.cuda.empty_cache()
            gc.collect()
419

420
    @staticmethod
421
    def _scale_activations(self, layer):
422
        scale_dict = self.get_act_for_scaling(layer)
423

424
425
        if scale_dict["is_scalable"]:
            if not isinstance(scale_dict["scale_layer"], ScaledActivation):
426
                param = next(layer.parameters())
427

428
                # get activation scale
429
430
431
                scale_like = torch.ones(
                    scale_dict["scale_shape"], dtype=param.dtype, device=param.device
                )
432

433
                # scale activation
434
435
                scaled_act = ScaledActivation(scale_dict["scale_layer"], scale_like)
                set_op_by_name(layer, scale_dict["scale_name"], scaled_act)