hf_model.py 5.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from fastllm_pytools import llm;
import torch;
import ctypes;
import numpy as np;

fastllm_data_type_dict = {
    "int4": 8,
    "int8": 3,
    "float16": 7
}
fastllm_weight_type_dict = {
    "linear": 1,
    "embedding": 2,
    "QuantizedLinear": 111
}

def create(model,
           tokenizer = None,
           pre_prompt = None,
           user_role = None,
           bot_role = None,
           history_sep = None,
           dtype = "float16"):
    if (dtype not in fastllm_data_type_dict):
        print("dtype should in ", list(fastllm_data_type_dict.keys()));
        exit(0);

    # 0.1 model info
    modelInfo = model.config.__dict__
    if (pre_prompt):
        modelInfo["pre_prompt"] = pre_prompt;
    if (user_role):
        modelInfo["user_role"] = user_role;
    if (bot_role):
        modelInfo["bot_role"] = bot_role;
    if (history_sep):
        modelInfo["history_sep"] = history_sep;
    if (modelInfo["model_type"] == "baichuan" and hasattr(model, "model") and hasattr(model.model, "get_alibi_mask")):
        # Baichuan 2代
        modelInfo["use_alibi"] = "1";
        modelInfo["pre_prompt"] = "";
        modelInfo["user_role"] = ("<FLM_FIX_TOKEN_" + str(model.generation_config.user_token_id) + "> ") if hasattr(model.generation_config, "user_token_id") else "";
        modelInfo["bot_role"] = ("<FLM_FIX_TOKEN_" + str(model.generation_config.assistant_token_id) + ">") if hasattr(model.generation_config, "assistant_token_id") else "";
        modelInfo["history_sep"] = "";
    if (modelInfo["model_type"] == "qwen"):
        modelInfo["im_end_id"] = tokenizer.im_end_id
        modelInfo["im_start_id"] = tokenizer.im_start_id


    weight_type_dict = {};
    module_dict = {};
    weight_bits = {};
    for key, m in model.named_modules():
        if (str(type(m)).find("QuantizedLinear") != -1):
            weight_type_dict[key + ".weight"] = "QuantizedLinear";
            weight_bits[key + ".weight"] = m.weight_bit_width;
        if (isinstance(m, torch.nn.Linear)):
            weight_type_dict[key + ".weight"] = "linear";
            module_dict[key + ".weight"] = m;
        if (isinstance(m, torch.nn.Embedding)):
            weight_type_dict[key] = "embedding";

    model = model.cpu();
    dict = model.state_dict();
    model_type = model.config.__dict__["model_type"];
    model = llm.fastllm_lib.create_empty_llm_model(model_type.encode());
    for it in modelInfo.keys():
        llm.fastllm_lib.add_dict_llm_model(model, str(it).encode(), str(modelInfo[it]).encode());

    # 1. vocab
    if (tokenizer):
        if (hasattr(tokenizer, "tokenizer")):
            if modelInfo["model_type"] == "qwen":
                pass
            else:
                tokenizer = tokenizer.tokenizer;
        if (hasattr(tokenizer, "sp_model")):
            piece_size = tokenizer.sp_model.piece_size();
            for i in range(piece_size):
                llm.fastllm_lib.add_tokenizer_word_llm_model(model, tokenizer.sp_model.id_to_piece(i).encode(),
                                                             i, ctypes.c_float(tokenizer.sp_model.get_score(i)));
        else:
            vocab = tokenizer.get_vocab();
            for v in vocab.keys():
                if (modelInfo["model_type"] == "moss"):
                    vv = [(ord(c) if c not in tokenizer.byte_decoder else tokenizer.byte_decoder[c]) for c in v];
                    llm.fastllm_lib.add_tokenizer_word_llm_model(model, vv, vocab[v], ctypes.c_float(1.0));
                elif (modelInfo["model_type"] == "qwen"):
                    llm.fastllm_lib.add_tokenizer_word_llm_model(model, v, vocab[v], ctypes.c_float(1.0));
                else:
                    llm.fastllm_lib.add_tokenizer_word_llm_model(model, v.encode(), vocab[v], ctypes.c_float(1.0));
    tot = 0;
    for key in dict:
        ori_data_type = 0;
        ori_np_data_type = np.float32;
        cur_weight_type = 0;
        if (key in weight_type_dict and weight_type_dict[key] in fastllm_weight_type_dict):
            cur_weight_type = fastllm_weight_type_dict[weight_type_dict[key]];
        to_data_type = 0;

        if (cur_weight_type == 1):
            to_data_type = fastllm_data_type_dict[dtype];
            if (to_data_type == 7):
                ori_data_type = 7;
                ori_np_data_type = np.float16;
        elif (cur_weight_type == 2):
            # TODO bfloat
            to_data_type = 0;

        if (cur_weight_type == 111):
            llm.fastllm_lib.add_qlinear_weight_llm_model(model, key.encode(),
                                                 len(dict[key].shape),
                                                 (ctypes.c_int * len(dict[key].shape))(*list(dict[key].shape)),
                                                 weight_bits[key],
                                                 dict[key + "_scale"].numpy().astype(np.float32).ctypes.data_as(ctypes.c_void_p),
                                                 dict[key].numpy().ctypes.data_as(ctypes.c_void_p));
        else:
            llm.fastllm_lib.add_weight_llm_model(model, key.encode(),
                                             len(dict[key].shape),
                                             (ctypes.c_int * len(dict[key].shape))(*list(dict[key].shape)),
                                             to_data_type, cur_weight_type, ori_data_type,
                                             dict[key].numpy().astype(ori_np_data_type).ctypes.data_as(ctypes.c_void_p));
        tot += 1;
        print("convert (", tot, "/", len(dict), end = " )\r");

    print("");
    llm.fastllm_lib.init_params_llm_model(model);
    llm.fastllm_lib.warmup_llm_model(model);
    ret = llm.model("", id = model);
    return ret;