jiuge_weight.hpp 5.01 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#ifndef JIUGE_WEIGHT_HPP
#define JIUGE_WEIGHT_HPP

#include "jiuge_impl.hpp"

#include <cmath>
inline std::shared_ptr<Tensor> get_in_embd(
    JiugeMeta const *meta,
    JiugeWeights const *w) {
    auto shape = std::vector<size_t>({meta->dvoc, meta->d});
    return Tensor::weight((char *)w->input_embd, meta->dt_logits, shape);
}

inline std::shared_ptr<Tensor> get_out_norm(
    JiugeMeta const *meta,
    JiugeWeights const *w) {
    auto shape = std::vector<size_t>({meta->d});
    return Tensor::weight((char *)w->output_norm, meta->dt_norm, shape);
}

inline std::shared_ptr<Tensor> get_out_embd(
    JiugeMeta const *meta,
    JiugeWeights const *w) {
    auto shape = std::vector<size_t>({meta->dvoc, meta->d});
    return Tensor::weight((char *)w->output_embd, meta->dt_logits, shape)
        ->permute({1, 0});
}

inline std::shared_ptr<Tensor> get_attn_norm(
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer) {
    auto shape = std::vector<size_t>({meta->d});
    return Tensor::weight((char *)(w->attn_norm[layer]), meta->dt_norm, shape);
}

inline std::shared_ptr<Tensor> get_attn_qkv(
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer, size_t idev, size_t ndev) {
    auto nkvh = meta->nkvh;
    auto nh = meta->nh;
    auto dh = meta->dh;
    auto d = meta->d;
    size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * d * dsize(meta->dt_mat);
    auto shape = std::vector<size_t>({(nh + 2 * nkvh) / ndev * dh, d});
    return Tensor::weight((char *)(w->attn_qkv[layer]) + offset, meta->dt_mat, shape)
        ->permute({1, 0});
}

inline std::shared_ptr<Tensor> get_attn_qkv_bias(
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer, size_t idev, size_t ndev) {
    auto nkvh = meta->nkvh;
    auto nh = meta->nh;
    auto dh = meta->dh;
    size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * dsize(meta->dt_mat);
    auto shape = std::vector<size_t>({1, (nh + 2 * nkvh) / ndev * dh});
    return Tensor::weight((char *)(w->attn_qkv_b[layer]) + offset, meta->dt_mat, shape);
}

inline std::shared_ptr<Tensor> get_attn_o(JiugeMeta const *meta,
                                          JiugeWeights const *w, size_t layer,
                                          size_t idev, size_t ndev) {
    auto nh = meta->nh;
    auto dh = meta->dh;
    auto d = meta->d;
    size_t offset = idev * d * (nh / ndev * dh) * dsize(meta->dt_mat);
    auto shape = std::vector<size_t>({d, nh / ndev * dh});
    return Tensor::weight((char *)(w->attn_o[layer]) + offset, meta->dt_mat, shape)
        ->permute({1, 0});
}

inline std::shared_ptr<Tensor> get_ffn_norm(
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer) {
    auto shape = std::vector<size_t>({meta->d});
    return Tensor::weight((char *)(w->ffn_norm[layer]), meta->dt_norm, shape);
}

inline std::shared_ptr<Tensor> get_ffn_gate_up(
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer, size_t idev, size_t ndev) {
    auto di = meta->di;
    auto d = meta->d;
    size_t offset = idev * (2 * di / ndev) * d * dsize(meta->dt_mat);
    auto shape = std::vector<size_t>({2 * di / ndev, d});
    return Tensor::weight((char *)(w->ffn_gate_up[layer]) + offset,
                          meta->dt_mat, shape)
        ->permute({1, 0});
}

inline std::shared_ptr<Tensor> get_ffn_down(
    JiugeMeta const *meta,
    JiugeWeights const *w,
    size_t layer, size_t idev, size_t ndev) {
    auto di = meta->di;
    auto d = meta->d;
    size_t offset = idev * d * (di / ndev) * dsize(meta->dt_mat);
    auto shape = std::vector<size_t>({d, di / ndev});
    return Tensor::weight((char *)(w->ffn_down[layer]) + offset, meta->dt_mat, shape)
        ->permute({1, 0});
}

inline std::shared_ptr<Tensor> get_sin_table(JiugeMeta const *meta) {
    float *table = (float *)std::malloc(meta->dctx * meta->dh * sizeof(float));
    auto half_dh = meta->dh / 2;
    for (size_t i = 0; i < meta->dctx; i++) {
        for (size_t j = 0; j < half_dh; j++) {
            float _sin = std::sin(
                static_cast<float>(i) / std::pow(meta->theta, static_cast<float>(j) / half_dh));
            table[i * meta->dh + 2 * j] = _sin;
            table[i * meta->dh + 2 * j + 1] = _sin;
        }
    }
    auto shape = std::vector<size_t>({meta->dctx, meta->dh});
    auto tensor = Tensor::weight(table, meta->dt_logits, shape);
    std::free(table);
    return tensor;
}

inline std::shared_ptr<Tensor> get_cos_table(JiugeMeta const *meta) {
    float *table = (float *)std::malloc(meta->dctx * meta->dh * sizeof(float));
    auto half_dh = meta->dh / 2;
    for (size_t i = 0; i < meta->dctx; i++) {
        for (size_t j = 0; j < half_dh; j++) {
            float _cos = std::cos(
                static_cast<float>(i) / std::pow(meta->theta, static_cast<float>(j) / half_dh));
            table[i * meta->dh + 2 * j] = _cos;
            table[i * meta->dh + 2 * j + 1] = _cos;
        }
    }
    auto shape = std::vector<size_t>({meta->dctx, meta->dh});
    auto tensor = Tensor::weight(table, meta->dt_logits, shape);
    std::free(table);
    return tensor;
}

#endif