jiuge.cpp 18.3 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
4
5
#include "jiuge_impl.hpp"
#include "jiuge_weight.hpp"

#include "../../tensor.hpp"
#include "../../utils.hpp"
wooway777's avatar
wooway777 committed
6
#include "../inference_context.hpp"
PanZezhong's avatar
init  
PanZezhong committed
7
8
9
10
11
12
#include "infinicore_infer.h"

#include <random>
#include <thread>
#include <vector>

blkmjsian's avatar
blkmjsian committed
13
void createDeviceResource(JiugeDeviceResource *rsrc, const JiugeMeta *meta,
PanZezhong's avatar
init  
PanZezhong committed
14
15
16
17
18
19
20
21
22
23
                          const JiugeWeights *weights,
                          infiniDevice_t device, int idev,
                          int ndev, int dev_id,
                          infinicclComm_t comm) {
    RUN_INFINI(infinirtSetDevice(device, dev_id));
    infiniopHandle_t handle;
    infiniopCreateHandle(&handle);
    infinirtStream_t stream;
    infinirtStreamCreate(&stream);

24
    std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm, w_attn_out,
PanZezhong's avatar
init  
PanZezhong committed
25
26
27
        w_ffn_norm, w_ffn_gate_up, w_ffn_down;
    for (size_t layer = 0; layer < meta->nlayer; layer++) {
        w_attn_norm.push_back(
PanZezhong's avatar
PanZezhong committed
28
            getAttnNorm(meta, weights, layer));
PanZezhong's avatar
init  
PanZezhong committed
29
        w_attn_qkv.push_back(
PanZezhong's avatar
PanZezhong committed
30
            getAttnQKV(meta, weights, layer, idev, ndev));
PanZezhong's avatar
init  
PanZezhong committed
31
32
        if (weights->attn_qkv_b != nullptr) {
            b_attn_qkv.push_back(
PanZezhong's avatar
PanZezhong committed
33
                getAttnQKVBias(meta, weights, layer, idev, ndev));
PanZezhong's avatar
init  
PanZezhong committed
34
        }
35
36
37
38
39
40
41

        if (weights->attn_q_norm != nullptr) {
            w_attn_q_norm.push_back(
                getAttnQNorm(meta, weights, layer));
            w_attn_k_norm.push_back(
                getAttnKNorm(meta, weights, layer));
        }
PanZezhong's avatar
init  
PanZezhong committed
42
        w_attn_out.push_back(
PanZezhong's avatar
PanZezhong committed
43
            getAttnO(meta, weights, layer, idev, ndev));
PanZezhong's avatar
init  
PanZezhong committed
44
        w_ffn_norm.push_back(
PanZezhong's avatar
PanZezhong committed
45
            getFFNNorm(meta, weights, layer));
PanZezhong's avatar
init  
PanZezhong committed
46
        w_ffn_gate_up.push_back(
PanZezhong's avatar
PanZezhong committed
47
            getFFNGateUp(meta, weights, layer, idev, ndev));
PanZezhong's avatar
init  
PanZezhong committed
48
        w_ffn_down.push_back(
PanZezhong's avatar
PanZezhong committed
49
            getFFNDown(meta, weights, layer, idev, ndev));
PanZezhong's avatar
init  
PanZezhong committed
50
51
    }

thatPepe's avatar
thatPepe committed
52
53
    auto memory_pool = std::make_shared<MemoryPool>(128 * 1024 * 1024);

blkmjsian's avatar
blkmjsian committed
54
    *rsrc = JiugeDeviceResource{
PanZezhong's avatar
PanZezhong committed
55
56
57
58
59
60
61
62
63
64
65
        device,
        dev_id,
        handle,
        getInEmbd(meta, weights),
        getOutNorm(meta, weights),
        getOutEmbd(meta, weights),
        getSinTable(meta),
        getCosTable(meta),
        w_attn_norm,
        w_attn_qkv,
        b_attn_qkv,
66
67
        w_attn_q_norm,
        w_attn_k_norm,
PanZezhong's avatar
PanZezhong committed
68
69
70
71
72
73
        w_attn_out,
        w_ffn_norm,
        w_ffn_gate_up,
        w_ffn_down,
        stream,
        comm,
thatPepe's avatar
thatPepe committed
74
        memory_pool,
PanZezhong's avatar
PanZezhong committed
75
    };
PanZezhong's avatar
PanZezhong committed
76
77
78
    RUN_INFINI(infinirtDeviceSynchronize());
}

blkmjsian's avatar
blkmjsian committed
79
void releaseDeviceResource(JiugeDeviceResource &res) {
PanZezhong's avatar
PanZezhong committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    infinirtDeviceSynchronize();
    // Release individual Tensors
    res.w_in_embd.reset();
    res.w_out_norm.reset();
    res.w_out_embd.reset();
    res.sin_table.reset();
    res.cos_table.reset();
    for (auto &t : res.w_attn_norm) {
        t.reset();
    }
    res.w_attn_norm.clear();
    for (auto &t : res.w_attn_qkv) {
        t.reset();
    }
    res.w_attn_qkv.clear();
    for (auto &t : res.b_attn_qkv) {
        t.reset();
    }
    res.b_attn_qkv.clear();
    for (auto &t : res.w_attn_out) {
        t.reset();
    }
    res.w_attn_out.clear();
    for (auto &t : res.w_ffn_norm) {
        t.reset();
    }
    res.w_ffn_norm.clear();
    for (auto &t : res.w_ffn_gate_up) {
        t.reset();
    }
    res.w_ffn_gate_up.clear();
    for (auto &t : res.w_ffn_down) {
        t.reset();
    }
    res.w_ffn_down.clear();
    infiniopDestroyHandle(res.handle);
    res.handle = nullptr;
    infinirtStreamDestroy(res.stream);
    res.stream = nullptr;
    infinicclCommDestroy(res.comm);
    res.comm = nullptr;
PanZezhong's avatar
init  
PanZezhong committed
121
122
}

blkmjsian's avatar
blkmjsian committed
123
void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
PanZezhong's avatar
init  
PanZezhong committed
124
125
126
127
                      uint32_t idev, uint32_t ndev,
                      const uint32_t *tokens, uint32_t ntok,
                      const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
                      struct KVCache **kv_caches,
Pan Zezhong's avatar
Pan Zezhong committed
128
                      const float *temperature, const uint32_t *topk, const float *topp,
PanZezhong's avatar
PanZezhong committed
129
                      uint32_t *output, void *last_logits) {
PanZezhong's avatar
init  
PanZezhong committed
130
131
132
    auto nlayer = meta.nlayer;
    auto nkvh = meta.nkvh / ndev;
    auto nh = meta.nh / ndev;
133
    auto ngroup = nh / nkvh;
PanZezhong's avatar
init  
PanZezhong committed
134
135
136
137
138
139
140
    // auto dctx = meta.dctx;
    auto dh = meta.dh;
    auto d = meta.d;
    auto dt_logits = meta.dt_logits;
    auto di = meta.di / ndev;
    auto dvoc = meta.dvoc;
    auto stream = rsrc.stream;
PanZezhong's avatar
PanZezhong committed
141
    bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0;
142
    bool has_qk_norm = rsrc.w_attn_q_norm.size() > 0 && rsrc.w_attn_k_norm.size() > 0;
PanZezhong's avatar
init  
PanZezhong committed
143
144

    // Allocate buffers
thatPepe's avatar
thatPepe committed
145
146
147
148
149
150
151
    auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool);
    auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool);
    auto qkv_buf = Tensor::buffer(dt_logits, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool);
    auto gate_up_buf = Tensor::buffer(dt_logits, {ntok, 2 * di}, rsrc.memory_pool);
    auto o_buf = Tensor::buffer(dt_logits, {ntok, nh * dh}, rsrc.memory_pool);
    auto prob_buf = Tensor::buffer(dt_logits, {nreq, dvoc}, rsrc.memory_pool);
    auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool);
Pan Zezhong's avatar
Pan Zezhong committed
152
153
    auto result_cpu = std::vector<int64_t>(nreq);

154
    auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh});
155
156
    auto q_buf = qkv_rope->slice(1, 0, nh);
    auto k_buf = qkv_rope->slice(1, nh, nkvh);
157

PanZezhong's avatar
init  
PanZezhong committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    // Prepare inputs
    auto batch_pos_ids = std::vector<uint32_t>(ntok);
    size_t req_start = 0;
    for (uint32_t req = 0; req < nreq; req++) {
        for (uint32_t i = 0; i < req_lens[req]; i++) {
            batch_pos_ids[req_start + i] = req_pos[req] + i;
        }
        req_start += req_lens[req];
    }

    std::shared_ptr<Tensor> pos_ids_buf;
    if (rsrc.device == INFINI_DEVICE_CPU) {
        pos_ids_buf = Tensor::weight(batch_pos_ids.data(), INFINI_DTYPE_U32, {ntok});
    } else {
thatPepe's avatar
thatPepe committed
172
        pos_ids_buf = Tensor::buffer(INFINI_DTYPE_U32, {ntok}, rsrc.memory_pool);
PanZezhong's avatar
init  
PanZezhong committed
173
174
175
176
177
178
179
180
181
182
183
        RUN_INFINI(infinirtMemcpyAsync(pos_ids_buf->data(), batch_pos_ids.data(), sizeof(uint32_t) * ntok,
                                       INFINIRT_MEMCPY_H2D, stream));
    }
    for (uint32_t i = 0; i < ntok; i++) {
        RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d),
                                       rsrc.w_in_embd->data(tokens[i] * d),
                                       dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream));
    }

    // Attention
    // attention inner
184
185
    size_t max_qk_size = 0;
    size_t max_seq_len = 0;
wooway777's avatar
wooway777 committed
186

PanZezhong's avatar
init  
PanZezhong committed
187
188
189
    for (uint32_t req = 0; req < nreq; req++) {
        auto past_len = req_pos[req];
        auto seq_len = req_lens[req];
190
        auto total_len = past_len + seq_len;
wooway777's avatar
wooway777 committed
191

192
193
        max_qk_size = std::max(max_qk_size, size_t(seq_len * total_len));
        max_seq_len = std::max(max_seq_len, size_t(seq_len));
PanZezhong's avatar
init  
PanZezhong committed
194
    }
wooway777's avatar
wooway777 committed
195

196
    auto qk_buf = Tensor::buffer(dt_logits, {nh * max_qk_size}, rsrc.memory_pool);
thatPepe's avatar
thatPepe committed
197
    auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
198
    auto q_rearrange = rearrange_q_buf->view({nkvh, ngroup, max_seq_len, dh});
199
    auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
200
    auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh});
PanZezhong's avatar
init  
PanZezhong committed
201

wooway777's avatar
wooway777 committed
202
    // MLP buffers
PanZezhong's avatar
init  
PanZezhong committed
203
204
205
    auto gate_buf = gate_up_buf->slice(1, 0, di);
    auto up_buf = gate_up_buf->slice(1, di, di);

PanZezhong's avatar
PanZezhong committed
206
    // Compute
PanZezhong's avatar
init  
PanZezhong committed
207
208
209
    for (uint32_t layer = 0; layer < nlayer; layer++) {
        // 1. Attention
        // rms norm
210
        rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon);
PanZezhong's avatar
init  
PanZezhong committed
211
        // qkv_proj
212
        linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, nullptr, has_qkv_bias ? rsrc.b_attn_qkv[layer] : nullptr);
213
214
215
216
217
218

        if (has_qk_norm) {
            rmsnorm(q_buf, q_buf, rsrc.w_attn_q_norm[layer], meta.epsilon);
            rmsnorm(k_buf, k_buf, rsrc.w_attn_k_norm[layer], meta.epsilon);
        }

PanZezhong's avatar
init  
PanZezhong committed
219
        // rope
220
221
        rope(q_buf, q_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
        rope(k_buf, k_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
PanZezhong's avatar
init  
PanZezhong committed
222
223
224

        size_t token_offset = 0;
        for (uint32_t req = 0; req < nreq; req++) {
225
            auto past_len = req_pos[req];
PanZezhong's avatar
init  
PanZezhong committed
226
            auto seq_len = req_lens[req];
wooway777's avatar
wooway777 committed
227
            auto total_len = past_len + seq_len;
228
            auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
229
            auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
230
231
            auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
            auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
wooway777's avatar
wooway777 committed
232

PanZezhong's avatar
init  
PanZezhong committed
233
            // self attention
234
            // concat
235
236
            rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k);
            rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v);
237
            // qk
238
            rearrange(q_rearrange->slice(2, 0, seq_len), q);
239
            auto qk_gemm = qk_buf->slice(0, 0, nh * seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len});
240
            auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0});
241
            linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr);
242
            // softmax
243
            auto qk_softmax = qk_gemm->view({nh, seq_len, total_len});
244
            causalSoftmax(qk_softmax, qk_softmax);
245
            auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2});
246
            linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr);
247
            // rearrange attn val
248
            rearrange(o, attn_val_gemm->slice(2, 0, seq_len));
PanZezhong's avatar
init  
PanZezhong committed
249
250
251

            token_offset += seq_len;
        }
252

PanZezhong's avatar
init  
PanZezhong committed
253
        // o_proj
254
        linear(logits_in, o_buf, rsrc.w_attn_out[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); // only rank 0 adds residual
PanZezhong's avatar
init  
PanZezhong committed
255
256
257
258
259
260

        // All_reduce if distributed
        if (rsrc.comm != nullptr) {
            RUN_INFINI(infinicclAllReduce(
                logits_in->data(), logits_in->data(), ntok * d, dt_logits,
                INFINICCL_SUM, rsrc.comm, stream));
PanZezhong's avatar
PanZezhong committed
261
            RUN_INFINI(infinirtStreamSynchronize(stream));
PanZezhong's avatar
init  
PanZezhong committed
262
263
        }
        // 2. FFN
264
        rmsnorm(logits_out, logits_in, rsrc.w_ffn_norm[layer], meta.epsilon);
265
        linear(gate_up_buf, logits_out, rsrc.w_ffn_gate_up[layer], 1.0, 0.0, nullptr, nullptr);
266
        swiglu(gate_buf, up_buf, gate_buf);
267
        linear(logits_in, gate_buf, rsrc.w_ffn_down[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); // only rank 0 adds residual
PanZezhong's avatar
init  
PanZezhong committed
268
269
270
271
272
273

        // All_reduce if distributed
        if (rsrc.comm != nullptr) {
            RUN_INFINI(infinicclAllReduce(
                logits_in->data(), logits_in->data(), ntok * d, dt_logits,
                INFINICCL_SUM, rsrc.comm, stream));
PanZezhong's avatar
PanZezhong committed
274
            RUN_INFINI(infinirtStreamSynchronize(stream));
PanZezhong's avatar
init  
PanZezhong committed
275
276
277
278
        }
    }
    // Sample and Output
    if (idev == 0) {
PanZezhong's avatar
PanZezhong committed
279
280
281
282
283
284
        if (last_logits != nullptr) {
            rmsnorm(logits_out, logits_in, rsrc.w_out_norm, meta.epsilon);
            auto last_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool);
            linear(last_logits_buf, logits_out, rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr);
            RUN_INFINI(infinirtStreamSynchronize(stream));
            RUN_INFINI(infinirtMemcpy(last_logits, last_logits_buf->data(), dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H));
PanZezhong's avatar
init  
PanZezhong committed
285
        }
PanZezhong's avatar
PanZezhong committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        if (output != nullptr) {
            size_t token_offset = 0;
            for (uint32_t req = 0; req < nreq; req++) {
                auto seq_len = req_lens[req];
                token_offset += seq_len;
                rmsnorm(logits_out->slice(0, req, 1),
                        logits_in->slice(0, token_offset - 1, 1),
                        rsrc.w_out_norm,
                        meta.epsilon);
            }
            linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr);
            std::random_device _rd;
            std::mt19937 gen(_rd());
            token_offset = 0;
            for (uint32_t req = 0; req < nreq; req++) {
                auto seq_len = req_lens[req];
                float random_val = std::uniform_real_distribution<float>(0, 1)(gen);
303
304
                randomSample(result_buf->slice(0, req, 1)->view_as({}, {}),
                             prob_buf->slice(0, req, 1)->view_as({dvoc}, {1}),
PanZezhong's avatar
PanZezhong committed
305
306
307
308
309
310
311
312
313
                             random_val, topp[req], topk[req], temperature[req]);
                token_offset += seq_len;
            }
            RUN_INFINI(infinirtStreamSynchronize(stream));
            RUN_INFINI(infinirtMemcpy(result_cpu.data(), result_buf->data(),
                                      sizeof(int64_t) * nreq, INFINIRT_MEMCPY_D2H));
            for (uint32_t req = 0; req < nreq; req++) {
                output[req] = uint32_t(result_cpu[req]);
            }
PanZezhong's avatar
init  
PanZezhong committed
314
315
316
317
        }
    }
}

318
__INFINI_C void
blkmjsian's avatar
blkmjsian committed
319
inferBatchJiuge(struct JiugeModel *model,
320
321
322
323
324
                const uint32_t *tokens, uint32_t ntok,
                const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
                struct KVCache **kv_caches,
                const float *temperature, const uint32_t *topk, const float *topp,
                uint32_t *output) {
PanZezhong's avatar
init  
PanZezhong committed
325
326
327
328
329
330
    model->req.tokens = tokens;
    model->req.ntok = ntok;
    model->req.req_lens = req_lens;
    model->req.nreq = nreq;
    model->req.req_pos = req_pos;
    model->req.kv_caches = kv_caches;
Pan Zezhong's avatar
Pan Zezhong committed
331
    model->req.output = output;
PanZezhong's avatar
PanZezhong committed
332
    model->req.logits = nullptr;
PanZezhong's avatar
init  
PanZezhong committed
333
334
335
336
337
338
339
340
    model->req.temperature = temperature;
    model->req.topk = topk;
    model->req.topp = topp;

    for (size_t idev = 0; idev < model->dev_ids.size(); idev++) {
        std::unique_lock<std::mutex> lock(model->states[idev].mtx);
        model->states[idev].proceed = true;
        lock.unlock();
PanZezhong's avatar
PanZezhong committed
341
342
343
344
345
346
347
        model->states[idev].cv_start.notify_one();
    }
    for (size_t i = model->dev_ids.size(); i > 0; i--) {
        auto idev = i - 1;
        std::unique_lock<std::mutex> lock(model->states[idev].mtx);
        model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); });
        lock.unlock();
PanZezhong's avatar
init  
PanZezhong committed
348
349
350
    }
}

351
__INFINI_C void
blkmjsian's avatar
blkmjsian committed
352
forwardBatchJiuge(struct JiugeModel *model,
353
354
355
356
                  const uint32_t *tokens, uint32_t ntok,
                  const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
                  struct KVCache **kv_caches,
                  void *logits) {
PanZezhong's avatar
PanZezhong committed
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
    model->req.tokens = tokens;
    model->req.ntok = ntok;
    model->req.req_lens = req_lens;
    model->req.nreq = nreq;
    model->req.req_pos = req_pos;
    model->req.kv_caches = kv_caches;
    model->req.output = nullptr;
    model->req.logits = logits;
    model->req.temperature = nullptr;
    model->req.topk = nullptr;
    model->req.topp = nullptr;

    for (size_t idev = 0; idev < model->dev_ids.size(); idev++) {
        std::unique_lock<std::mutex> lock(model->states[idev].mtx);
        model->states[idev].proceed = true;
        lock.unlock();
        model->states[idev].cv_start.notify_one();
    }
    for (size_t i = model->dev_ids.size(); i > 0; i--) {
        auto idev = i - 1;
        std::unique_lock<std::mutex> lock(model->states[idev].mtx);
        model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); });
        lock.unlock();
    }
}

blkmjsian's avatar
blkmjsian committed
383
void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, JiugeDeviceResource *rsrc, InferState &state, InferRequest &req,
PanZezhong's avatar
init  
PanZezhong committed
384
                  infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) {
blkmjsian's avatar
blkmjsian committed
385
386
387
    // Create Device Resource
    createDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm);

388
    CacheManager cache_manager(100);
blkmjsian's avatar
blkmjsian committed
389
    InferenceContext ctx(rsrc->handle, rsrc->memory_pool, &cache_manager, rsrc->stream);
wooway777's avatar
wooway777 committed
390

391
392
393
    // Set the inference context for this thread
    setInferenceContext(&ctx);

PanZezhong's avatar
PanZezhong committed
394
395
396
397
398
399
400
401
    {
        std::unique_lock<std::mutex> lock(state.mtx);
        state.loaded = true;
        lock.unlock();
        state.cv_load.notify_one();
    }

    // Infer Loop
PanZezhong's avatar
init  
PanZezhong committed
402
403
    while (true) {
        std::unique_lock<std::mutex> lock(state.mtx);
PanZezhong's avatar
PanZezhong committed
404
        state.cv_start.wait(lock, [&] { return state.proceed || state.exit_flag; });
PanZezhong's avatar
PanZezhong committed
405
        // quit if exit_flag is set
PanZezhong's avatar
init  
PanZezhong committed
406
407
408
409
        if (state.exit_flag) {
            break;
        }

wooway777's avatar
wooway777 committed
410
411
        inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok,
                         req.req_lens, req.nreq, req.req_pos, req.kv_caches,
PanZezhong's avatar
PanZezhong committed
412
                         req.temperature, req.topk, req.topp, req.output, req.logits);
PanZezhong's avatar
init  
PanZezhong committed
413
414
415

        state.proceed = false;
        lock.unlock();
PanZezhong's avatar
PanZezhong committed
416
        state.cv_done.notify_one();
PanZezhong's avatar
init  
PanZezhong committed
417
418
    }

PanZezhong's avatar
PanZezhong committed
419
420
    // Clean-Up
    releaseDeviceResource(*rsrc);
421
    setInferenceContext(nullptr); // Clear the context when done
PanZezhong's avatar
init  
PanZezhong committed
422
423
}

PanZezhong's avatar
PanZezhong committed
424
JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infiniDevice_t device_, std::vector<int> device_ids) : meta(*_meta) {
PanZezhong's avatar
init  
PanZezhong committed
425
    int ndev = int(device_ids.size());
PanZezhong's avatar
PanZezhong committed
426
    device = device_;
PanZezhong's avatar
init  
PanZezhong committed
427
    dev_ids = device_ids;
blkmjsian's avatar
blkmjsian committed
428
    dev_resources = std::vector<JiugeDeviceResource>(ndev);
PanZezhong's avatar
init  
PanZezhong committed
429
430
431
432
433
434
435
436
437
438
439
    states = std::vector<InferState>(ndev);
    threads.resize(ndev);
    RUN_INFINI(infinirtInit());
    auto comms = std::vector<infinicclComm_t>(ndev, nullptr);
    if (ndev > 1) {
        RUN_INFINI(infinicclCommInitAll(device, comms.data(), ndev, dev_ids.data()));
    }

    for (int i = 0; i < ndev; i++) {
        threads[i] = std::thread(launchDevice, std::cref(meta), weights, &dev_resources[i], std::ref(states[i]), std::ref(req), device, i, ndev, dev_ids[i], comms[i]);
    }
PanZezhong's avatar
PanZezhong committed
440
441
442
443
444
    for (int i = 0; i < ndev; i++) {
        std::unique_lock<std::mutex> lock(states[i].mtx);
        states[i].cv_load.wait(lock, [&] { return states[i].loaded; });
        lock.unlock();
    }
PanZezhong's avatar
init  
PanZezhong committed
445
446
}

447
__INFINI_C struct JiugeModel *
PanZezhong's avatar
init  
PanZezhong committed
448
449
450
451
452
453
454
455
456
457
458
createJiugeModel(const JiugeMeta *meta,
                 const JiugeWeights *weights,
                 infiniDevice_t device,
                 int ndev,
                 const int *dev_ids) {
    std::vector<int> device_ids(ndev);
    std::copy(dev_ids, dev_ids + ndev, device_ids.begin());
    JiugeModel *model = new JiugeModel(meta, weights, device, device_ids);
    return model;
}

459
__INFINI_C void destroyJiugeModel(struct JiugeModel *model) {
PanZezhong's avatar
init  
PanZezhong committed
460
461
462
463
464
465
    auto ndev = model->dev_resources.size();

    for (size_t idev = 0; idev < ndev; idev++) {
        std::unique_lock<std::mutex> lock(model->states[idev].mtx);
        model->states[idev].exit_flag = true;
        lock.unlock();
PanZezhong's avatar
PanZezhong committed
466
        model->states[idev].cv_start.notify_one();
PanZezhong's avatar
init  
PanZezhong committed
467
468
469
470
471
472
473
    }

    for (size_t idev = 0; idev < ndev; idev++) {
        model->threads[idev].join();
    }

    delete model;
474
}