inference_context.hpp 3.92 KB
Newer Older
wooway777's avatar
wooway777 committed
1
2
3
4
5
#pragma once

#include "cache_manager.hpp"
#include "jiuge/jiuge_impl.hpp"
#include "jiuge/jiuge_weight.hpp"
6
#include <cassert>
wooway777's avatar
wooway777 committed
7
8
9
10
11
12
13
14
15
16
17

struct InferenceContext {
    DeviceResource *rsrc;
    CacheManager *cache_manager;
    infinirtStream_t stream;
    std::shared_ptr<Storage> workspace_storage;
    size_t current_workspace_size = 0;

    InferenceContext(DeviceResource *rsrc, CacheManager *cache_manager, infinirtStream_t stream);

    void ensure_workspace(size_t required_size);
18

19
20
21
    void add(std::shared_ptr<Tensor> c,
             std::shared_ptr<Tensor> a,
             std::shared_ptr<Tensor> b);
wooway777's avatar
wooway777 committed
22
23
24
25
    void rmsnorm(std::shared_ptr<Tensor> y,
                 std::shared_ptr<Tensor> x,
                 std::shared_ptr<Tensor> w,
                 float epsilon);
26
27
28
    void gemm(std::shared_ptr<Tensor> c,
              std::shared_ptr<Tensor> a,
              std::shared_ptr<Tensor> b,
wooway777's avatar
wooway777 committed
29
              float alpha, float beta);
30
31
    void rearrange(std::shared_ptr<Tensor> dst,
                   std::shared_ptr<Tensor> src);
wooway777's avatar
wooway777 committed
32
33
34
35
36
    void rope(std::shared_ptr<Tensor> q,
              std::shared_ptr<Tensor> k,
              std::shared_ptr<Tensor> pos,
              std::shared_ptr<Tensor> sin,
              std::shared_ptr<Tensor> cos);
37
38
39
40
41
42
43
    void causalSoftmax(std::shared_ptr<Tensor> y,
                       std::shared_ptr<Tensor> x);
    void swiglu(std::shared_ptr<Tensor> out,
                std::shared_ptr<Tensor> up,
                std::shared_ptr<Tensor> gate);
    void randomSample(std::shared_ptr<Tensor> out,
                      std::shared_ptr<Tensor> prob,
wooway777's avatar
wooway777 committed
44
                      float random_val, float top_p, uint32_t top_k, float temperature);
45
46
47
48
49
50

    void linear(std::shared_ptr<Tensor> c,
                std::shared_ptr<Tensor> a,
                std::shared_ptr<Tensor> b,
                float alpha, float beta,
                std::shared_ptr<Tensor> residual);
wooway777's avatar
wooway777 committed
51
};
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

namespace {
thread_local InferenceContext *tls_inference_context = nullptr;
}

inline InferenceContext &getInferenceContext() {
    assert(tls_inference_context != nullptr && "InferenceContext not set for this thread");
    return *tls_inference_context;
}

inline void setInferenceContext(InferenceContext *ctx) {
    tls_inference_context = ctx;
}

inline void add(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a, std::shared_ptr<Tensor> b) {
    getInferenceContext().add(c, a, b);
}

inline void rmsnorm(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x,
                    std::shared_ptr<Tensor> w, float epsilon) {
    getInferenceContext().rmsnorm(y, x, w, epsilon);
}

inline void gemm(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
                 std::shared_ptr<Tensor> b, float alpha, float beta) {
    getInferenceContext().gemm(c, a, b, alpha, beta);
}

inline void rearrange(std::shared_ptr<Tensor> dst, std::shared_ptr<Tensor> src) {
    getInferenceContext().rearrange(dst, src);
}

inline void rope(std::shared_ptr<Tensor> q, std::shared_ptr<Tensor> k,
                 std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> sin,
                 std::shared_ptr<Tensor> cos) {
    getInferenceContext().rope(q, k, pos, sin, cos);
}

inline void causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x) {
    getInferenceContext().causalSoftmax(y, x);
}

inline void swiglu(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> up,
                   std::shared_ptr<Tensor> gate) {
    getInferenceContext().swiglu(out, up, gate);
}

inline void randomSample(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> prob,
                         float random_val, float top_p, uint32_t top_k, float temperature) {
    getInferenceContext().randomSample(out, prob, random_val, top_p, top_k, temperature);
}

inline void linear(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
                   std::shared_ptr<Tensor> b, float alpha, float beta,
                   std::shared_ptr<Tensor> residual) {
    getInferenceContext().linear(c, a, b, alpha, beta, residual);
}