jiuge_weight.hpp 6.92 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
4
5
6
#ifndef JIUGE_WEIGHT_HPP
#define JIUGE_WEIGHT_HPP

#include "jiuge_impl.hpp"

#include <cmath>
PanZezhong's avatar
PanZezhong committed
7
inline std::shared_ptr<Tensor> getInEmbd(
PanZezhong's avatar
init  
PanZezhong committed
8
9
10
11
12
13
    JiugeMeta const *meta,
    JiugeWeights const *w) {
    auto shape = std::vector<size_t>({meta->dvoc, meta->d});
    return Tensor::weight((char *)w->input_embd, meta->dt_logits, shape);
}

PanZezhong's avatar
PanZezhong committed
14
inline std::shared_ptr<Tensor> getOutNorm(
PanZezhong's avatar
init  
PanZezhong committed
15
16
17
    JiugeMeta const *meta,
    JiugeWeights const *w) {
    auto shape = std::vector<size_t>({meta->d});
PanZezhong's avatar
PanZezhong committed
18
    return Tensor::weight((char *)w->output_norm, w->dt_norm, shape);
PanZezhong's avatar
init  
PanZezhong committed
19
20
}

PanZezhong's avatar
PanZezhong committed
21
inline std::shared_ptr<Tensor> getOutEmbd(
PanZezhong's avatar
init  
PanZezhong committed
22
23
    JiugeMeta const *meta,
    JiugeWeights const *w) {
PanZezhong's avatar
PanZezhong committed
24
25
26
27
28
29
30
31
    if (w->transpose_linear_weights != 0) {
        auto shape = std::vector<size_t>({meta->dvoc, meta->d});
        return Tensor::weight((char *)w->output_embd, meta->dt_logits, shape)
            ->permute({1, 0});
    } else {
        auto shape = std::vector<size_t>({meta->d, meta->dvoc});
        return Tensor::weight((char *)w->output_embd, meta->dt_logits, shape);
    }
PanZezhong's avatar
init  
PanZezhong committed
32
33
}

PanZezhong's avatar
PanZezhong committed
34
inline std::shared_ptr<Tensor> getAttnNorm(
PanZezhong's avatar
init  
PanZezhong committed
35
36
37
38
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer) {
    auto shape = std::vector<size_t>({meta->d});
PanZezhong's avatar
PanZezhong committed
39
    return Tensor::weight((char *)(w->attn_norm[layer]), w->dt_norm, shape);
PanZezhong's avatar
init  
PanZezhong committed
40
41
}

PanZezhong's avatar
PanZezhong committed
42
inline std::shared_ptr<Tensor> getAttnQKV(
PanZezhong's avatar
init  
PanZezhong committed
43
44
45
46
47
48
49
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer, size_t idev, size_t ndev) {
    auto nkvh = meta->nkvh;
    auto nh = meta->nh;
    auto dh = meta->dh;
    auto d = meta->d;
PanZezhong's avatar
PanZezhong committed
50
    size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * d * dsize(w->dt_mat);
PanZezhong's avatar
PanZezhong committed
51
52
53
54
55
56
57
58
    if (w->transpose_linear_weights != 0) {
        auto shape = std::vector<size_t>({(nh + 2 * nkvh) / ndev * dh, d});
        return Tensor::weight((char *)(w->attn_qkv[layer]) + offset, w->dt_mat, shape)
            ->permute({1, 0});
    } else {
        auto shape = std::vector<size_t>({d, (nh + 2 * nkvh) / ndev * dh});
        return Tensor::weight((char *)(w->attn_qkv[layer]) + offset, w->dt_mat, shape);
    }
PanZezhong's avatar
init  
PanZezhong committed
59
60
}

PanZezhong's avatar
PanZezhong committed
61
inline std::shared_ptr<Tensor> getAttnQKVBias(
PanZezhong's avatar
init  
PanZezhong committed
62
63
64
65
66
67
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer, size_t idev, size_t ndev) {
    auto nkvh = meta->nkvh;
    auto nh = meta->nh;
    auto dh = meta->dh;
PanZezhong's avatar
PanZezhong committed
68
    size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * dsize(w->dt_mat);
Pan Zezhong's avatar
Pan Zezhong committed
69
    auto shape = std::vector<size_t>({(nh + 2 * nkvh) / ndev * dh});
PanZezhong's avatar
PanZezhong committed
70
    return Tensor::weight((char *)(w->attn_qkv_b[layer]) + offset, w->dt_mat, shape);
PanZezhong's avatar
init  
PanZezhong committed
71
72
}

PanZezhong's avatar
PanZezhong committed
73
74
75
inline std::shared_ptr<Tensor> getAttnO(JiugeMeta const *meta,
                                        JiugeWeights const *w, size_t layer,
                                        size_t idev, size_t ndev) {
PanZezhong's avatar
init  
PanZezhong committed
76
77
78
    auto nh = meta->nh;
    auto dh = meta->dh;
    auto d = meta->d;
PanZezhong's avatar
PanZezhong committed
79
    size_t offset = idev * d * (nh / ndev * dh) * dsize(w->dt_mat);
PanZezhong's avatar
PanZezhong committed
80
81
82
83
84
85
86
87
    if (w->transpose_linear_weights != 0) {
        auto shape = std::vector<size_t>({d, nh / ndev * dh});
        return Tensor::weight((char *)(w->attn_o[layer]) + offset, w->dt_mat, shape)
            ->permute({1, 0});
    } else {
        auto shape = std::vector<size_t>({nh / ndev * dh, d});
        return Tensor::weight((char *)(w->attn_o[layer]) + offset, w->dt_mat, shape);
    }
PanZezhong's avatar
init  
PanZezhong committed
88
89
}

PanZezhong's avatar
PanZezhong committed
90
inline std::shared_ptr<Tensor> getFFNNorm(
PanZezhong's avatar
init  
PanZezhong committed
91
92
93
94
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer) {
    auto shape = std::vector<size_t>({meta->d});
PanZezhong's avatar
PanZezhong committed
95
    return Tensor::weight((char *)(w->ffn_norm[layer]), w->dt_norm, shape);
PanZezhong's avatar
init  
PanZezhong committed
96
97
}

PanZezhong's avatar
PanZezhong committed
98
inline std::shared_ptr<Tensor> getFFNGateUp(
PanZezhong's avatar
init  
PanZezhong committed
99
100
101
102
103
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer, size_t idev, size_t ndev) {
    auto di = meta->di;
    auto d = meta->d;
PanZezhong's avatar
PanZezhong committed
104
    size_t offset = idev * (2 * di / ndev) * d * dsize(w->dt_mat);
PanZezhong's avatar
PanZezhong committed
105
106
107
108
109
110
111
112
113
114
    if (w->transpose_linear_weights != 0) {
        auto shape = std::vector<size_t>({2 * di / ndev, d});
        return Tensor::weight((char *)(w->ffn_gate_up[layer]) + offset,
                              w->dt_mat, shape)
            ->permute({1, 0});
    } else {
        auto shape = std::vector<size_t>({d, 2 * di / ndev});
        return Tensor::weight((char *)(w->ffn_gate_up[layer]) + offset,
                              w->dt_mat, shape);
    }
PanZezhong's avatar
init  
PanZezhong committed
115
116
}

PanZezhong's avatar
PanZezhong committed
117
inline std::shared_ptr<Tensor> getFFNDown(
PanZezhong's avatar
init  
PanZezhong committed
118
119
120
121
122
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer, size_t idev, size_t ndev) {
    auto di = meta->di;
    auto d = meta->d;
PanZezhong's avatar
PanZezhong committed
123
    size_t offset = idev * d * (di / ndev) * dsize(w->dt_mat);
PanZezhong's avatar
PanZezhong committed
124
125
126
127
128
129
130
131
    if (w->transpose_linear_weights != 0) {
        auto shape = std::vector<size_t>({d, di / ndev});
        return Tensor::weight((char *)(w->ffn_down[layer]) + offset, w->dt_mat, shape)
            ->permute({1, 0});
    } else {
        auto shape = std::vector<size_t>({di / ndev, d});
        return Tensor::weight((char *)(w->ffn_down[layer]) + offset, w->dt_mat, shape);
    }
PanZezhong's avatar
init  
PanZezhong committed
132
133
}

PanZezhong's avatar
PanZezhong committed
134
inline std::shared_ptr<Tensor> getSinTable(JiugeMeta const *meta) {
PanZezhong's avatar
init  
PanZezhong committed
135
    auto half_dh = meta->dh / 2;
PanZezhong's avatar
PanZezhong committed
136
137
    auto unit = dsize(meta->dt_logits);
    void *table = std::malloc(meta->dctx * half_dh * unit);
PanZezhong's avatar
PanZezhong committed
138

PanZezhong's avatar
init  
PanZezhong committed
139
140
141
142
    for (size_t i = 0; i < meta->dctx; i++) {
        for (size_t j = 0; j < half_dh; j++) {
            float _sin = std::sin(
                static_cast<float>(i) / std::pow(meta->theta, static_cast<float>(j) / half_dh));
PanZezhong's avatar
PanZezhong committed
143
144
            if (meta->dt_logits == INFINI_DTYPE_F16) {
                ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_sin);
PanZezhong's avatar
PanZezhong committed
145
146
            } else if (meta->dt_logits == INFINI_DTYPE_BF16) {
                ((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_sin);
PanZezhong's avatar
PanZezhong committed
147
148
149
150
151
152
            } else if (meta->dt_logits == INFINI_DTYPE_F32) {
                ((float *)table)[i * half_dh + j] = _sin;
            } else {
                std::cout << "unsupported data type" << std::endl;
                exit(1);
            }
PanZezhong's avatar
init  
PanZezhong committed
153
154
        }
    }
PanZezhong's avatar
PanZezhong committed
155
    auto shape = std::vector<size_t>({meta->dctx, half_dh});
PanZezhong's avatar
init  
PanZezhong committed
156
157
158
159
160
    auto tensor = Tensor::weight(table, meta->dt_logits, shape);
    std::free(table);
    return tensor;
}

PanZezhong's avatar
PanZezhong committed
161
inline std::shared_ptr<Tensor> getCosTable(JiugeMeta const *meta) {
PanZezhong's avatar
init  
PanZezhong committed
162
    auto half_dh = meta->dh / 2;
PanZezhong's avatar
PanZezhong committed
163
164
    auto unit = dsize(meta->dt_logits);
    void *table = std::malloc(meta->dctx * half_dh * unit);
PanZezhong's avatar
PanZezhong committed
165

PanZezhong's avatar
init  
PanZezhong committed
166
167
168
169
    for (size_t i = 0; i < meta->dctx; i++) {
        for (size_t j = 0; j < half_dh; j++) {
            float _cos = std::cos(
                static_cast<float>(i) / std::pow(meta->theta, static_cast<float>(j) / half_dh));
PanZezhong's avatar
PanZezhong committed
170
171
            if (meta->dt_logits == INFINI_DTYPE_F16) {
                ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_cos);
PanZezhong's avatar
PanZezhong committed
172
173
            } else if (meta->dt_logits == INFINI_DTYPE_BF16) {
                ((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_cos);
PanZezhong's avatar
PanZezhong committed
174
175
176
177
178
179
            } else if (meta->dt_logits == INFINI_DTYPE_F32) {
                ((float *)table)[i * half_dh + j] = _cos;
            } else {
                std::cout << "unsupported data type" << std::endl;
                exit(1);
            }
PanZezhong's avatar
init  
PanZezhong committed
180
181
        }
    }
PanZezhong's avatar
PanZezhong committed
182
    auto shape = std::vector<size_t>({meta->dctx, half_dh});
PanZezhong's avatar
init  
PanZezhong committed
183
184
185
186
187
188
    auto tensor = Tensor::weight(table, meta->dt_logits, shape);
    std::free(table);
    return tensor;
}

#endif