jiuge_awq_weight.cpp 6.24 KB
Newer Older
blkmjsian's avatar
blkmjsian committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include "jiuge_awq.hpp"

#include <cmath>

inline std::shared_ptr<Tensor> getSinTable(size_t dctx, size_t dh, float theta) {
    auto half_dh = dh / 2;
    auto unit = dsize(INFINI_DTYPE_F16);
    void *table = std::malloc(dctx * half_dh * unit);

    for (size_t i = 0; i < dctx; i++) {
        for (size_t j = 0; j < half_dh; j++) {
            float _sin = std::sin(
                static_cast<float>(i) / std::pow(theta, static_cast<float>(j) / half_dh));

            ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_sin);
        }
    }
    auto shape = std::vector<size_t>({dctx, half_dh});
    auto tensor = Tensor::weight(table, INFINI_DTYPE_F16, shape);
    std::free(table);
    return tensor;
}

inline std::shared_ptr<Tensor> getCosTable(size_t dctx, size_t dh, float theta) {
    auto half_dh = dh / 2;
    auto unit = dsize(INFINI_DTYPE_F16);
    void *table = std::malloc(dctx * half_dh * unit);

    for (size_t i = 0; i < dctx; i++) {
        for (size_t j = 0; j < half_dh; j++) {
            float _cos = std::cos(
                static_cast<float>(i) / std::pow(theta, static_cast<float>(j) / half_dh));

            ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_cos);
        }
    }
    auto shape = std::vector<size_t>({dctx, half_dh});
    auto tensor = Tensor::weight(table, INFINI_DTYPE_F16, shape);
    std::free(table);
    return tensor;
}

JiugeAWQWeights::JiugeAWQWeights(
    const JiugeAWQMeta *meta,
    infiniDevice_t device,
46
    const std::vector<int> &dev_ids) : infinicore::weights::Loader(device, dev_ids) {
blkmjsian's avatar
blkmjsian committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    auto ndev = dev_ids.size();
    _device_weights.resize(ndev);
    infiniDtype_t dt_logits = meta->dt_logits;
    infiniDtype_t dt_norm_w = meta->dt_norm_w;
    size_t nlayer = meta->nlayer;
    size_t d = meta->d;
    size_t nh = meta->nh / ndev;
    size_t nkvh = meta->nkvh / ndev;
    size_t dh = meta->dh;
    size_t di = meta->di / ndev;
    size_t dctx = meta->dctx;
    size_t dvoc = meta->dvoc;
    size_t nbit = meta->nbit;
    size_t quant_group_size = meta->quant_group_size;

    for (size_t i = 0; i < ndev; i++) {
        RUN_INFINI(infinirtSetDevice(device, dev_ids[i]));

        auto weight = std::make_shared<JiugeAWQDeviceWeight>();
        _device_weights[i] = weight;

        auto w_in_embd = Tensor::weight(nullptr, dt_logits, {dvoc, d});
PanZezhong1725's avatar
PanZezhong1725 committed
69
        this->register_weight("model.embed_tokens.weight", w_in_embd, i);
blkmjsian's avatar
blkmjsian committed
70
71
72
        weight->w_in_embd = w_in_embd;

        auto w_out_norm = Tensor::weight(nullptr, dt_norm_w, {d});
PanZezhong1725's avatar
PanZezhong1725 committed
73
        this->register_weight("model.norm.weight", w_out_norm, i);
blkmjsian's avatar
blkmjsian committed
74
75
76
        weight->w_out_norm = w_out_norm;

        auto w_out_embd = Tensor::weight(nullptr, dt_logits, {dvoc, d})->permute({1, 0});
PanZezhong1725's avatar
PanZezhong1725 committed
77
        this->register_weight("lm_head.weight", w_out_embd, i);
blkmjsian's avatar
blkmjsian committed
78
79
80
81
82
83
84
        weight->w_out_embd = w_out_embd;

        weight->sin_table = getSinTable(dctx, dh, meta->theta);
        weight->cos_table = getCosTable(dctx, dh, meta->theta);

        for (size_t layer = 0; layer < nlayer; layer++) {

PanZezhong1725's avatar
PanZezhong1725 committed
85
86
87
#define RIGISTER_LAYER_WEIGHT(W_NAME, W_VAR, W_SHAPE, W_DTYPE, W_DIST_TYPE)                      \
    auto W_VAR = Tensor::weight(nullptr, W_DTYPE, W_SHAPE);                                      \
    this->register_weight(W_NAME, W_VAR, i, infinicore::weights::DistributionType::W_DIST_TYPE); \
blkmjsian's avatar
blkmjsian committed
88
89
    weight->W_VAR.push_back(W_VAR);

90
            RIGISTER_LAYER_WEIGHT("model.layers." + std::to_string(layer) + ".input_layernorm.weight", w_attn_norm, {d}, dt_norm_w, FULL);
blkmjsian's avatar
blkmjsian committed
91

PanZezhong1725's avatar
PanZezhong1725 committed
92
93
94
95
96
97
98
99
#define REGISTER_LAYER_QUANT_WEIGHT(W_NAME, W_VAR, W_IN, W_OUT, W_DIST_TYPE)                                     \
    auto W_VAR = std::make_shared<QuantInt4Weight>();                                                            \
    W_VAR->w = Tensor::weight(nullptr, INFINI_DTYPE_I32, {W_IN, (W_OUT)*nbit / 32});                             \
    this->register_weight(W_NAME + ".qweight", W_VAR->w, i, infinicore::weights::DistributionType::W_DIST_TYPE); \
    W_VAR->s = Tensor::weight(nullptr, INFINI_DTYPE_F16, {(W_IN) / quant_group_size, (W_OUT)});                  \
    this->register_weight(W_NAME + ".scales", W_VAR->s, i, infinicore::weights::DistributionType::W_DIST_TYPE);  \
    W_VAR->z = Tensor::weight(nullptr, INFINI_DTYPE_I32, {(W_IN) / quant_group_size, (W_OUT)*nbit / 32});        \
    this->register_weight(W_NAME + ".qzeros", W_VAR->z, i, infinicore::weights::DistributionType::W_DIST_TYPE);  \
blkmjsian's avatar
blkmjsian committed
100
101
    weight->W_VAR.push_back(W_VAR);

102
103
104
105
106
107
108
109
110
111
112
113
            REGISTER_LAYER_QUANT_WEIGHT("model.layers." + std::to_string(layer) + ".self_attn.q_proj", w_attn_q, d, nh * dh, COLUMN);
            REGISTER_LAYER_QUANT_WEIGHT("model.layers." + std::to_string(layer) + ".self_attn.k_proj", w_attn_k, d, nkvh * dh, COLUMN);
            REGISTER_LAYER_QUANT_WEIGHT("model.layers." + std::to_string(layer) + ".self_attn.v_proj", w_attn_v, d, nkvh * dh, COLUMN);
            RIGISTER_LAYER_WEIGHT("model.layers." + std::to_string(layer) + ".self_attn.q_proj.bias", b_attn_q, {nh * dh}, INFINI_DTYPE_F16, COLUMN);
            RIGISTER_LAYER_WEIGHT("model.layers." + std::to_string(layer) + ".self_attn.k_proj.bias", b_attn_k, {nkvh * dh}, INFINI_DTYPE_F16, COLUMN);
            RIGISTER_LAYER_WEIGHT("model.layers." + std::to_string(layer) + ".self_attn.v_proj.bias", b_attn_v, {nkvh * dh}, INFINI_DTYPE_F16, COLUMN);
            REGISTER_LAYER_QUANT_WEIGHT("model.layers." + std::to_string(layer) + ".self_attn.o_proj", w_attn_out, nh * dh, d, ROW);

            RIGISTER_LAYER_WEIGHT("model.layers." + std::to_string(layer) + ".post_attention_layernorm.weight", w_ffn_norm, {d}, dt_norm_w, FULL);
            REGISTER_LAYER_QUANT_WEIGHT("model.layers." + std::to_string(layer) + ".mlp.gate_proj", w_ffn_gate, d, di, COLUMN);
            REGISTER_LAYER_QUANT_WEIGHT("model.layers." + std::to_string(layer) + ".mlp.up_proj", w_ffn_up, d, di, COLUMN);
            REGISTER_LAYER_QUANT_WEIGHT("model.layers." + std::to_string(layer) + ".mlp.down_proj", w_ffn_down, di, d, ROW);
blkmjsian's avatar
blkmjsian committed
114
115
116
117
118
119
120
        }
    }

#undef RIGISTER_LAYER_WEIGHT
#undef REGISTER_LAYER_QUANT_WEIGHT
}

121
__INFINI_C struct ModelWeights *
blkmjsian's avatar
blkmjsian committed
122
123
124
125
126
127
128
createJiugeAWQWeights(const JiugeAWQMeta *meta,
                      infiniDevice_t device,
                      int ndev,
                      const int *dev_ids) {
    JiugeAWQWeights *weights = new JiugeAWQWeights(meta, device, std::vector<int>(dev_ids, dev_ids + ndev));
    return (struct ModelWeights *)weights;
}