auto_awq.py 2.83 KB
Newer Older
pppppM's avatar
pppppM committed
1
2
3
4
5
6
# Copyright (c) OpenMMLab. All rights reserved.

from pathlib import Path

import fire
import torch
7
8
from accelerate import (infer_auto_device_map, init_empty_weights,
                        load_checkpoint_in_model)
pppppM's avatar
pppppM committed
9
from torch import nn
10
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
pppppM's avatar
pppppM committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

from lmdeploy.lite.quantization.awq import (FC_FCS_MAP, NORM_FCS_MAP,
                                            quant_weights, smooth_layers)
from lmdeploy.lite.utils import collect_target_modules

LAYER_TYPE_MAP = {
    'InternLMForCausalLM': 'InternLMDecoderLayer',
    'QWenLMHeadModel': 'QWenBlock',
    'BaiChuanForCausalLM': 'DecoderLayer',
    'LlamaForCausalLM': 'LlamaDecoderLayer',
}
NORM_TYPE_MAP = {
    'InternLMForCausalLM': 'InternLMRMSNorm',
    'QWenLMHeadModel': 'RMSNorm',
    'BaiChuanForCausalLM': 'RMSNorm',
    'LlamaForCausalLM': 'LlamaRMSNorm',
}


def auto_awq(model: str,
31
             work_dir: str,
pppppM's avatar
pppppM committed
32
33
34
35
36
             w_bits: int = 4,
             w_sym: bool = False,
             w_group_size: int = 128,
             device: str = 'cuda'):

37
    # Load tokenizer and configuration
pppppM's avatar
pppppM committed
38
39
40
    tokenizer = AutoTokenizer.from_pretrained(model,
                                              use_fast=False,
                                              trust_remote_code=True)
41
42
    hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True)
    checkpoint = hf_config._name_or_path
pppppM's avatar
pppppM committed
43

44
45
46
47
48
49
    with init_empty_weights():
        # Load model
        model = AutoModelForCausalLM.from_pretrained(model,
                                                     torch_dtype=torch.float16,
                                                     trust_remote_code=True)
        model.config.use_cache = False
pppppM's avatar
pppppM committed
50
51
52
53

    layer_type = LAYER_TYPE_MAP[type(model).__name__]
    fc2fcs = FC_FCS_MAP[layer_type]
    norm2fcs = NORM_FCS_MAP[layer_type]
54
55
56
57
58
59
60
61
62
63
64
65
66

    decoder_layers = collect_target_modules(model, layer_type)

    # Infer device map
    device_map = infer_auto_device_map(model,
                                       no_split_module_classes=[layer_type])
    for name in device_map.keys():
        if name in decoder_layers or 'lm_head' in name:
            device_map[name] = 'cpu'
        else:
            device_map[name] = 0
    load_checkpoint_in_model(model, checkpoint, device_map)

pppppM's avatar
pppppM committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    work_dir = Path(work_dir)

    act_scales = torch.load(work_dir / 'inputs_stats.pth')['absmean']
    layers = collect_target_modules(model, layer_type)
    fcs = {}
    for l_name, layer in layers.items():
        name2fc = collect_target_modules(layer, nn.Linear, prefix=l_name)
        fcs.update(name2fc)

    smooth_layers(layers, fc2fcs, norm2fcs, act_scales, w_group_size, device)
    quant_weights(model, fcs, w_bits, w_sym, w_group_size, device)

    model.save_pretrained(work_dir)
    tokenizer.save_pretrained(work_dir)


if __name__ == '__main__':

    fire.Fire(auto_awq)