jiuge.cpp 25.7 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include "jiuge_impl.hpp"
#include "jiuge_weight.hpp"

#include "../../tensor.hpp"
#include "../../utils.hpp"
#include "infinicore_infer.h"

#include <random>
#include <thread>
#include <vector>

void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
                          const JiugeWeights *weights,
                          infiniDevice_t device, int idev,
                          int ndev, int dev_id,
                          infinicclComm_t comm) {
    RUN_INFINI(infinirtSetDevice(device, dev_id));
    infiniopHandle_t handle;
    infiniopCreateHandle(&handle);
    infinirtStream_t stream;
    infinirtStreamCreate(&stream);

    std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out,
        w_ffn_norm, w_ffn_gate_up, w_ffn_down;
    for (size_t layer = 0; layer < meta->nlayer; layer++) {
        w_attn_norm.push_back(
PanZezhong's avatar
PanZezhong committed
27
            getAttnNorm(meta, weights, layer));
PanZezhong's avatar
init  
PanZezhong committed
28
        w_attn_qkv.push_back(
PanZezhong's avatar
PanZezhong committed
29
            getAttnQKV(meta, weights, layer, idev, ndev));
PanZezhong's avatar
init  
PanZezhong committed
30
31
        if (weights->attn_qkv_b != nullptr) {
            b_attn_qkv.push_back(
PanZezhong's avatar
PanZezhong committed
32
                getAttnQKVBias(meta, weights, layer, idev, ndev));
PanZezhong's avatar
init  
PanZezhong committed
33
34
        }
        w_attn_out.push_back(
PanZezhong's avatar
PanZezhong committed
35
            getAttnO(meta, weights, layer, idev, ndev));
PanZezhong's avatar
init  
PanZezhong committed
36
        w_ffn_norm.push_back(
PanZezhong's avatar
PanZezhong committed
37
            getFFNNorm(meta, weights, layer));
PanZezhong's avatar
init  
PanZezhong committed
38
        w_ffn_gate_up.push_back(
PanZezhong's avatar
PanZezhong committed
39
            getFFNGateUp(meta, weights, layer, idev, ndev));
PanZezhong's avatar
init  
PanZezhong committed
40
        w_ffn_down.push_back(
PanZezhong's avatar
PanZezhong committed
41
            getFFNDown(meta, weights, layer, idev, ndev));
PanZezhong's avatar
init  
PanZezhong committed
42
43
    }

thatPepe's avatar
thatPepe committed
44
45
    auto memory_pool = std::make_shared<MemoryPool>(128 * 1024 * 1024);

PanZezhong's avatar
PanZezhong committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    *rsrc = DeviceResource{
        device,
        dev_id,
        handle,
        getInEmbd(meta, weights),
        getOutNorm(meta, weights),
        getOutEmbd(meta, weights),
        getSinTable(meta),
        getCosTable(meta),
        w_attn_norm,
        w_attn_qkv,
        b_attn_qkv,
        w_attn_out,
        w_ffn_norm,
        w_ffn_gate_up,
        w_ffn_down,
        stream,
        comm,
thatPepe's avatar
thatPepe committed
64
        memory_pool,
PanZezhong's avatar
PanZezhong committed
65
    };
PanZezhong's avatar
PanZezhong committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    RUN_INFINI(infinirtDeviceSynchronize());
}

void releaseDeviceResource(DeviceResource &res) {
    infinirtDeviceSynchronize();
    // Release individual Tensors
    res.w_in_embd.reset();
    res.w_out_norm.reset();
    res.w_out_embd.reset();
    res.sin_table.reset();
    res.cos_table.reset();
    for (auto &t : res.w_attn_norm) {
        t.reset();
    }
    res.w_attn_norm.clear();
    for (auto &t : res.w_attn_qkv) {
        t.reset();
    }
    res.w_attn_qkv.clear();
    for (auto &t : res.b_attn_qkv) {
        t.reset();
    }
    res.b_attn_qkv.clear();
    for (auto &t : res.w_attn_out) {
        t.reset();
    }
    res.w_attn_out.clear();
    for (auto &t : res.w_ffn_norm) {
        t.reset();
    }
    res.w_ffn_norm.clear();
    for (auto &t : res.w_ffn_gate_up) {
        t.reset();
    }
    res.w_ffn_gate_up.clear();
    for (auto &t : res.w_ffn_down) {
        t.reset();
    }
    res.w_ffn_down.clear();
    infiniopDestroyHandle(res.handle);
    res.handle = nullptr;
    infinirtStreamDestroy(res.stream);
    res.stream = nullptr;
    infinicclCommDestroy(res.comm);
    res.comm = nullptr;
PanZezhong's avatar
init  
PanZezhong committed
111
112
}

PanZezhong's avatar
PanZezhong committed
113
void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
PanZezhong's avatar
init  
PanZezhong committed
114
115
116
117
                      uint32_t idev, uint32_t ndev,
                      const uint32_t *tokens, uint32_t ntok,
                      const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
                      struct KVCache **kv_caches,
Pan Zezhong's avatar
Pan Zezhong committed
118
119
                      const float *temperature, const uint32_t *topk, const float *topp,
                      uint32_t *output) {
PanZezhong's avatar
init  
PanZezhong committed
120
121
122
    auto nlayer = meta.nlayer;
    auto nkvh = meta.nkvh / ndev;
    auto nh = meta.nh / ndev;
123
    auto ngroup = nh / nkvh;
PanZezhong's avatar
init  
PanZezhong committed
124
125
126
127
128
129
130
    // auto dctx = meta.dctx;
    auto dh = meta.dh;
    auto d = meta.d;
    auto dt_logits = meta.dt_logits;
    auto di = meta.di / ndev;
    auto dvoc = meta.dvoc;
    auto stream = rsrc.stream;
PanZezhong's avatar
PanZezhong committed
131
    bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0;
PanZezhong's avatar
init  
PanZezhong committed
132
133

    // Allocate buffers
thatPepe's avatar
thatPepe committed
134
135
136
137
138
139
140
    auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool);
    auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool);
    auto qkv_buf = Tensor::buffer(dt_logits, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool);
    auto gate_up_buf = Tensor::buffer(dt_logits, {ntok, 2 * di}, rsrc.memory_pool);
    auto o_buf = Tensor::buffer(dt_logits, {ntok, nh * dh}, rsrc.memory_pool);
    auto prob_buf = Tensor::buffer(dt_logits, {nreq, dvoc}, rsrc.memory_pool);
    auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool);
Pan Zezhong's avatar
Pan Zezhong committed
141
142
    auto result_cpu = std::vector<int64_t>(nreq);

PanZezhong's avatar
init  
PanZezhong committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    // Prepare inputs
    auto batch_pos_ids = std::vector<uint32_t>(ntok);
    size_t req_start = 0;
    for (uint32_t req = 0; req < nreq; req++) {
        for (uint32_t i = 0; i < req_lens[req]; i++) {
            batch_pos_ids[req_start + i] = req_pos[req] + i;
        }
        req_start += req_lens[req];
    }

    std::shared_ptr<Tensor> pos_ids_buf;
    if (rsrc.device == INFINI_DEVICE_CPU) {
        pos_ids_buf = Tensor::weight(batch_pos_ids.data(), INFINI_DTYPE_U32, {ntok});
    } else {
thatPepe's avatar
thatPepe committed
157
        pos_ids_buf = Tensor::buffer(INFINI_DTYPE_U32, {ntok}, rsrc.memory_pool);
PanZezhong's avatar
init  
PanZezhong committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        RUN_INFINI(infinirtMemcpyAsync(pos_ids_buf->data(), batch_pos_ids.data(), sizeof(uint32_t) * ntok,
                                       INFINIRT_MEMCPY_H2D, stream));
    }
    for (uint32_t i = 0; i < ntok; i++) {
        RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d),
                                       rsrc.w_in_embd->data(tokens[i] * d),
                                       dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream));
    }

    // Prepare operators and workspace
    size_t workspace_size = 0, temp_size = 0;
    // attn & mlp rmsnorm
    infiniopRMSNormDescriptor_t desc_norm;
    RUN_INFINI(infiniopCreateRMSNormDescriptor(
PanZezhong's avatar
PanZezhong committed
172
173
        rsrc.handle, &desc_norm, logits_in->desc(),
        logits_out->desc(), rsrc.w_attn_norm[0]->desc(),
PanZezhong's avatar
init  
PanZezhong committed
174
175
176
177
178
        meta.epsilon));
    RUN_INFINI(infiniopGetRMSNormWorkspaceSize(desc_norm, &workspace_size));
    workspace_size = std::max(workspace_size, temp_size);
    // Attention
    infiniopGemmDescriptor_t desc_attn_qkv, desc_attn_o;
PanZezhong's avatar
PanZezhong committed
179
180
181
    infiniopRearrangeDescriptor_t desc_qkv_bias;
    if (has_qkv_bias) {
        RUN_INFINI(infiniopCreateRearrangeDescriptor(
PanZezhong's avatar
PanZezhong committed
182
183
            rsrc.handle, &desc_qkv_bias, qkv_buf->desc(),
            TensorDesc::create(dt_logits, {ntok, (nh + nkvh * 2) * dh}, {0, 1})->desc()));
PanZezhong's avatar
PanZezhong committed
184
    }
PanZezhong's avatar
init  
PanZezhong committed
185
    RUN_INFINI(infiniopCreateGemmDescriptor(
PanZezhong's avatar
PanZezhong committed
186
187
        rsrc.handle, &desc_attn_qkv, qkv_buf->desc(),
        logits_in->desc(), rsrc.w_attn_qkv[0]->desc()));
PanZezhong's avatar
init  
PanZezhong committed
188
    RUN_INFINI(infiniopCreateGemmDescriptor(
PanZezhong's avatar
PanZezhong committed
189
190
        rsrc.handle, &desc_attn_o, logits_in->desc(),
        o_buf->desc(), rsrc.w_attn_out[0]->desc()));
PanZezhong's avatar
init  
PanZezhong committed
191
192
193
194
195
    RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_attn_qkv, &temp_size));
    workspace_size = std::max(workspace_size, temp_size);
    RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_attn_o, &temp_size));
    workspace_size = std::max(workspace_size, temp_size);
    infiniopRoPEDescriptor_t desc_rope_q, desc_rope_k;
PanZezhong's avatar
PanZezhong committed
196
    qkv_buf->dimSplit(1, {nh + nkvh * 2, dh}); // (ntok, nh + 2 * nkvh, dh)
PanZezhong's avatar
init  
PanZezhong committed
197
198
199
    auto qkv_buf_q = qkv_buf->slice(1, 0, nh);
    auto qkv_buf_k = qkv_buf->slice(1, nh, nkvh);
    RUN_INFINI(infiniopCreateRoPEDescriptor(
PanZezhong's avatar
PanZezhong committed
200
201
202
        rsrc.handle, &desc_rope_q, qkv_buf_q->desc(), qkv_buf_q->desc(),
        pos_ids_buf->desc(), rsrc.sin_table->desc(),
        rsrc.cos_table->desc()));
PanZezhong's avatar
init  
PanZezhong committed
203
204
205
    RUN_INFINI(infiniopGetRoPEWorkspaceSize(desc_rope_q, &temp_size));
    workspace_size = std::max(workspace_size, temp_size);
    RUN_INFINI(infiniopCreateRoPEDescriptor(
PanZezhong's avatar
PanZezhong committed
206
207
208
        rsrc.handle, &desc_rope_k, qkv_buf_k->desc(), qkv_buf_k->desc(),
        pos_ids_buf->desc(), rsrc.sin_table->desc(),
        rsrc.cos_table->desc()));
PanZezhong's avatar
init  
PanZezhong committed
209
210
211
    RUN_INFINI(infiniopGetRoPEWorkspaceSize(desc_rope_k, &temp_size));
    workspace_size = std::max(workspace_size, temp_size);
    // attention inner
212
213
214
215
216
217
    auto desc_kv_rearranges = std::vector<infiniopRearrangeDescriptor_t>(nreq);
    auto desc_q_rearranges = std::vector<infiniopRearrangeDescriptor_t>(nreq);
    auto desc_qk_gemms = std::vector<infiniopGemmDescriptor_t>(nreq);
    auto desc_qk_softmaxs = std::vector<infiniopCausalSoftmaxDescriptor_t>(nreq);
    auto desc_attn_v_gemms = std::vector<infiniopGemmDescriptor_t>(nreq);
    auto desc_attn_v_rearranges = std::vector<infiniopRearrangeDescriptor_t>(nreq);
PanZezhong's avatar
init  
PanZezhong committed
218
    size_t token_offset = 0;
219
220
    size_t max_qk_size = 0;
    size_t max_seq_len = 0;
PanZezhong's avatar
PanZezhong committed
221
    o_buf->dimSplit(1, {nh, dh});
PanZezhong's avatar
init  
PanZezhong committed
222
223
224
    for (uint32_t req = 0; req < nreq; req++) {
        auto past_len = req_pos[req];
        auto seq_len = req_lens[req];
225
        auto total_len = past_len + seq_len;
PanZezhong's avatar
init  
PanZezhong committed
226
        auto o = o_buf->slice({{0, token_offset, seq_len}});
227
228
229
230
231
232
233
234
235
        auto q = qkv_buf->slice({{0, token_offset, seq_len}, {1, 0, nh}});
        auto k = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
        // auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
        // kv cache tensors can share the same descriptor
        // [nkvh, dh, total_len]
        auto full_kv = kv_caches[req]->k[idev][0]->slice(0, 0, total_len)->permute({1, 2, 0});
        auto cache_kv = kv_caches[req]->k[idev][0]->slice(0, past_len, seq_len);

        RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc.handle, &desc_kv_rearranges[req],
PanZezhong's avatar
PanZezhong committed
236
                                                     cache_kv->desc(), k->desc()));
237
238
239
240
241
242

        // [nkvh, ngroup, seq_len, dh]
        q->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3});
        auto q_t = TensorDesc::create(dt_logits, {nkvh, ngroup, seq_len, dh});
        // [seq_len, nkvh, ngroup, dh] -> [nkvh, ngroup, seq_len, dh]
        RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc.handle, &desc_q_rearranges[req],
PanZezhong's avatar
PanZezhong committed
243
                                                     q_t->desc(), q->desc()));
244
245
246
247
        // [nkvh, ngroup, seq_len, dh] -> [seq_len, nkvh, ngroup, dh]
        auto attn_v_t = q_t;
        auto attn_v = TensorDesc::createWithOrder(dt_logits, {nkvh, ngroup, seq_len, dh}, {1, 2, 0, 3});
        RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc.handle, &desc_attn_v_rearranges[req],
PanZezhong's avatar
PanZezhong committed
248
                                                     attn_v->desc(), attn_v_t->desc()));
249
250
251
252
253
        q_t = TensorDesc::create(dt_logits, {nkvh, ngroup * seq_len, dh});
        auto qk = TensorDesc::create(dt_logits, {nkvh, ngroup * seq_len, total_len});
        max_qk_size = std::max(max_qk_size, size_t(seq_len * total_len));
        max_seq_len = std::max(max_seq_len, size_t(seq_len));
        RUN_INFINI(infiniopCreateGemmDescriptor(
PanZezhong's avatar
PanZezhong committed
254
            rsrc.handle, &desc_qk_gemms[req], qk->desc(), q_t->desc(), full_kv->desc()));
255
256
257
258
259
260
        RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_qk_gemms[req], &temp_size));
        workspace_size = std::max(workspace_size, temp_size);

        // [nkvh, total_len, dh]
        auto full_v = kv_caches[req]->v[idev][0]->slice(0, 0, total_len)->permute({1, 0, 2});
        RUN_INFINI(infiniopCreateGemmDescriptor(
PanZezhong's avatar
PanZezhong committed
261
            rsrc.handle, &desc_attn_v_gemms[req], q_t->desc(), qk->desc(), full_v->desc()));
262
263
264
265
266
        RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_attn_v_gemms[req], &temp_size));
        workspace_size = std::max(workspace_size, temp_size);

        qk = TensorDesc::create(dt_logits, {nkvh * ngroup, seq_len, total_len});
        RUN_INFINI(infiniopCreateCausalSoftmaxDescriptor(
PanZezhong's avatar
PanZezhong committed
267
            rsrc.handle, &desc_qk_softmaxs[req], qk->desc(), qk->desc()));
268
        RUN_INFINI(infiniopGetCausalSoftmaxWorkspaceSize(desc_qk_softmaxs[req], &temp_size));
PanZezhong's avatar
init  
PanZezhong committed
269
        workspace_size = std::max(workspace_size, temp_size);
270

PanZezhong's avatar
init  
PanZezhong committed
271
272
        token_offset += seq_len;
    }
thatPepe's avatar
thatPepe committed
273
274
275
    auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool);
    auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
    auto attn_val_buf = Tensor::buffer(dt_logits, {nh, max_seq_len, dh}, rsrc.memory_pool);
PanZezhong's avatar
init  
PanZezhong committed
276
277
278
279
280

    // MLP descriptors
    infiniopGemmDescriptor_t desc_ffn_gate_up, desc_ffn_down;
    infiniopSwiGLUDescriptor_t desc_swiglu;
    RUN_INFINI(infiniopCreateGemmDescriptor(
PanZezhong's avatar
PanZezhong committed
281
282
        rsrc.handle, &desc_ffn_gate_up, gate_up_buf->desc(),
        logits_out->desc(), rsrc.w_ffn_gate_up[0]->desc()));
PanZezhong's avatar
init  
PanZezhong committed
283
284
285
286
287
    RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_ffn_gate_up, &temp_size));
    workspace_size = std::max(workspace_size, temp_size);
    auto gate_buf = gate_up_buf->slice(1, 0, di);
    auto up_buf = gate_up_buf->slice(1, di, di);
    RUN_INFINI(infiniopCreateSwiGLUDescriptor(
PanZezhong's avatar
PanZezhong committed
288
        rsrc.handle, &desc_swiglu, gate_buf->desc(), up_buf->desc(), gate_buf->desc()));
PanZezhong's avatar
init  
PanZezhong committed
289
290
291
    RUN_INFINI(infiniopGetSwiGLUWorkspaceSize(desc_swiglu, &temp_size));
    workspace_size = std::max(workspace_size, temp_size);
    RUN_INFINI(infiniopCreateGemmDescriptor(
PanZezhong's avatar
PanZezhong committed
292
293
        rsrc.handle, &desc_ffn_down, logits_in->desc(),
        gate_buf->desc(), rsrc.w_ffn_down[0]->desc()));
PanZezhong's avatar
init  
PanZezhong committed
294
295
296
297
298
299
    RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_ffn_down, &temp_size));
    workspace_size = std::max(workspace_size, temp_size);

    // Output and sample
    infiniopRMSNormDescriptor_t desc_norm_out;
    RUN_INFINI(infiniopCreateRMSNormDescriptor(
PanZezhong's avatar
PanZezhong committed
300
301
302
        rsrc.handle, &desc_norm_out, logits_out->slice(0, 0, 1)->desc(),
        logits_out->slice(0, 0, 1)->desc(),
        rsrc.w_out_norm->desc(), meta.epsilon));
PanZezhong's avatar
init  
PanZezhong committed
303
304
305
306
    RUN_INFINI(infiniopGetRMSNormWorkspaceSize(desc_norm_out, &temp_size));
    workspace_size = std::max(workspace_size, temp_size);
    infiniopGemmDescriptor_t desc_out_embd;
    RUN_INFINI(infiniopCreateGemmDescriptor(
PanZezhong's avatar
PanZezhong committed
307
308
309
        rsrc.handle, &desc_out_embd, prob_buf->desc(),
        logits_out->slice(0, 0, nreq)->desc(),
        rsrc.w_out_embd->desc()));
PanZezhong's avatar
init  
PanZezhong committed
310
311
312
313
314
    RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_out_embd, &temp_size));
    workspace_size = std::max(workspace_size, temp_size);
    infiniopRandomSampleDescriptor_t desc_sample;
    RUN_INFINI(infiniopCreateRandomSampleDescriptor(
        rsrc.handle, &desc_sample,
PanZezhong's avatar
PanZezhong committed
315
316
        TensorDesc::create(INFINI_DTYPE_I64, {}, {})->desc(),
        TensorDesc::create(dt_logits, {dvoc}, {1})->desc()));
PanZezhong's avatar
init  
PanZezhong committed
317
318
319
    RUN_INFINI(infiniopGetRandomSampleWorkspaceSize(desc_sample, &temp_size));
    workspace_size = std::max(workspace_size, temp_size);
    // Allocate workspace
thatPepe's avatar
thatPepe committed
320
    std::shared_ptr<Storage> workspace_storage = Storage::createFromPool(workspace_size, rsrc.memory_pool);
PanZezhong's avatar
PanZezhong committed
321
    void *workspace = workspace_storage->memory();
PanZezhong's avatar
PanZezhong committed
322

PanZezhong's avatar
PanZezhong committed
323
    // Compute
PanZezhong's avatar
init  
PanZezhong committed
324
325
326
327
328
329
330
331
    for (uint32_t layer = 0; layer < nlayer; layer++) {
        // 1. Attention
        // rms norm
        RUN_INFINI(infiniopRMSNorm(
            desc_norm, workspace, workspace_size,
            logits_out->data(), logits_in->data(),
            rsrc.w_attn_norm[layer]->data(), stream));
        // qkv_proj
PanZezhong's avatar
PanZezhong committed
332
333
334
        if (has_qkv_bias) {
            RUN_INFINI(infiniopRearrange(
                desc_qkv_bias,
Pan Zezhong's avatar
Pan Zezhong committed
335
                qkv_buf->data(), rsrc.b_attn_qkv[layer]->data(), stream));
PanZezhong's avatar
PanZezhong committed
336
        }
PanZezhong's avatar
init  
PanZezhong committed
337
338
339
        RUN_INFINI(infiniopGemm(
            desc_attn_qkv, workspace, workspace_size,
            qkv_buf->data(), logits_out->data(),
Pan Zezhong's avatar
Pan Zezhong committed
340
            rsrc.w_attn_qkv[layer]->data(), 1.0, has_qkv_bias ? 1.0 : 0.0, stream));
PanZezhong's avatar
init  
PanZezhong committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        // rope
        RUN_INFINI(infiniopRoPE(
            desc_rope_q, workspace, workspace_size,
            qkv_buf->data(), qkv_buf->data(),
            pos_ids_buf->data(),
            rsrc.sin_table->data(),
            rsrc.cos_table->data(), stream));
        RUN_INFINI(infiniopRoPE(
            desc_rope_k, workspace, workspace_size,
            qkv_buf->data(nh * dh), qkv_buf->data(nh * dh),
            pos_ids_buf->data(),
            rsrc.sin_table->data(),
            rsrc.cos_table->data(),
            stream));

        size_t token_offset = 0;
        for (uint32_t req = 0; req < nreq; req++) {
358
            auto past_len = req_pos[req];
PanZezhong's avatar
init  
PanZezhong committed
359
            auto seq_len = req_lens[req];
360
361
362
363
            auto o = o_buf->slice({{0, token_offset, seq_len}});
            auto q = qkv_buf->slice({{0, token_offset, seq_len}, {1, 0, nh}});
            auto k = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
            auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
PanZezhong's avatar
init  
PanZezhong committed
364
            // self attention
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
            // concat
            RUN_INFINI(infiniopRearrange(
                desc_kv_rearranges[req],
                kv_caches[req]->k[idev][layer]->data(past_len * nkvh * dh),
                k->data(), stream));
            RUN_INFINI(infiniopRearrange(
                desc_kv_rearranges[req],
                kv_caches[req]->v[idev][layer]->data(past_len * nkvh * dh),
                v->data(), stream));
            // qk
            RUN_INFINI(infiniopRearrange(desc_q_rearranges[req], rearrange_q_buf->data(), q->data(), stream));
            RUN_INFINI(infiniopGemm(
                desc_qk_gemms[req], workspace, workspace_size,
                qk_buf->data(), rearrange_q_buf->data(), kv_caches[req]->k[idev][layer]->data(), 1. / sqrt(dh), 0.0, stream));
            // softmax
            RUN_INFINI(infiniopCausalSoftmax(
                desc_qk_softmaxs[req], workspace, workspace_size,
                qk_buf->data(), qk_buf->data(), stream));
            // attn val
            RUN_INFINI(infiniopGemm(
                desc_attn_v_gemms[req], workspace, workspace_size,
                attn_val_buf->data(), qk_buf->data(), kv_caches[req]->v[idev][layer]->data(), 1.0, 0.0, stream));
            // rearrange attn val
            RUN_INFINI(infiniopRearrange(
                desc_attn_v_rearranges[req],
390
                o->data(),
391
                attn_val_buf->data(), stream));
PanZezhong's avatar
init  
PanZezhong committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405

            token_offset += seq_len;
        }
        // o_proj
        RUN_INFINI(infiniopGemm(
            desc_attn_o, workspace, workspace_size,
            logits_in->data(), o_buf->data(),
            rsrc.w_attn_out[layer]->data(), 1.0, idev == 0 ? 1.0 : 0.0, stream)); // only rank 0 adds residual

        // All_reduce if distributed
        if (rsrc.comm != nullptr) {
            RUN_INFINI(infinicclAllReduce(
                logits_in->data(), logits_in->data(), ntok * d, dt_logits,
                INFINICCL_SUM, rsrc.comm, stream));
PanZezhong's avatar
PanZezhong committed
406
            RUN_INFINI(infinirtStreamSynchronize(stream));
PanZezhong's avatar
init  
PanZezhong committed
407
408
409
410
411
412
413
414
415
416
417
418
419
        }
        // 2. FFN
        // rms_norm
        RUN_INFINI(infiniopRMSNorm(
            desc_norm, workspace, workspace_size,
            logits_out->data(), logits_in->data(),
            rsrc.w_ffn_norm[layer]->data(), stream));
        RUN_INFINI(infiniopGemm(
            desc_ffn_gate_up, workspace, workspace_size,
            gate_up_buf->data(), logits_out->data(), rsrc.w_ffn_gate_up[layer]->data(),
            1.0, 0.0, stream));
        RUN_INFINI(infiniopSwiGLU(
            desc_swiglu, workspace, workspace_size,
PanZezhong's avatar
PanZezhong committed
420
            gate_buf->data(), up_buf->data(), gate_buf->data(), stream));
PanZezhong's avatar
init  
PanZezhong committed
421
422
        RUN_INFINI(infiniopGemm(
            desc_ffn_down, workspace, workspace_size,
PanZezhong's avatar
PanZezhong committed
423
            logits_in->data(), gate_buf->data(),
PanZezhong's avatar
init  
PanZezhong committed
424
425
426
427
428
429
430
            rsrc.w_ffn_down[layer]->data(), 1.0, idev == 0 ? 1.0 : 0.0, stream)); // only rank 0 adds residual

        // All_reduce if distributed
        if (rsrc.comm != nullptr) {
            RUN_INFINI(infinicclAllReduce(
                logits_in->data(), logits_in->data(), ntok * d, dt_logits,
                INFINICCL_SUM, rsrc.comm, stream));
PanZezhong's avatar
PanZezhong committed
431
            RUN_INFINI(infinirtStreamSynchronize(stream));
PanZezhong's avatar
init  
PanZezhong committed
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
        }
    }
    // Sample and Output
    if (idev == 0) {
        size_t token_offset = 0;
        for (uint32_t req = 0; req < nreq; req++) {
            auto seq_len = req_lens[req];
            token_offset += seq_len;
            RUN_INFINI(infiniopRMSNorm(
                desc_norm_out, workspace, workspace_size,
                logits_out->data(req * d),
                logits_in->data((token_offset - 1) * d),
                rsrc.w_out_norm->data(), stream));
        }
        RUN_INFINI(infiniopGemm(
            desc_out_embd, workspace, workspace_size,
            prob_buf->data(), logits_out->data(),
            rsrc.w_out_embd->data(), 1.0, 0.0, stream));
        std::random_device _rd;
        std::mt19937 gen(_rd());
        token_offset = 0;
        for (uint32_t req = 0; req < nreq; req++) {
            auto seq_len = req_lens[req];
            float random_val = std::uniform_real_distribution<float>(0, 1)(gen);
PanZezhong's avatar
PanZezhong committed
456
            // prob_buf->debug();
PanZezhong's avatar
init  
PanZezhong committed
457
458
459
            RUN_INFINI(infiniopRandomSample(
                desc_sample, workspace, workspace_size,
                result_buf->data(req),
Pan Zezhong's avatar
Pan Zezhong committed
460
461
462
463
                prob_buf->data(req * dvoc),
                random_val,
                topp[req], topk[req], temperature[req],
                stream));
PanZezhong's avatar
PanZezhong committed
464
            // result_buf->debug();
PanZezhong's avatar
init  
PanZezhong committed
465
466
467
            token_offset += seq_len;
        }
        RUN_INFINI(infinirtStreamSynchronize(stream));
PanZezhong's avatar
PanZezhong committed
468
        RUN_INFINI(infinirtMemcpy(result_cpu.data(), result_buf->data(),
Pan Zezhong's avatar
Pan Zezhong committed
469
                                  sizeof(int64_t) * nreq, INFINIRT_MEMCPY_D2H));
PanZezhong's avatar
init  
PanZezhong committed
470
        for (uint32_t req = 0; req < nreq; req++) {
Pan Zezhong's avatar
Pan Zezhong committed
471
            output[req] = result_cpu[req];
PanZezhong's avatar
init  
PanZezhong committed
472
473
474
475
476
        }
    }

    // Clean up
    infiniopDestroyRMSNormDescriptor(desc_norm);
477
478
479
    if (has_qkv_bias) {
        infiniopDestroyRearrangeDescriptor(desc_qkv_bias);
    }
PanZezhong's avatar
init  
PanZezhong committed
480
481
482
483
484
    infiniopDestroyGemmDescriptor(desc_attn_qkv);
    infiniopDestroyGemmDescriptor(desc_attn_o);
    infiniopDestroyRoPEDescriptor(desc_rope_q);
    infiniopDestroyRoPEDescriptor(desc_rope_k);
    for (uint32_t req = 0; req < nreq; req++) {
485
486
487
488
489
490
        infiniopDestroyRearrangeDescriptor(desc_kv_rearranges[req]);
        infiniopDestroyRearrangeDescriptor(desc_q_rearranges[req]);
        infiniopDestroyGemmDescriptor(desc_qk_gemms[req]);
        infiniopDestroyCausalSoftmaxDescriptor(desc_qk_softmaxs[req]);
        infiniopDestroyGemmDescriptor(desc_attn_v_gemms[req]);
        infiniopDestroyRearrangeDescriptor(desc_attn_v_rearranges[req]);
PanZezhong's avatar
init  
PanZezhong committed
491
    }
PanZezhong's avatar
PanZezhong committed
492
493
494
    infiniopDestroyGemmDescriptor(desc_ffn_gate_up);
    infiniopDestroySwiGLUDescriptor(desc_swiglu);
    infiniopDestroyGemmDescriptor(desc_ffn_down);
PanZezhong's avatar
init  
PanZezhong committed
495
496
497
498
499
500
501
502
503
504
    infiniopDestroyRMSNormDescriptor(desc_norm_out);
    infiniopDestroyGemmDescriptor(desc_out_embd);
    infiniopDestroyRandomSampleDescriptor(desc_sample);
}

__C void
inferBatch(struct JiugeModel *model,
           const uint32_t *tokens, uint32_t ntok,
           const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
           struct KVCache **kv_caches,
Pan Zezhong's avatar
Pan Zezhong committed
505
506
           const float *temperature, const uint32_t *topk, const float *topp,
           uint32_t *output) {
PanZezhong's avatar
init  
PanZezhong committed
507
508
509
510
511
512
    model->req.tokens = tokens;
    model->req.ntok = ntok;
    model->req.req_lens = req_lens;
    model->req.nreq = nreq;
    model->req.req_pos = req_pos;
    model->req.kv_caches = kv_caches;
Pan Zezhong's avatar
Pan Zezhong committed
513
    model->req.output = output;
PanZezhong's avatar
init  
PanZezhong committed
514
515
516
517
518
519
520
521
    model->req.temperature = temperature;
    model->req.topk = topk;
    model->req.topp = topp;

    for (size_t idev = 0; idev < model->dev_ids.size(); idev++) {
        std::unique_lock<std::mutex> lock(model->states[idev].mtx);
        model->states[idev].proceed = true;
        lock.unlock();
PanZezhong's avatar
PanZezhong committed
522
523
524
525
526
527
528
        model->states[idev].cv_start.notify_one();
    }
    for (size_t i = model->dev_ids.size(); i > 0; i--) {
        auto idev = i - 1;
        std::unique_lock<std::mutex> lock(model->states[idev].mtx);
        model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); });
        lock.unlock();
PanZezhong's avatar
init  
PanZezhong committed
529
530
531
532
533
    }
}

void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceResource *rsrc, InferState &state, InferRequest &req,
                  infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) {
PanZezhong's avatar
PanZezhong committed
534
    // Create Device Resource
PanZezhong's avatar
init  
PanZezhong committed
535
    createDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm);
PanZezhong's avatar
PanZezhong committed
536
537
538
539
540
541
542
543
    {
        std::unique_lock<std::mutex> lock(state.mtx);
        state.loaded = true;
        lock.unlock();
        state.cv_load.notify_one();
    }

    // Infer Loop
PanZezhong's avatar
init  
PanZezhong committed
544
545
    while (true) {
        std::unique_lock<std::mutex> lock(state.mtx);
PanZezhong's avatar
PanZezhong committed
546
        state.cv_start.wait(lock, [&] { return state.proceed || state.exit_flag; });
PanZezhong's avatar
PanZezhong committed
547
        // quit if exit_flag is set
PanZezhong's avatar
init  
PanZezhong committed
548
549
550
551
        if (state.exit_flag) {
            break;
        }

Pan Zezhong's avatar
Pan Zezhong committed
552
        inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, req.req_lens, req.nreq, req.req_pos, req.kv_caches, req.temperature, req.topk, req.topp, req.output);
PanZezhong's avatar
init  
PanZezhong committed
553
554
555

        state.proceed = false;
        lock.unlock();
PanZezhong's avatar
PanZezhong committed
556
        state.cv_done.notify_one();
PanZezhong's avatar
init  
PanZezhong committed
557
558
    }

PanZezhong's avatar
PanZezhong committed
559
560
    // Clean-Up
    releaseDeviceResource(*rsrc);
PanZezhong's avatar
init  
PanZezhong committed
561
562
}

PanZezhong's avatar
PanZezhong committed
563
JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infiniDevice_t device_, std::vector<int> device_ids) : meta(*_meta) {
PanZezhong's avatar
init  
PanZezhong committed
564
    int ndev = int(device_ids.size());
PanZezhong's avatar
PanZezhong committed
565
    device = device_;
PanZezhong's avatar
init  
PanZezhong committed
566
567
568
569
570
571
572
573
574
575
576
577
578
    dev_ids = device_ids;
    dev_resources = std::vector<DeviceResource>(ndev);
    states = std::vector<InferState>(ndev);
    threads.resize(ndev);
    RUN_INFINI(infinirtInit());
    auto comms = std::vector<infinicclComm_t>(ndev, nullptr);
    if (ndev > 1) {
        RUN_INFINI(infinicclCommInitAll(device, comms.data(), ndev, dev_ids.data()));
    }

    for (int i = 0; i < ndev; i++) {
        threads[i] = std::thread(launchDevice, std::cref(meta), weights, &dev_resources[i], std::ref(states[i]), std::ref(req), device, i, ndev, dev_ids[i], comms[i]);
    }
PanZezhong's avatar
PanZezhong committed
579
580
581
582
583
    for (int i = 0; i < ndev; i++) {
        std::unique_lock<std::mutex> lock(states[i].mtx);
        states[i].cv_load.wait(lock, [&] { return states[i].loaded; });
        lock.unlock();
    }
PanZezhong's avatar
init  
PanZezhong committed
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
}

__C struct JiugeModel *
createJiugeModel(const JiugeMeta *meta,
                 const JiugeWeights *weights,
                 infiniDevice_t device,
                 int ndev,
                 const int *dev_ids) {
    std::vector<int> device_ids(ndev);
    std::copy(dev_ids, dev_ids + ndev, device_ids.begin());
    JiugeModel *model = new JiugeModel(meta, weights, device, device_ids);
    return model;
}

__C void destroyJiugeModel(struct JiugeModel *model) {
    auto ndev = model->dev_resources.size();

    for (size_t idev = 0; idev < ndev; idev++) {
        std::unique_lock<std::mutex> lock(model->states[idev].mtx);
        model->states[idev].exit_flag = true;
        lock.unlock();
PanZezhong's avatar
PanZezhong committed
605
        model->states[idev].cv_start.notify_one();
PanZezhong's avatar
init  
PanZezhong committed
606
607
608
609
610
611
612
613
    }

    for (size_t idev = 0; idev < ndev; idev++) {
        model->threads[idev].join();
    }

    delete model;
}