base.py 14.6 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
import torch
5
import transformers
Casper Hansen's avatar
Casper Hansen committed
6
import torch.nn as nn
7

Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper's avatar
Casper committed
9
from typing import List, Union
10
from safetensors.torch import save_file
11
from huggingface_hub import snapshot_download
12
from transformers.modeling_utils import shard_checkpoint
13

14
15
from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
16
from awq.modules.linear.marlin import WQLinear_Marlin, marlin_post_init
17
18
from awq.modules.linear.exllama import WQLinear_Exllama, exllama_post_init
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
19
20
21
22
23
from awq.utils.module import (
    get_named_linears,
    set_op_by_name,
    exclude_layers_to_not_quantize,
)
Casper's avatar
Casper committed
24
25
26
27
from transformers import (
    AutoConfig,
    PreTrainedModel,
    PretrainedConfig,
28
29
    AutoProcessor,
    CLIPImageProcessor,
Casper's avatar
Casper committed
30
)
31
32
33
34
from accelerate.big_modeling import (
    init_empty_weights,
    load_checkpoint_and_dispatch,
)
Casper's avatar
Casper committed
35

36
37
38
39
40
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name

41
# Since we support different `AutoModelForxxx` from transformers
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# we need to define a custom mapping dict as below:
TRANSFORMERS_AUTO_MAPPING_DICT = {
    "mpt": "AutoModelForCausalLM",
    "llama": "AutoModelForCausalLM",
    "opt": "AutoModelForCausalLM",
    "RefinedWeb": "AutoModelForCausalLM",
    "RefinedWebModel": "AutoModelForCausalLM",
    "falcon": "AutoModelForCausalLM",
    "bloom": "AutoModelForCausalLM",
    "gptj": "AutoModelForCausalLM",
    "gpt_bigcode": "AutoModelForCausalLM",
    "mistral": "AutoModelForCausalLM",
    "mixtral": "AutoModelForCausalLM",
    "gpt_neox": "AutoModelForCausalLM",
    "aquila": "AutoModelForCausalLM",
    "Yi": "AutoModelForCausalLM",
    "qwen": "AutoModelForCausalLM",
Aoyu's avatar
Aoyu committed
59
    "baichuan": "AutoModelForCausalLM",
60
61
62
    "llava": "AutoModelForVision2Seq",
}

63

64
class BaseAWQForCausalLM(nn.Module):
65
66
67
    def __init__(
        self, model, model_type, is_quantized, config, quant_config, processor
    ):
68
        super().__init__()
69
70
71
        self.model: PreTrainedModel = model
        self.model_type: str = model_type
        self.is_quantized: bool = is_quantized
72
        self.search_result = None
Casper's avatar
Casper committed
73
        self.config: PretrainedConfig = config
Casper's avatar
Casper committed
74
        self.quant_config: AwqConfig = quant_config
75
        self.processor: CLIPImageProcessor = processor
76

77
78
    def to(self, device: str):
        return self.model.to(device)
79

80
81
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
82

Casper Hansen's avatar
Casper Hansen committed
83
84
85
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
86

Casper Hansen's avatar
Casper Hansen committed
87
    @torch.no_grad()
88
89
90
91
92
93
94
95
96
97
98
    def quantize(
        self,
        tokenizer=None,
        quant_config={},
        calib_data: Union[str, List[str]] = "pileval",
        split="train",
        text_column="text",
        duo_scaling=True,
        modules_to_not_convert=None,
        export_compatible=False,
    ):
Casper's avatar
Casper committed
99
        self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
100

101
        self.quantizer = AwqQuantizer(
102
103
104
105
106
            self,
            self.model,
            tokenizer,
            self.quant_config.w_bit,
            self.quant_config.q_group_size,
107
            self.quant_config.zero_point,
108
109
110
111
112
113
114
            self.quant_config.version,
            calib_data,
            split,
            text_column,
            duo_scaling,
            modules_to_not_convert=modules_to_not_convert,
            export_compatible=export_compatible,
Casper Hansen's avatar
Casper Hansen committed
115
        )
116
        self.quantizer.quantize()
117

Casper Hansen's avatar
Casper Hansen committed
118
        self.is_quantized = True
119

120
121
122
123
124
    @torch.no_grad()
    def pack(self):
        """
        A utility function for the following scenario. Note that save_quantized will
        overwrite existing weights if you use the same quant_path.
125

126
127
128
129
        model.quantize(
            tokenizer,
            quant_config=quant_config,
            export_compatible=True
130
        )
131
132
133
134
135
        model.save_quantized(...)  # produces GGUF/other compat weights
        model.pack(...) # makes the model CUDA compat
        model.save_quantized(...)  # produces CUDA compat weights
        """
        self.quantizer.pack()
136

qwopqwop200's avatar
qwopqwop200 committed
137
    @staticmethod
Casper's avatar
Casper committed
138
    def fuse_layers(model):
qwopqwop200's avatar
qwopqwop200 committed
139
        pass
Casper's avatar
Casper committed
140

141
    def save_quantized(self, save_dir, safetensors=True, shard_size="10GB"):
142
        save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir
143

Casper Hansen's avatar
Casper Hansen committed
144
145
        # Save model
        class EmptyModule(nn.Module):
146
147
148
149
150
            def __init__(self):
                super(EmptyModule, self).__init__()

            def forward(self, x):
                return x
151

Casper's avatar
Casper committed
152
153
        # Save model and config files with empty state dict
        self.model.config.quantization_config = self.quant_config.to_transformers_dict()
154
        self.model.generation_config.do_sample = True
Casper Hansen's avatar
Casper Hansen committed
155
        self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
Casper's avatar
Casper committed
156
        self.quant_config.save_pretrained(save_dir)
157

158
159
160
161
        # Vision transformers have a processor
        if self.processor is not None:
            self.processor.save_pretrained(save_dir)

Casper Hansen's avatar
Casper Hansen committed
162
        # Remove empty state dict
163
164
165
166
        default_paths = [
            f"{save_dir}/model.safetensors",
            f"{save_dir}/pytorch_model.bin",
        ]
167
168
169
        for path in default_paths:
            if os.path.exists(path):
                os.remove(path)
170

Casper Hansen's avatar
Casper Hansen committed
171
        # model_name has no extension, add it when saving state_dict
172
        model_name = "model.safetensors" if safetensors else "pytorch_model.bin"
173

Casper Hansen's avatar
Casper Hansen committed
174
175
        # shard checkpoint into chunks (10GB default)
        shards, index = shard_checkpoint(
176
            self.model.state_dict(), max_shard_size=shard_size, weights_name=model_name
Casper Hansen's avatar
Casper Hansen committed
177
        )
178

Casper Hansen's avatar
Casper Hansen committed
179
180
181
182
        for shard_file, shard in shards.items():
            if safetensors:
                # safetensors must be in the same memory, so we duplicate and use contiguous memory
                shard = {k: v.clone().contiguous() for k, v in shard.items()}
183
184
185
                save_file(
                    shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}
                )
Casper Hansen's avatar
Casper Hansen committed
186
187
            else:
                torch.save(shard, os.path.join(save_dir, shard_file))
188

Casper Hansen's avatar
Casper Hansen committed
189
190
        # save shard index
        if index is not None:
191
            with open(f"{save_dir}/{model_name}.index.json", "w+") as file:
Casper Hansen's avatar
Casper Hansen committed
192
                file.write(json.dumps(index, indent=4))
193

194
    @classmethod
195
196
197
198
199
200
201
202
203
204
    def from_pretrained(
        self,
        model_path,
        model_type,
        torch_dtype: torch.dtype = torch.float16,
        trust_remote_code=True,
        safetensors=False,
        device_map=None,
        **model_init_kwargs,
    ):
Casper Hansen's avatar
Casper Hansen committed
205
206
        # Get weights path and quant config
        model_weights_path, config, quant_config = self._load_config(
207
            self, model_path, "", safetensors, trust_remote_code=trust_remote_code
208
        )
Casper's avatar
Casper committed
209

210
211
212
213
214
215
216
217
        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)

        processor = None
        if target_cls_name == "AutoModelForVision2Seq":
            processor = AutoProcessor.from_pretrained(model_weights_path)
            processor: CLIPImageProcessor = processor.image_processor

Casper Hansen's avatar
Casper Hansen committed
218
        # If not quantized, must load with AutoModelForCausalLM
219
        model = target_cls.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
220
221
222
223
            model_weights_path,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
            use_safetensors=safetensors,
224
            device_map=device_map,
225
            **model_init_kwargs,
Casper Hansen's avatar
Casper Hansen committed
226
227
228
229
        )

        model.eval()

230
231
232
233
234
235
236
237
        return self(
            model,
            model_type,
            is_quantized=False,
            config=config,
            quant_config=quant_config,
            processor=processor,
        )
Casper Hansen's avatar
Casper Hansen committed
238

239
    @classmethod
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    def from_quantized(
        self,
        model_path,
        model_type,
        model_filename="",
        max_new_tokens=None,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        safetensors=True,
        is_quantized=True,
        fuse_layers=False,
        use_exllama=False,
        use_exllama_v2=False,
        version="GEMM",
        device_map="balanced",
        offload_folder=None,
        **config_kwargs,
    ):
Casper Hansen's avatar
Casper Hansen committed
258
259
        # [STEP 1-2] Load weights path and configs
        model_weights_path, config, quant_config = self._load_config(
260
261
262
263
264
265
266
267
            self,
            model_path,
            model_filename,
            safetensors,
            version,
            trust_remote_code,
            max_new_tokens=max_new_tokens,
            **config_kwargs,
Casper Hansen's avatar
Casper Hansen committed
268
        )
269
270
271

        target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
        target_cls = getattr(transformers, target_cls_name)
272

Casper Hansen's avatar
Casper Hansen committed
273
274
        # [STEP 3] Load model
        with init_empty_weights():
275
276
277
278
279
280
            model = target_cls.from_config(
                config=config,
                torch_dtype=torch_dtype,
                trust_remote_code=trust_remote_code,
            )

Casper Hansen's avatar
Casper Hansen committed
281
        # Prepare WQLinear layers, replace nn.Linear
282
283
284
285
286
287
288
289
290
        self._load_quantized_modules(
            self,
            model,
            quant_config,
            quant_config.version,
            use_exllama=use_exllama,
            use_exllama_v2=use_exllama_v2,
        )

Casper Hansen's avatar
Casper Hansen committed
291
292
        model.tie_weights()

293
294
295
        # loads the weights into modules and distributes
        # across available devices automatically
        load_checkpoint_and_dispatch(
Casper Hansen's avatar
Casper Hansen committed
296
297
            model,
            checkpoint=model_weights_path,
s4rduk4r's avatar
s4rduk4r committed
298
            device_map=device_map,
299
            no_split_module_classes=[self.layer_type],
s4rduk4r's avatar
s4rduk4r committed
300
            offload_folder=offload_folder,
301
            dtype=torch_dtype,
Casper Hansen's avatar
Casper Hansen committed
302
        )
303

Casper Hansen's avatar
Casper Hansen committed
304
        # Dispath to devices
305
        if fuse_layers:
Casper's avatar
Casper committed
306
            self.fuse_layers(model)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
307

308
309
310
311
        if quant_config.version == "Marlin":
            model = marlin_post_init(model)

        elif use_exllama:
312
313
314
            # creates q4 handle
            model = exllama_post_init(model)
        elif use_exllama_v2:
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
315
            # creates q4 handle and allocates scratch spaces wrt max_input_len and max_batch_size
316
317
            model = exllamav2_post_init(
                model,
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
318
319
                max_input_len=max_new_tokens or 2048,
                max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1)),
320
321
322
323
324
325
326
327
328
329
            )

        return self(
            model,
            model_type,
            is_quantized=is_quantized,
            config=config,
            quant_config=quant_config,
            processor=None,
        )
s4rduk4r's avatar
s4rduk4r committed
330

331
332
333
334
335
336
337
338
339
340
    def _load_config(
        self,
        model_path,
        model_filename,
        safetensors=True,
        version="GEMM",
        trust_remote_code=True,
        max_new_tokens=4096,
        **config_kwargs,
    ):
341
        # [STEP 1] Download model if path is not a directory
342
        if not os.path.isdir(model_path):
343
            ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
344
            if safetensors:
345
                ignore_patterns.extend(["*.pt*", "*.bin*", "consolidated*"])
346
            else:
Casper Hansen's avatar
Casper Hansen committed
347
                ignore_patterns.append("*.safetensors*")
348

349
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
350
351
352

        if model_filename != "":
            model_weights_path = model_path + f"/{model_filename}"
353
354
        else:
            model_weights_path = model_path
355

356
        # [STEP 2] Load config and set sequence length
357
        # TODO: Create BaseAWQConfig class
Casper's avatar
Casper committed
358
        quant_config = AwqConfig.from_pretrained(model_path)
359

360
        # Load model config and set max generation length
361
362
363
364
        if max_new_tokens is None and hasattr(self, "max_new_tokens_key"):
            config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=trust_remote_code, **config_kwargs
            )
365
366
367
            config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048)
            # To add the generate support for Multi-modal models as well
            if hasattr(config, "text_config"):
368
369
370
                config.text_config.max_new_tokens = getattr(
                    config, self.max_new_tokens_key, 2048
                )
371
372
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
373
374
375
            config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=trust_remote_code, **config_kwargs
            )
376
            config.max_new_tokens = max_new_tokens
377

Casper Hansen's avatar
Casper Hansen committed
378
        return model_weights_path, config, quant_config
Casper's avatar
Casper committed
379

380
381
382
    def _load_quantized_modules(
        self, model, quant_config, version, use_exllama, use_exllama_v2
    ):
383
        # Real quantization of weights
384
385
386
387
        assert not (
            version == "GEMV" and (use_exllama or use_exllama_v2)
        ), "Exllama kernels only support GEMM version."

388
        # Get blocks of model
389
        layers = self.get_model_layers(model)
390

391
392
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
393
394

            # Get every linear layer in a block
395
            named_linears = get_named_linears(layer)
396

397
            # Filter out the linear layers we don't want to exclude
398
399
400
            named_linears = exclude_layers_to_not_quantize(
                named_linears, quant_config.modules_to_not_convert
            )
401

402
            # Replace activation functions
403
            self._scale_activations(self, layer)
404

405
            # Replace nn.Linear with WQLinear
406
            for name, module in named_linears.items():
407
408
409
                if version == "Marlin":
                    q_linear_module = WQLinear_Marlin
                elif use_exllama:
410
411
412
413
                    q_linear_module = WQLinear_Exllama
                elif use_exllama_v2:
                    q_linear_module = WQLinear_ExllamaV2
                elif version == "GEMM":
Casper Hansen's avatar
Casper Hansen committed
414
                    q_linear_module = WQLinear_GEMM
415
                elif version == "GEMV":
Casper Hansen's avatar
Casper Hansen committed
416
                    q_linear_module = WQLinear_GEMV
417

Casper Hansen's avatar
Casper Hansen committed
418
                q_linear = q_linear_module.from_linear(
419
                    module, quant_config.w_bit, quant_config.q_group_size, True
Casper Hansen's avatar
Casper Hansen committed
420
                )
421
422
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
423

424
425
            torch.cuda.empty_cache()
            gc.collect()
426

427
    @staticmethod
428
    def _scale_activations(self, layer):
429
        scale_dict = self.get_act_for_scaling(layer)
430

431
432
        if scale_dict["is_scalable"]:
            if not isinstance(scale_dict["scale_layer"], ScaledActivation):
433
                param = next(layer.parameters())
434

435
                # get activation scale
436
437
438
                scale_like = torch.ones(
                    scale_dict["scale_shape"], dtype=param.dtype, device=param.device
                )
439

440
                # scale activation
441
442
                scaled_act = ScaledActivation(scale_dict["scale_layer"], scale_like)
                set_op_by_name(layer, scale_dict["scale_name"], scaled_act)