jiuge_impl.hpp 1.58 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
4
5
#ifndef JIUGE_IMPL_H
#define JIUGE_IMPL_H

#include "infinicore_infer.h"

PanZezhong's avatar
PanZezhong committed
6
#include "../../allocator.hpp"
PanZezhong's avatar
init  
PanZezhong committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include "../../tensor.hpp"

#include <condition_variable>
#include <memory>
#include <mutex>
#include <thread>
#include <vector>

struct DeviceResource {
    // Device
    infiniDevice_t device;
    int device_id;
    infiniopHandle_t handle;
    // Weights
    std::shared_ptr<Tensor> w_in_embd, w_out_norm, w_out_embd, sin_table,
        cos_table;
    std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out,
        w_ffn_norm, w_ffn_gate_up, w_ffn_down;
    // Streams
    infinirtStream_t stream;
PanZezhong's avatar
PanZezhong committed
27
    // Communicator
PanZezhong's avatar
init  
PanZezhong committed
28
    infinicclComm_t comm;
PanZezhong's avatar
PanZezhong committed
29

thatPepe's avatar
thatPepe committed
30
    std::shared_ptr<MemoryPool> memory_pool;
PanZezhong's avatar
init  
PanZezhong committed
31
32
33
34
};

struct InferState {
    std::mutex mtx;
PanZezhong's avatar
PanZezhong committed
35
36
    std::condition_variable cv_load, cv_start, cv_done;
    bool loaded = false;
PanZezhong's avatar
init  
PanZezhong committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    bool proceed = false;
    bool exit_flag = false;
};

struct InferRequest {
    const uint32_t *tokens;
    uint32_t ntok;
    const uint32_t *req_lens;
    uint32_t nreq;
    const uint32_t *req_pos;
    struct KVCache **kv_caches;
    uint32_t *ans;
    float temperature;
    uint32_t topk;
    float topp;
};

struct JiugeModel {
    JiugeMeta meta;
    infiniDevice_t device;
    std::vector<int> dev_ids;
    std::vector<DeviceResource> dev_resources;
    std::vector<InferState> states;
    std::vector<std::thread> threads;
    InferRequest req;

    JiugeModel(const JiugeMeta *, const JiugeWeights *, infiniDevice_t device, std::vector<int> device_ids);
};

struct KVCache {
    std::vector<std::vector<std::shared_ptr<Tensor>>> k, v;
};

#endif