inference_context.hpp 7.04 KB
Newer Older
wooway777's avatar
wooway777 committed
1
2
#pragma once

blkmjsian's avatar
blkmjsian committed
3
4
#include "../cache_manager/opcache_manager.hpp"

5
#include <cassert>
wooway777's avatar
wooway777 committed
6
7

struct InferenceContext {
blkmjsian's avatar
blkmjsian committed
8
9
    infiniopHandle_t op_handle;
    std::shared_ptr<MemoryPool> memory_pool;
wooway777's avatar
wooway777 committed
10
11
12
13
14
    CacheManager *cache_manager;
    infinirtStream_t stream;
    std::shared_ptr<Storage> workspace_storage;
    size_t current_workspace_size = 0;

blkmjsian's avatar
blkmjsian committed
15
    InferenceContext(infiniopHandle_t op_handle, std::shared_ptr<MemoryPool> memory_pool, CacheManager *cache_manager, infinirtStream_t stream);
wooway777's avatar
wooway777 committed
16
17

    void ensure_workspace(size_t required_size);
18

19
20
21
    void add(std::shared_ptr<Tensor> c,
             std::shared_ptr<Tensor> a,
             std::shared_ptr<Tensor> b);
hejianlin's avatar
hejianlin committed
22
23
24
25
26
27
28
29
    void conv(std::shared_ptr<Tensor> y,
              std::shared_ptr<Tensor> x,
              std::shared_ptr<Tensor> w,
              std::shared_ptr<Tensor> bias,
              void *pads, void *strides, void *dilations, size_t n);
    void mul(std::shared_ptr<Tensor> c,
             std::shared_ptr<Tensor> a,
             std::shared_ptr<Tensor> b);
wooway777's avatar
wooway777 committed
30
31
32
33
    void rmsnorm(std::shared_ptr<Tensor> y,
                 std::shared_ptr<Tensor> x,
                 std::shared_ptr<Tensor> w,
                 float epsilon);
34
35
36
    void gemm(std::shared_ptr<Tensor> c,
              std::shared_ptr<Tensor> a,
              std::shared_ptr<Tensor> b,
wooway777's avatar
wooway777 committed
37
              float alpha, float beta);
38
39
    void rearrange(std::shared_ptr<Tensor> dst,
                   std::shared_ptr<Tensor> src);
wooway777's avatar
wooway777 committed
40
41
42
43
    void rope(std::shared_ptr<Tensor> q,
              std::shared_ptr<Tensor> k,
              std::shared_ptr<Tensor> pos,
              std::shared_ptr<Tensor> sin,
PanZezhong1725's avatar
PanZezhong1725 committed
44
45
              std::shared_ptr<Tensor> cos,
              infiniopRoPEAlgo_t algo);
46
47
    void causalSoftmax(std::shared_ptr<Tensor> y,
                       std::shared_ptr<Tensor> x);
blkmjsian's avatar
blkmjsian committed
48
49
50
51
52
53
54
55

    void topkrouter(std::shared_ptr<Tensor> values,  // F32
                    std::shared_ptr<Tensor> indices, // I32
                    std::shared_ptr<Tensor> x,
                    std::shared_ptr<Tensor> correction_bias, // F32
                    float routed_scaling_factor,
                    size_t topk);

56
57
58
    void swiglu(std::shared_ptr<Tensor> out,
                std::shared_ptr<Tensor> up,
                std::shared_ptr<Tensor> gate);
hejianlin's avatar
hejianlin committed
59
60
    void silu(std::shared_ptr<Tensor> out,
              std::shared_ptr<Tensor> input);
61
62
    void randomSample(std::shared_ptr<Tensor> out,
                      std::shared_ptr<Tensor> prob,
wooway777's avatar
wooway777 committed
63
                      float random_val, float top_p, uint32_t top_k, float temperature);
64
65
66
67
68

    void linear(std::shared_ptr<Tensor> c,
                std::shared_ptr<Tensor> a,
                std::shared_ptr<Tensor> b,
                float alpha, float beta,
69
70
                std::shared_ptr<Tensor> residual,
                std::shared_ptr<Tensor> bias);
blkmjsian's avatar
blkmjsian committed
71
72
73
74
    void dequant(std::shared_ptr<Tensor> weight,
                 std::shared_ptr<Tensor> in_w,
                 std::shared_ptr<Tensor> in_s,
                 std::shared_ptr<Tensor> in_z);
wooway777's avatar
wooway777 committed
75
};
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

namespace {
thread_local InferenceContext *tls_inference_context = nullptr;
}

inline InferenceContext &getInferenceContext() {
    assert(tls_inference_context != nullptr && "InferenceContext not set for this thread");
    return *tls_inference_context;
}

inline void setInferenceContext(InferenceContext *ctx) {
    tls_inference_context = ctx;
}

inline void add(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a, std::shared_ptr<Tensor> b) {
    getInferenceContext().add(c, a, b);
}

hejianlin's avatar
hejianlin committed
94
inline void conv(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x, std::shared_ptr<Tensor> w, std::shared_ptr<Tensor> bias,
PanZezhong's avatar
PanZezhong committed
95
                 void *pads, void *strides, void *dilations, size_t n) {
hejianlin's avatar
hejianlin committed
96
97
98
99
100
101
102
    getInferenceContext().conv(y, x, w, bias, pads, strides, dilations, n);
}

inline void mul(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a, std::shared_ptr<Tensor> b) {
    getInferenceContext().mul(c, a, b);
}

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
inline void rmsnorm(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x,
                    std::shared_ptr<Tensor> w, float epsilon) {
    getInferenceContext().rmsnorm(y, x, w, epsilon);
}

inline void gemm(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
                 std::shared_ptr<Tensor> b, float alpha, float beta) {
    getInferenceContext().gemm(c, a, b, alpha, beta);
}

inline void rearrange(std::shared_ptr<Tensor> dst, std::shared_ptr<Tensor> src) {
    getInferenceContext().rearrange(dst, src);
}

inline void rope(std::shared_ptr<Tensor> q, std::shared_ptr<Tensor> k,
                 std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> sin,
                 std::shared_ptr<Tensor> cos) {
PanZezhong1725's avatar
PanZezhong1725 committed
120
    getInferenceContext().rope(q, k, pos, sin, cos, INFINIOP_ROPE_ALGO_GPT_J);
121
122
}

blkmjsian's avatar
blkmjsian committed
123
124
125
inline void rope_v2(std::shared_ptr<Tensor> q, std::shared_ptr<Tensor> k,
                    std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> sin,
                    std::shared_ptr<Tensor> cos) {
PanZezhong1725's avatar
PanZezhong1725 committed
126
    getInferenceContext().rope(q, k, pos, sin, cos, INFINIOP_ROPE_ALGO_GPT_NEOX);
blkmjsian's avatar
blkmjsian committed
127
128
}

129
130
131
132
inline void causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x) {
    getInferenceContext().causalSoftmax(y, x);
}

blkmjsian's avatar
blkmjsian committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
inline void topkrouter(std::shared_ptr<Tensor> values,  // F32
                       std::shared_ptr<Tensor> indices, // I32
                       std::shared_ptr<Tensor> x,
                       std::shared_ptr<Tensor> correction_bias, // F32
                       float routed_scaling_factor,
                       size_t topk) {

    getInferenceContext().topkrouter(values,  // F32
                                     indices, // I32
                                     x,
                                     correction_bias, // F32
                                     routed_scaling_factor,
                                     topk);
}

148
149
150
151
152
inline void swiglu(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> up,
                   std::shared_ptr<Tensor> gate) {
    getInferenceContext().swiglu(out, up, gate);
}

hejianlin's avatar
hejianlin committed
153
154
155
156
inline void silu(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> input) {
    getInferenceContext().silu(out, input);
}

157
158
159
160
161
162
163
inline void randomSample(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> prob,
                         float random_val, float top_p, uint32_t top_k, float temperature) {
    getInferenceContext().randomSample(out, prob, random_val, top_p, top_k, temperature);
}

inline void linear(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
                   std::shared_ptr<Tensor> b, float alpha, float beta,
164
165
                   std::shared_ptr<Tensor> residual, std::shared_ptr<Tensor> bias) {
    getInferenceContext().linear(c, a, b, alpha, beta, residual, bias);
166
}
blkmjsian's avatar
blkmjsian committed
167
168
169
170
171
172
173
174

inline void dequant_linear(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> x,
                           std::shared_ptr<Tensor> w_w, std::shared_ptr<Tensor> w_s, std::shared_ptr<Tensor> w_z,
                           float alpha, float beta, std::shared_ptr<Tensor> residual, std::shared_ptr<Tensor> bias) {
    auto w = Tensor::buffer(x->dtype(), {x->shape()[1], out->shape()[1]}, getInferenceContext().memory_pool);
    getInferenceContext().dequant(w, w_w, w_s, w_z);
    getInferenceContext().linear(out, x, w, alpha, beta, residual, bias);
}