jiuge_impl.hpp 1.62 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
4
5
#ifndef JIUGE_IMPL_H
#define JIUGE_IMPL_H

#include "infinicore_infer.h"

PanZezhong's avatar
PanZezhong committed
6
#include "../../allocator.hpp"
PanZezhong's avatar
init  
PanZezhong committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include "../../tensor.hpp"

#include <condition_variable>
#include <memory>
#include <mutex>
#include <thread>
#include <vector>

struct DeviceResource {
    // Device
    infiniDevice_t device;
    int device_id;
    infiniopHandle_t handle;
    // Weights
    std::shared_ptr<Tensor> w_in_embd, w_out_norm, w_out_embd, sin_table,
        cos_table;
    std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out,
        w_ffn_norm, w_ffn_gate_up, w_ffn_down;
    // Streams
    infinirtStream_t stream;
PanZezhong's avatar
PanZezhong committed
27
    // Communicator
PanZezhong's avatar
init  
PanZezhong committed
28
    infinicclComm_t comm;
PanZezhong's avatar
PanZezhong committed
29

thatPepe's avatar
thatPepe committed
30
    std::shared_ptr<MemoryPool> memory_pool;
PanZezhong's avatar
init  
PanZezhong committed
31
32
33
34
};

struct InferState {
    std::mutex mtx;
PanZezhong's avatar
PanZezhong committed
35
36
    std::condition_variable cv_load, cv_start, cv_done;
    bool loaded = false;
PanZezhong's avatar
init  
PanZezhong committed
37
38
39
40
41
42
43
44
45
46
47
    bool proceed = false;
    bool exit_flag = false;
};

struct InferRequest {
    const uint32_t *tokens;
    uint32_t ntok;
    const uint32_t *req_lens;
    uint32_t nreq;
    const uint32_t *req_pos;
    struct KVCache **kv_caches;
Pan Zezhong's avatar
Pan Zezhong committed
48
49
50
51
    const float *temperature;
    const uint32_t *topk;
    const float *topp;
    uint32_t *output;
PanZezhong's avatar
PanZezhong committed
52
    void *logits;
PanZezhong's avatar
init  
PanZezhong committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
};

struct JiugeModel {
    JiugeMeta meta;
    infiniDevice_t device;
    std::vector<int> dev_ids;
    std::vector<DeviceResource> dev_resources;
    std::vector<InferState> states;
    std::vector<std::thread> threads;
    InferRequest req;

    JiugeModel(const JiugeMeta *, const JiugeWeights *, infiniDevice_t device, std::vector<int> device_ids);
};

struct KVCache {
    std::vector<std::vector<std::shared_ptr<Tensor>>> k, v;
};

#endif