load_quant.py 2.51 KB
Newer Older
Haotian Tang's avatar
Haotian Tang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from awq.quantize.quantizer import real_quantize_model_weight
from awq.quantize.qmodule import WQLinear
from tqdm import tqdm

def load_awq_model(model, checkpoint, w_bit, group_size, device):
    q_config = {"zero_point": True, "q_group_size": group_size}
    real_quantize_model_weight(model, w_bit, q_config, init_only = True)
    pbar = tqdm(range(1))
    pbar.set_description('Loading checkpoint')
    for i in pbar:
        if hasattr(model.config, "tie_encoder_decoder"):
            model.config.tie_encoder_decoder = False
        if hasattr(model.config, "tie_word_embeddings"):
            model.config.tie_word_embeddings = False
        model = load_checkpoint_and_dispatch(
            model, checkpoint,
            no_split_module_classes=[
                "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"]
        ).to(device)
    return model


def make_quant_linear(module, names, w_bit, groupsize, device, name=''):
    if isinstance(module, WQLinear):
        return
    for attr in dir(module):
        tmp = getattr(module, attr)
        name1 = name + '.' + attr if name != '' else attr
        if name1 in names:
            delattr(module, attr)
            setattr(module, attr, WQLinear(w_bit, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None, device))
    for name1, child in module.named_children():
        make_quant_linear(child, names, w_bit, groupsize, device, name + '.' + name1 if name != '' else name1)

def find_layers(module, layers=[nn.Linear], name=''):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
    return res


def load_awq_llama_fast(model, checkpoint, w_bit, group_size, device):
    layers = find_layers(model)
    for name in ['lm_head']:
        if name in layers:
            del layers[name]
    make_quant_linear(model, layers, w_bit, group_size, device)
    del layers

    pbar = tqdm(range(1))
    pbar.set_description('Loading checkpoint')
    for i in pbar:
        if checkpoint.endswith('.safetensors'):
            from safetensors.torch import load_file as safe_load
            model.load_state_dict(safe_load(checkpoint))
        else:
            model.load_state_dict(torch.load(checkpoint))

    return model.to(device)