jiuge.cpp 16.8 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
4
5
#include "jiuge_impl.hpp"
#include "jiuge_weight.hpp"

#include "../../tensor.hpp"
#include "../../utils.hpp"
wooway777's avatar
wooway777 committed
6
#include "../inference_context.hpp"
PanZezhong's avatar
init  
PanZezhong committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include "infinicore_infer.h"

#include <random>
#include <thread>
#include <vector>

void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
                          const JiugeWeights *weights,
                          infiniDevice_t device, int idev,
                          int ndev, int dev_id,
                          infinicclComm_t comm) {
    RUN_INFINI(infinirtSetDevice(device, dev_id));
    infiniopHandle_t handle;
    infiniopCreateHandle(&handle);
    infinirtStream_t stream;
    infinirtStreamCreate(&stream);

    std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out,
        w_ffn_norm, w_ffn_gate_up, w_ffn_down;
    for (size_t layer = 0; layer < meta->nlayer; layer++) {
        w_attn_norm.push_back(
PanZezhong's avatar
PanZezhong committed
28
            getAttnNorm(meta, weights, layer));
PanZezhong's avatar
init  
PanZezhong committed
29
        w_attn_qkv.push_back(
PanZezhong's avatar
PanZezhong committed
30
            getAttnQKV(meta, weights, layer, idev, ndev));
PanZezhong's avatar
init  
PanZezhong committed
31
32
        if (weights->attn_qkv_b != nullptr) {
            b_attn_qkv.push_back(
PanZezhong's avatar
PanZezhong committed
33
                getAttnQKVBias(meta, weights, layer, idev, ndev));
PanZezhong's avatar
init  
PanZezhong committed
34
35
        }
        w_attn_out.push_back(
PanZezhong's avatar
PanZezhong committed
36
            getAttnO(meta, weights, layer, idev, ndev));
PanZezhong's avatar
init  
PanZezhong committed
37
        w_ffn_norm.push_back(
PanZezhong's avatar
PanZezhong committed
38
            getFFNNorm(meta, weights, layer));
PanZezhong's avatar
init  
PanZezhong committed
39
        w_ffn_gate_up.push_back(
PanZezhong's avatar
PanZezhong committed
40
            getFFNGateUp(meta, weights, layer, idev, ndev));
PanZezhong's avatar
init  
PanZezhong committed
41
        w_ffn_down.push_back(
PanZezhong's avatar
PanZezhong committed
42
            getFFNDown(meta, weights, layer, idev, ndev));
PanZezhong's avatar
init  
PanZezhong committed
43
44
    }

thatPepe's avatar
thatPepe committed
45
46
    auto memory_pool = std::make_shared<MemoryPool>(128 * 1024 * 1024);

PanZezhong's avatar
PanZezhong committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    *rsrc = DeviceResource{
        device,
        dev_id,
        handle,
        getInEmbd(meta, weights),
        getOutNorm(meta, weights),
        getOutEmbd(meta, weights),
        getSinTable(meta),
        getCosTable(meta),
        w_attn_norm,
        w_attn_qkv,
        b_attn_qkv,
        w_attn_out,
        w_ffn_norm,
        w_ffn_gate_up,
        w_ffn_down,
        stream,
        comm,
thatPepe's avatar
thatPepe committed
65
        memory_pool,
PanZezhong's avatar
PanZezhong committed
66
    };
PanZezhong's avatar
PanZezhong committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    RUN_INFINI(infinirtDeviceSynchronize());
}

void releaseDeviceResource(DeviceResource &res) {
    infinirtDeviceSynchronize();
    // Release individual Tensors
    res.w_in_embd.reset();
    res.w_out_norm.reset();
    res.w_out_embd.reset();
    res.sin_table.reset();
    res.cos_table.reset();
    for (auto &t : res.w_attn_norm) {
        t.reset();
    }
    res.w_attn_norm.clear();
    for (auto &t : res.w_attn_qkv) {
        t.reset();
    }
    res.w_attn_qkv.clear();
    for (auto &t : res.b_attn_qkv) {
        t.reset();
    }
    res.b_attn_qkv.clear();
    for (auto &t : res.w_attn_out) {
        t.reset();
    }
    res.w_attn_out.clear();
    for (auto &t : res.w_ffn_norm) {
        t.reset();
    }
    res.w_ffn_norm.clear();
    for (auto &t : res.w_ffn_gate_up) {
        t.reset();
    }
    res.w_ffn_gate_up.clear();
    for (auto &t : res.w_ffn_down) {
        t.reset();
    }
    res.w_ffn_down.clear();
    infiniopDestroyHandle(res.handle);
    res.handle = nullptr;
    infinirtStreamDestroy(res.stream);
    res.stream = nullptr;
    infinicclCommDestroy(res.comm);
    res.comm = nullptr;
PanZezhong's avatar
init  
PanZezhong committed
112
113
}

PanZezhong's avatar
PanZezhong committed
114
void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
PanZezhong's avatar
init  
PanZezhong committed
115
116
117
118
                      uint32_t idev, uint32_t ndev,
                      const uint32_t *tokens, uint32_t ntok,
                      const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
                      struct KVCache **kv_caches,
Pan Zezhong's avatar
Pan Zezhong committed
119
                      const float *temperature, const uint32_t *topk, const float *topp,
wooway777's avatar
wooway777 committed
120
                      uint32_t *output, InferenceContext &ctx) {
PanZezhong's avatar
init  
PanZezhong committed
121
122
123
    auto nlayer = meta.nlayer;
    auto nkvh = meta.nkvh / ndev;
    auto nh = meta.nh / ndev;
124
    auto ngroup = nh / nkvh;
PanZezhong's avatar
init  
PanZezhong committed
125
126
127
128
129
130
131
    // auto dctx = meta.dctx;
    auto dh = meta.dh;
    auto d = meta.d;
    auto dt_logits = meta.dt_logits;
    auto di = meta.di / ndev;
    auto dvoc = meta.dvoc;
    auto stream = rsrc.stream;
PanZezhong's avatar
PanZezhong committed
132
    bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0;
PanZezhong's avatar
init  
PanZezhong committed
133
134

    // Allocate buffers
thatPepe's avatar
thatPepe committed
135
136
137
138
139
140
141
    auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool);
    auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool);
    auto qkv_buf = Tensor::buffer(dt_logits, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool);
    auto gate_up_buf = Tensor::buffer(dt_logits, {ntok, 2 * di}, rsrc.memory_pool);
    auto o_buf = Tensor::buffer(dt_logits, {ntok, nh * dh}, rsrc.memory_pool);
    auto prob_buf = Tensor::buffer(dt_logits, {nreq, dvoc}, rsrc.memory_pool);
    auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool);
Pan Zezhong's avatar
Pan Zezhong committed
142
143
    auto result_cpu = std::vector<int64_t>(nreq);

PanZezhong's avatar
init  
PanZezhong committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    // Prepare inputs
    auto batch_pos_ids = std::vector<uint32_t>(ntok);
    size_t req_start = 0;
    for (uint32_t req = 0; req < nreq; req++) {
        for (uint32_t i = 0; i < req_lens[req]; i++) {
            batch_pos_ids[req_start + i] = req_pos[req] + i;
        }
        req_start += req_lens[req];
    }

    std::shared_ptr<Tensor> pos_ids_buf;
    if (rsrc.device == INFINI_DEVICE_CPU) {
        pos_ids_buf = Tensor::weight(batch_pos_ids.data(), INFINI_DTYPE_U32, {ntok});
    } else {
thatPepe's avatar
thatPepe committed
158
        pos_ids_buf = Tensor::buffer(INFINI_DTYPE_U32, {ntok}, rsrc.memory_pool);
PanZezhong's avatar
init  
PanZezhong committed
159
160
161
162
163
164
165
166
167
168
        RUN_INFINI(infinirtMemcpyAsync(pos_ids_buf->data(), batch_pos_ids.data(), sizeof(uint32_t) * ntok,
                                       INFINIRT_MEMCPY_H2D, stream));
    }
    for (uint32_t i = 0; i < ntok; i++) {
        RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d),
                                       rsrc.w_in_embd->data(tokens[i] * d),
                                       dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream));
    }

    // Attention
wooway777's avatar
wooway777 committed
169
170
171
172
    auto qkv_desc = TensorDesc::create(dt_logits, qkv_buf->shape(), qkv_buf->strides());
    auto b_attn_qkv_desc = TensorDesc::create(dt_logits, {ntok, (nh + nkvh * 2) * dh}, {0, 1});
    auto o_desc = TensorDesc::create(dt_logits, o_buf->shape(), o_buf->strides());

PanZezhong's avatar
PanZezhong committed
173
    qkv_buf->dimSplit(1, {nh + nkvh * 2, dh}); // (ntok, nh + 2 * nkvh, dh)
PanZezhong's avatar
init  
PanZezhong committed
174
    // attention inner
175
176
    size_t max_qk_size = 0;
    size_t max_seq_len = 0;
PanZezhong's avatar
PanZezhong committed
177
    o_buf->dimSplit(1, {nh, dh});
wooway777's avatar
wooway777 committed
178

PanZezhong's avatar
init  
PanZezhong committed
179
180
181
    for (uint32_t req = 0; req < nreq; req++) {
        auto past_len = req_pos[req];
        auto seq_len = req_lens[req];
182
        auto total_len = past_len + seq_len;
wooway777's avatar
wooway777 committed
183

184
185
        max_qk_size = std::max(max_qk_size, size_t(seq_len * total_len));
        max_seq_len = std::max(max_seq_len, size_t(seq_len));
PanZezhong's avatar
init  
PanZezhong committed
186
    }
wooway777's avatar
wooway777 committed
187

thatPepe's avatar
thatPepe committed
188
189
190
    auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool);
    auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
    auto attn_val_buf = Tensor::buffer(dt_logits, {nh, max_seq_len, dh}, rsrc.memory_pool);
PanZezhong's avatar
init  
PanZezhong committed
191

wooway777's avatar
wooway777 committed
192
    // MLP buffers
PanZezhong's avatar
init  
PanZezhong committed
193
194
195
196
    auto gate_buf = gate_up_buf->slice(1, 0, di);
    auto up_buf = gate_up_buf->slice(1, di, di);

    // Output and sample
wooway777's avatar
wooway777 committed
197
198
    auto result_desc = TensorDesc::create(INFINI_DTYPE_I64, {}, {});
    auto prob_desc = TensorDesc::create(dt_logits, {dvoc}, {1});
PanZezhong's avatar
PanZezhong committed
199

PanZezhong's avatar
PanZezhong committed
200
    // Compute
PanZezhong's avatar
init  
PanZezhong committed
201
202
203
    for (uint32_t layer = 0; layer < nlayer; layer++) {
        // 1. Attention
        // rms norm
wooway777's avatar
wooway777 committed
204
        ctx.rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon);
PanZezhong's avatar
init  
PanZezhong committed
205
        // qkv_proj
PanZezhong's avatar
PanZezhong committed
206
        if (has_qkv_bias) {
wooway777's avatar
wooway777 committed
207
            ctx.rearrange(qkv_buf, qkv_desc, rsrc.b_attn_qkv[layer], b_attn_qkv_desc);
PanZezhong's avatar
PanZezhong committed
208
        }
wooway777's avatar
wooway777 committed
209
210
211
212
        ctx.gemm(qkv_buf, qkv_desc,
                 logits_out, nullptr,
                 rsrc.w_attn_qkv[layer], nullptr,
                 1.0, has_qkv_bias ? 1.0 : 0.0);
PanZezhong's avatar
init  
PanZezhong committed
213
        // rope
wooway777's avatar
wooway777 committed
214
215
        ctx.rope(qkv_buf->slice(1, 0, nh), qkv_buf->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
        ctx.rope(qkv_buf->slice(1, nh, nkvh), qkv_buf->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
PanZezhong's avatar
init  
PanZezhong committed
216
217
218

        size_t token_offset = 0;
        for (uint32_t req = 0; req < nreq; req++) {
219
            auto past_len = req_pos[req];
PanZezhong's avatar
init  
PanZezhong committed
220
            auto seq_len = req_lens[req];
wooway777's avatar
wooway777 committed
221
            auto total_len = past_len + seq_len;
222
223
224
225
            auto o = o_buf->slice({{0, token_offset, seq_len}});
            auto q = qkv_buf->slice({{0, token_offset, seq_len}, {1, 0, nh}});
            auto k = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
            auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
wooway777's avatar
wooway777 committed
226
227
228
229

            auto qt_rearrange_desc = TensorDesc::create(dt_logits, {nkvh, ngroup, seq_len, dh});
            auto qt_gemm_desc = TensorDesc::create(dt_logits, {nkvh, ngroup * seq_len, dh});
            auto qk_gemm_desc = TensorDesc::create(dt_logits, {nkvh, ngroup * seq_len, total_len});
PanZezhong's avatar
init  
PanZezhong committed
230
            // self attention
231
            // concat
wooway777's avatar
wooway777 committed
232
233
            ctx.rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), nullptr, k, nullptr);
            ctx.rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), nullptr, v, nullptr);
234
            // qk
wooway777's avatar
wooway777 committed
235
236
237
238
239
240
            ctx.rearrange(rearrange_q_buf, qt_rearrange_desc,
                          q->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3}), nullptr);
            ctx.gemm(qk_buf, qk_gemm_desc,
                     rearrange_q_buf, qt_gemm_desc,
                     kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}), nullptr,
                     1. / sqrt(dh), 0.0);
241
            // softmax
wooway777's avatar
wooway777 committed
242
243
244
245
246
247
            auto qk_desc = TensorDesc::create(dt_logits, {nkvh * ngroup, seq_len, total_len});
            ctx.causalSoftmax(qk_buf, qk_desc, qk_buf, qk_desc);
            ctx.gemm(attn_val_buf, qt_gemm_desc,
                     qk_buf, qk_gemm_desc,
                     kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}), nullptr,
                     1.0, 0.0);
248
            // rearrange attn val
wooway777's avatar
wooway777 committed
249
250
            ctx.rearrange(o, TensorDesc::createWithOrder(dt_logits, {nkvh, ngroup, seq_len, dh}, {1, 2, 0, 3}),
                          attn_val_buf, qt_rearrange_desc);
PanZezhong's avatar
init  
PanZezhong committed
251
252
253
254

            token_offset += seq_len;
        }
        // o_proj
wooway777's avatar
wooway777 committed
255
256
257
258
        ctx.gemm(logits_in, nullptr,
                 o_buf, o_desc,
                 rsrc.w_attn_out[layer], nullptr,
                 1.0, idev == 0 ? 1.0 : 0.0); // only rank 0 adds residual
PanZezhong's avatar
init  
PanZezhong committed
259
260
261
262
263
264

        // All_reduce if distributed
        if (rsrc.comm != nullptr) {
            RUN_INFINI(infinicclAllReduce(
                logits_in->data(), logits_in->data(), ntok * d, dt_logits,
                INFINICCL_SUM, rsrc.comm, stream));
PanZezhong's avatar
PanZezhong committed
265
            RUN_INFINI(infinirtStreamSynchronize(stream));
PanZezhong's avatar
init  
PanZezhong committed
266
267
268
        }
        // 2. FFN
        // rms_norm
wooway777's avatar
wooway777 committed
269
270
271
272
273
274
275
276
277
278
        ctx.rmsnorm(logits_out, logits_in, rsrc.w_ffn_norm[layer], meta.epsilon);
        ctx.gemm(gate_up_buf, nullptr,
                 logits_out, nullptr,
                 rsrc.w_ffn_gate_up[layer], nullptr,
                 1.0, 0.0);
        ctx.swiglu(gate_buf, up_buf, gate_buf);
        ctx.gemm(logits_in, nullptr,
                 gate_buf, nullptr,
                 rsrc.w_ffn_down[layer], nullptr,
                 1.0, idev == 0 ? 1.0 : 0.0); // only rank 0 adds residual
PanZezhong's avatar
init  
PanZezhong committed
279
280
281
282
283
284

        // All_reduce if distributed
        if (rsrc.comm != nullptr) {
            RUN_INFINI(infinicclAllReduce(
                logits_in->data(), logits_in->data(), ntok * d, dt_logits,
                INFINICCL_SUM, rsrc.comm, stream));
PanZezhong's avatar
PanZezhong committed
285
            RUN_INFINI(infinirtStreamSynchronize(stream));
PanZezhong's avatar
init  
PanZezhong committed
286
287
288
289
290
291
292
293
        }
    }
    // Sample and Output
    if (idev == 0) {
        size_t token_offset = 0;
        for (uint32_t req = 0; req < nreq; req++) {
            auto seq_len = req_lens[req];
            token_offset += seq_len;
wooway777's avatar
wooway777 committed
294
295
296
297
            ctx.rmsnorm(logits_out->slice(0, req, 1),
                        logits_in->slice(0, token_offset - 1, 1),
                        rsrc.w_out_norm,
                        meta.epsilon);
PanZezhong's avatar
init  
PanZezhong committed
298
        }
wooway777's avatar
wooway777 committed
299
300
301
302
        ctx.gemm(prob_buf, nullptr,
                 logits_out->slice(0, 0, nreq), nullptr,
                 rsrc.w_out_embd, nullptr,
                 1.0, 0.0);
PanZezhong's avatar
init  
PanZezhong committed
303
304
305
306
307
308
        std::random_device _rd;
        std::mt19937 gen(_rd());
        token_offset = 0;
        for (uint32_t req = 0; req < nreq; req++) {
            auto seq_len = req_lens[req];
            float random_val = std::uniform_real_distribution<float>(0, 1)(gen);
wooway777's avatar
wooway777 committed
309
310
311
            ctx.randomSample(result_buf->slice(0, req, 1), result_desc,
                             prob_buf->slice(0, req, 1), prob_desc,
                             random_val, topp[req], topk[req], temperature[req]);
PanZezhong's avatar
init  
PanZezhong committed
312
313
314
            token_offset += seq_len;
        }
        RUN_INFINI(infinirtStreamSynchronize(stream));
PanZezhong's avatar
PanZezhong committed
315
        RUN_INFINI(infinirtMemcpy(result_cpu.data(), result_buf->data(),
Pan Zezhong's avatar
Pan Zezhong committed
316
                                  sizeof(int64_t) * nreq, INFINIRT_MEMCPY_D2H));
PanZezhong's avatar
init  
PanZezhong committed
317
        for (uint32_t req = 0; req < nreq; req++) {
Pan Zezhong's avatar
Pan Zezhong committed
318
            output[req] = result_cpu[req];
PanZezhong's avatar
init  
PanZezhong committed
319
320
321
322
323
324
325
326
327
        }
    }
}

__C void
inferBatch(struct JiugeModel *model,
           const uint32_t *tokens, uint32_t ntok,
           const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
           struct KVCache **kv_caches,
Pan Zezhong's avatar
Pan Zezhong committed
328
329
           const float *temperature, const uint32_t *topk, const float *topp,
           uint32_t *output) {
PanZezhong's avatar
init  
PanZezhong committed
330
331
332
333
334
335
    model->req.tokens = tokens;
    model->req.ntok = ntok;
    model->req.req_lens = req_lens;
    model->req.nreq = nreq;
    model->req.req_pos = req_pos;
    model->req.kv_caches = kv_caches;
Pan Zezhong's avatar
Pan Zezhong committed
336
    model->req.output = output;
PanZezhong's avatar
init  
PanZezhong committed
337
338
339
340
341
342
343
344
    model->req.temperature = temperature;
    model->req.topk = topk;
    model->req.topp = topp;

    for (size_t idev = 0; idev < model->dev_ids.size(); idev++) {
        std::unique_lock<std::mutex> lock(model->states[idev].mtx);
        model->states[idev].proceed = true;
        lock.unlock();
PanZezhong's avatar
PanZezhong committed
345
346
347
348
349
350
351
        model->states[idev].cv_start.notify_one();
    }
    for (size_t i = model->dev_ids.size(); i > 0; i--) {
        auto idev = i - 1;
        std::unique_lock<std::mutex> lock(model->states[idev].mtx);
        model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); });
        lock.unlock();
PanZezhong's avatar
init  
PanZezhong committed
352
353
354
355
356
    }
}

void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceResource *rsrc, InferState &state, InferRequest &req,
                  infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) {
wooway777's avatar
wooway777 committed
357
358
359
    CacheManager cache_manager(100);
    InferenceContext ctx(rsrc, &cache_manager, rsrc->stream);

PanZezhong's avatar
PanZezhong committed
360
    // Create Device Resource
PanZezhong's avatar
init  
PanZezhong committed
361
    createDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm);
PanZezhong's avatar
PanZezhong committed
362
363
364
365
366
367
368
369
    {
        std::unique_lock<std::mutex> lock(state.mtx);
        state.loaded = true;
        lock.unlock();
        state.cv_load.notify_one();
    }

    // Infer Loop
PanZezhong's avatar
init  
PanZezhong committed
370
371
    while (true) {
        std::unique_lock<std::mutex> lock(state.mtx);
PanZezhong's avatar
PanZezhong committed
372
        state.cv_start.wait(lock, [&] { return state.proceed || state.exit_flag; });
PanZezhong's avatar
PanZezhong committed
373
        // quit if exit_flag is set
PanZezhong's avatar
init  
PanZezhong committed
374
375
376
377
        if (state.exit_flag) {
            break;
        }

wooway777's avatar
wooway777 committed
378
379
380
381
        inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok,
                         req.req_lens, req.nreq, req.req_pos, req.kv_caches,
                         req.temperature, req.topk, req.topp, req.output,
                         ctx);
PanZezhong's avatar
init  
PanZezhong committed
382
383
384

        state.proceed = false;
        lock.unlock();
PanZezhong's avatar
PanZezhong committed
385
        state.cv_done.notify_one();
PanZezhong's avatar
init  
PanZezhong committed
386
387
    }

PanZezhong's avatar
PanZezhong committed
388
389
    // Clean-Up
    releaseDeviceResource(*rsrc);
PanZezhong's avatar
init  
PanZezhong committed
390
391
}

PanZezhong's avatar
PanZezhong committed
392
JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infiniDevice_t device_, std::vector<int> device_ids) : meta(*_meta) {
PanZezhong's avatar
init  
PanZezhong committed
393
    int ndev = int(device_ids.size());
PanZezhong's avatar
PanZezhong committed
394
    device = device_;
PanZezhong's avatar
init  
PanZezhong committed
395
396
397
398
399
400
401
402
403
404
405
406
407
    dev_ids = device_ids;
    dev_resources = std::vector<DeviceResource>(ndev);
    states = std::vector<InferState>(ndev);
    threads.resize(ndev);
    RUN_INFINI(infinirtInit());
    auto comms = std::vector<infinicclComm_t>(ndev, nullptr);
    if (ndev > 1) {
        RUN_INFINI(infinicclCommInitAll(device, comms.data(), ndev, dev_ids.data()));
    }

    for (int i = 0; i < ndev; i++) {
        threads[i] = std::thread(launchDevice, std::cref(meta), weights, &dev_resources[i], std::ref(states[i]), std::ref(req), device, i, ndev, dev_ids[i], comms[i]);
    }
PanZezhong's avatar
PanZezhong committed
408
409
410
411
412
    for (int i = 0; i < ndev; i++) {
        std::unique_lock<std::mutex> lock(states[i].mtx);
        states[i].cv_load.wait(lock, [&] { return states[i].loaded; });
        lock.unlock();
    }
PanZezhong's avatar
init  
PanZezhong committed
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
}

__C struct JiugeModel *
createJiugeModel(const JiugeMeta *meta,
                 const JiugeWeights *weights,
                 infiniDevice_t device,
                 int ndev,
                 const int *dev_ids) {
    std::vector<int> device_ids(ndev);
    std::copy(dev_ids, dev_ids + ndev, device_ids.begin());
    JiugeModel *model = new JiugeModel(meta, weights, device, device_ids);
    return model;
}

__C void destroyJiugeModel(struct JiugeModel *model) {
    auto ndev = model->dev_resources.size();

    for (size_t idev = 0; idev < ndev; idev++) {
        std::unique_lock<std::mutex> lock(model->states[idev].mtx);
        model->states[idev].exit_flag = true;
        lock.unlock();
PanZezhong's avatar
PanZezhong committed
434
        model->states[idev].cv_start.notify_one();
PanZezhong's avatar
init  
PanZezhong committed
435
436
437
438
439
440
441
442
    }

    for (size_t idev = 0; idev < ndev; idev++) {
        model->threads[idev].join();
    }

    delete model;
}