jiuge_impl.hpp 1.6 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
4
5
#ifndef JIUGE_IMPL_H
#define JIUGE_IMPL_H

#include "infinicore_infer.h"

PanZezhong's avatar
PanZezhong committed
6
#include "../../allocator.hpp"
PanZezhong's avatar
init  
PanZezhong committed
7
8
9
10
11
12
13
14
#include "../../tensor.hpp"

#include <condition_variable>
#include <memory>
#include <mutex>
#include <thread>
#include <vector>

blkmjsian's avatar
blkmjsian committed
15
struct JiugeDeviceResource {
PanZezhong's avatar
init  
PanZezhong committed
16
17
18
19
20
21
22
    // Device
    infiniDevice_t device;
    int device_id;
    infiniopHandle_t handle;
    // Weights
    std::shared_ptr<Tensor> w_in_embd, w_out_norm, w_out_embd, sin_table,
        cos_table;
23
    std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm,w_attn_out,
PanZezhong's avatar
init  
PanZezhong committed
24
25
26
        w_ffn_norm, w_ffn_gate_up, w_ffn_down;
    // Streams
    infinirtStream_t stream;
PanZezhong's avatar
PanZezhong committed
27
    // Communicator
PanZezhong's avatar
init  
PanZezhong committed
28
    infinicclComm_t comm;
PanZezhong's avatar
PanZezhong committed
29

thatPepe's avatar
thatPepe committed
30
    std::shared_ptr<MemoryPool> memory_pool;
PanZezhong's avatar
init  
PanZezhong committed
31
32
33
34
};

struct InferState {
    std::mutex mtx;
PanZezhong's avatar
PanZezhong committed
35
36
    std::condition_variable cv_load, cv_start, cv_done;
    bool loaded = false;
PanZezhong's avatar
init  
PanZezhong committed
37
38
39
40
41
42
43
44
45
46
47
    bool proceed = false;
    bool exit_flag = false;
};

struct InferRequest {
    const uint32_t *tokens;
    uint32_t ntok;
    const uint32_t *req_lens;
    uint32_t nreq;
    const uint32_t *req_pos;
    struct KVCache **kv_caches;
Pan Zezhong's avatar
Pan Zezhong committed
48
49
50
51
    const float *temperature;
    const uint32_t *topk;
    const float *topp;
    uint32_t *output;
PanZezhong's avatar
PanZezhong committed
52
    void *logits;
PanZezhong's avatar
init  
PanZezhong committed
53
54
55
56
57
58
};

struct JiugeModel {
    JiugeMeta meta;
    infiniDevice_t device;
    std::vector<int> dev_ids;
blkmjsian's avatar
blkmjsian committed
59
    std::vector<JiugeDeviceResource> dev_resources;
PanZezhong's avatar
init  
PanZezhong committed
60
61
62
63
64
65
66
    std::vector<InferState> states;
    std::vector<std::thread> threads;
    InferRequest req;

    JiugeModel(const JiugeMeta *, const JiugeWeights *, infiniDevice_t device, std::vector<int> device_ids);
};

blkmjsian's avatar
blkmjsian committed
67
#include "../../cache.hpp"
PanZezhong's avatar
init  
PanZezhong committed
68
69

#endif