jiuge_weight.hpp 5.52 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
4
5
6
#ifndef JIUGE_WEIGHT_HPP
#define JIUGE_WEIGHT_HPP

#include "jiuge_impl.hpp"

#include <cmath>
PanZezhong's avatar
PanZezhong committed
7
inline std::shared_ptr<Tensor> getInEmbd(
PanZezhong's avatar
init  
PanZezhong committed
8
9
10
11
12
13
    JiugeMeta const *meta,
    JiugeWeights const *w) {
    auto shape = std::vector<size_t>({meta->dvoc, meta->d});
    return Tensor::weight((char *)w->input_embd, meta->dt_logits, shape);
}

PanZezhong's avatar
PanZezhong committed
14
inline std::shared_ptr<Tensor> getOutNorm(
PanZezhong's avatar
init  
PanZezhong committed
15
16
17
    JiugeMeta const *meta,
    JiugeWeights const *w) {
    auto shape = std::vector<size_t>({meta->d});
PanZezhong's avatar
PanZezhong committed
18
    return Tensor::weight((char *)w->output_norm, w->dt_norm, shape);
PanZezhong's avatar
init  
PanZezhong committed
19
20
}

PanZezhong's avatar
PanZezhong committed
21
inline std::shared_ptr<Tensor> getOutEmbd(
PanZezhong's avatar
init  
PanZezhong committed
22
23
24
25
26
27
28
    JiugeMeta const *meta,
    JiugeWeights const *w) {
    auto shape = std::vector<size_t>({meta->dvoc, meta->d});
    return Tensor::weight((char *)w->output_embd, meta->dt_logits, shape)
        ->permute({1, 0});
}

PanZezhong's avatar
PanZezhong committed
29
inline std::shared_ptr<Tensor> getAttnNorm(
PanZezhong's avatar
init  
PanZezhong committed
30
31
32
33
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer) {
    auto shape = std::vector<size_t>({meta->d});
PanZezhong's avatar
PanZezhong committed
34
    return Tensor::weight((char *)(w->attn_norm[layer]), w->dt_norm, shape);
PanZezhong's avatar
init  
PanZezhong committed
35
36
}

PanZezhong's avatar
PanZezhong committed
37
inline std::shared_ptr<Tensor> getAttnQKV(
PanZezhong's avatar
init  
PanZezhong committed
38
39
40
41
42
43
44
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer, size_t idev, size_t ndev) {
    auto nkvh = meta->nkvh;
    auto nh = meta->nh;
    auto dh = meta->dh;
    auto d = meta->d;
PanZezhong's avatar
PanZezhong committed
45
    size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * d * dsize(w->dt_mat);
PanZezhong's avatar
init  
PanZezhong committed
46
    auto shape = std::vector<size_t>({(nh + 2 * nkvh) / ndev * dh, d});
PanZezhong's avatar
PanZezhong committed
47
    return Tensor::weight((char *)(w->attn_qkv[layer]) + offset, w->dt_mat, shape)
PanZezhong's avatar
init  
PanZezhong committed
48
49
50
        ->permute({1, 0});
}

PanZezhong's avatar
PanZezhong committed
51
inline std::shared_ptr<Tensor> getAttnQKVBias(
PanZezhong's avatar
init  
PanZezhong committed
52
53
54
55
56
57
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer, size_t idev, size_t ndev) {
    auto nkvh = meta->nkvh;
    auto nh = meta->nh;
    auto dh = meta->dh;
PanZezhong's avatar
PanZezhong committed
58
    size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * dsize(w->dt_mat);
Pan Zezhong's avatar
Pan Zezhong committed
59
    auto shape = std::vector<size_t>({(nh + 2 * nkvh) / ndev * dh});
PanZezhong's avatar
PanZezhong committed
60
    return Tensor::weight((char *)(w->attn_qkv_b[layer]) + offset, w->dt_mat, shape);
PanZezhong's avatar
init  
PanZezhong committed
61
62
}

PanZezhong's avatar
PanZezhong committed
63
64
65
inline std::shared_ptr<Tensor> getAttnO(JiugeMeta const *meta,
                                        JiugeWeights const *w, size_t layer,
                                        size_t idev, size_t ndev) {
PanZezhong's avatar
init  
PanZezhong committed
66
67
68
    auto nh = meta->nh;
    auto dh = meta->dh;
    auto d = meta->d;
PanZezhong's avatar
PanZezhong committed
69
    size_t offset = idev * d * (nh / ndev * dh) * dsize(w->dt_mat);
PanZezhong's avatar
init  
PanZezhong committed
70
    auto shape = std::vector<size_t>({d, nh / ndev * dh});
PanZezhong's avatar
PanZezhong committed
71
    return Tensor::weight((char *)(w->attn_o[layer]) + offset, w->dt_mat, shape)
PanZezhong's avatar
init  
PanZezhong committed
72
73
74
        ->permute({1, 0});
}

PanZezhong's avatar
PanZezhong committed
75
inline std::shared_ptr<Tensor> getFFNNorm(
PanZezhong's avatar
init  
PanZezhong committed
76
77
78
79
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer) {
    auto shape = std::vector<size_t>({meta->d});
PanZezhong's avatar
PanZezhong committed
80
    return Tensor::weight((char *)(w->ffn_norm[layer]), w->dt_norm, shape);
PanZezhong's avatar
init  
PanZezhong committed
81
82
}

PanZezhong's avatar
PanZezhong committed
83
inline std::shared_ptr<Tensor> getFFNGateUp(
PanZezhong's avatar
init  
PanZezhong committed
84
85
86
87
88
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer, size_t idev, size_t ndev) {
    auto di = meta->di;
    auto d = meta->d;
PanZezhong's avatar
PanZezhong committed
89
    size_t offset = idev * (2 * di / ndev) * d * dsize(w->dt_mat);
PanZezhong's avatar
init  
PanZezhong committed
90
91
    auto shape = std::vector<size_t>({2 * di / ndev, d});
    return Tensor::weight((char *)(w->ffn_gate_up[layer]) + offset,
PanZezhong's avatar
PanZezhong committed
92
                          w->dt_mat, shape)
PanZezhong's avatar
init  
PanZezhong committed
93
94
95
        ->permute({1, 0});
}

PanZezhong's avatar
PanZezhong committed
96
inline std::shared_ptr<Tensor> getFFNDown(
PanZezhong's avatar
init  
PanZezhong committed
97
98
99
100
101
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer, size_t idev, size_t ndev) {
    auto di = meta->di;
    auto d = meta->d;
PanZezhong's avatar
PanZezhong committed
102
    size_t offset = idev * d * (di / ndev) * dsize(w->dt_mat);
PanZezhong's avatar
init  
PanZezhong committed
103
    auto shape = std::vector<size_t>({d, di / ndev});
PanZezhong's avatar
PanZezhong committed
104
    return Tensor::weight((char *)(w->ffn_down[layer]) + offset, w->dt_mat, shape)
PanZezhong's avatar
init  
PanZezhong committed
105
106
107
        ->permute({1, 0});
}

PanZezhong's avatar
PanZezhong committed
108
inline std::shared_ptr<Tensor> getSinTable(JiugeMeta const *meta) {
PanZezhong's avatar
init  
PanZezhong committed
109
    auto half_dh = meta->dh / 2;
PanZezhong's avatar
PanZezhong committed
110
111
    auto unit = dsize(meta->dt_logits);
    void *table = std::malloc(meta->dctx * half_dh * unit);
PanZezhong's avatar
PanZezhong committed
112

PanZezhong's avatar
init  
PanZezhong committed
113
114
115
116
    for (size_t i = 0; i < meta->dctx; i++) {
        for (size_t j = 0; j < half_dh; j++) {
            float _sin = std::sin(
                static_cast<float>(i) / std::pow(meta->theta, static_cast<float>(j) / half_dh));
PanZezhong's avatar
PanZezhong committed
117
118
119
120
121
122
123
124
            if (meta->dt_logits == INFINI_DTYPE_F16) {
                ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_sin);
            } else if (meta->dt_logits == INFINI_DTYPE_F32) {
                ((float *)table)[i * half_dh + j] = _sin;
            } else {
                std::cout << "unsupported data type" << std::endl;
                exit(1);
            }
PanZezhong's avatar
init  
PanZezhong committed
125
126
        }
    }
PanZezhong's avatar
PanZezhong committed
127
    auto shape = std::vector<size_t>({meta->dctx, half_dh});
PanZezhong's avatar
init  
PanZezhong committed
128
129
130
131
132
    auto tensor = Tensor::weight(table, meta->dt_logits, shape);
    std::free(table);
    return tensor;
}

PanZezhong's avatar
PanZezhong committed
133
inline std::shared_ptr<Tensor> getCosTable(JiugeMeta const *meta) {
PanZezhong's avatar
init  
PanZezhong committed
134
    auto half_dh = meta->dh / 2;
PanZezhong's avatar
PanZezhong committed
135
136
    auto unit = dsize(meta->dt_logits);
    void *table = std::malloc(meta->dctx * half_dh * unit);
PanZezhong's avatar
PanZezhong committed
137

PanZezhong's avatar
init  
PanZezhong committed
138
139
140
141
    for (size_t i = 0; i < meta->dctx; i++) {
        for (size_t j = 0; j < half_dh; j++) {
            float _cos = std::cos(
                static_cast<float>(i) / std::pow(meta->theta, static_cast<float>(j) / half_dh));
PanZezhong's avatar
PanZezhong committed
142
143
144
145
146
147
148
149
            if (meta->dt_logits == INFINI_DTYPE_F16) {
                ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_cos);
            } else if (meta->dt_logits == INFINI_DTYPE_F32) {
                ((float *)table)[i * half_dh + j] = _cos;
            } else {
                std::cout << "unsupported data type" << std::endl;
                exit(1);
            }
PanZezhong's avatar
init  
PanZezhong committed
150
151
        }
    }
PanZezhong's avatar
PanZezhong committed
152
    auto shape = std::vector<size_t>({meta->dctx, half_dh});
PanZezhong's avatar
init  
PanZezhong committed
153
154
155
156
157
158
    auto tensor = Tensor::weight(table, meta->dt_logits, shape);
    std::free(table);
    return tensor;
}

#endif