inference_context.hpp 6.07 KB
Newer Older
wooway777's avatar
wooway777 committed
1
2
#pragma once

blkmjsian's avatar
blkmjsian committed
3
4
#include "../cache_manager/opcache_manager.hpp"

5
#include <cassert>
wooway777's avatar
wooway777 committed
6
7

struct InferenceContext {
blkmjsian's avatar
blkmjsian committed
8
9
    infiniopHandle_t op_handle;
    std::shared_ptr<MemoryPool> memory_pool;
wooway777's avatar
wooway777 committed
10
11
12
13
14
    CacheManager *cache_manager;
    infinirtStream_t stream;
    std::shared_ptr<Storage> workspace_storage;
    size_t current_workspace_size = 0;

blkmjsian's avatar
blkmjsian committed
15
    InferenceContext(infiniopHandle_t op_handle, std::shared_ptr<MemoryPool> memory_pool, CacheManager *cache_manager, infinirtStream_t stream);
wooway777's avatar
wooway777 committed
16
17

    void ensure_workspace(size_t required_size);
18

19
20
21
    void add(std::shared_ptr<Tensor> c,
             std::shared_ptr<Tensor> a,
             std::shared_ptr<Tensor> b);
wooway777's avatar
wooway777 committed
22
23
24
25
    void rmsnorm(std::shared_ptr<Tensor> y,
                 std::shared_ptr<Tensor> x,
                 std::shared_ptr<Tensor> w,
                 float epsilon);
26
27
28
    void gemm(std::shared_ptr<Tensor> c,
              std::shared_ptr<Tensor> a,
              std::shared_ptr<Tensor> b,
wooway777's avatar
wooway777 committed
29
              float alpha, float beta);
30
31
    void rearrange(std::shared_ptr<Tensor> dst,
                   std::shared_ptr<Tensor> src);
wooway777's avatar
wooway777 committed
32
33
34
35
    void rope(std::shared_ptr<Tensor> q,
              std::shared_ptr<Tensor> k,
              std::shared_ptr<Tensor> pos,
              std::shared_ptr<Tensor> sin,
PanZezhong1725's avatar
PanZezhong1725 committed
36
37
              std::shared_ptr<Tensor> cos,
              infiniopRoPEAlgo_t algo);
38
39
    void causalSoftmax(std::shared_ptr<Tensor> y,
                       std::shared_ptr<Tensor> x);
blkmjsian's avatar
blkmjsian committed
40
41
42
43
44
45
46
47

    void topkrouter(std::shared_ptr<Tensor> values,  // F32
                    std::shared_ptr<Tensor> indices, // I32
                    std::shared_ptr<Tensor> x,
                    std::shared_ptr<Tensor> correction_bias, // F32
                    float routed_scaling_factor,
                    size_t topk);

48
49
50
51
52
    void swiglu(std::shared_ptr<Tensor> out,
                std::shared_ptr<Tensor> up,
                std::shared_ptr<Tensor> gate);
    void randomSample(std::shared_ptr<Tensor> out,
                      std::shared_ptr<Tensor> prob,
wooway777's avatar
wooway777 committed
53
                      float random_val, float top_p, uint32_t top_k, float temperature);
54
55
56
57
58

    void linear(std::shared_ptr<Tensor> c,
                std::shared_ptr<Tensor> a,
                std::shared_ptr<Tensor> b,
                float alpha, float beta,
59
60
                std::shared_ptr<Tensor> residual,
                std::shared_ptr<Tensor> bias);
blkmjsian's avatar
blkmjsian committed
61
62
63
64
    void dequant(std::shared_ptr<Tensor> weight,
                 std::shared_ptr<Tensor> in_w,
                 std::shared_ptr<Tensor> in_s,
                 std::shared_ptr<Tensor> in_z);
wooway777's avatar
wooway777 committed
65
};
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

namespace {
thread_local InferenceContext *tls_inference_context = nullptr;
}

inline InferenceContext &getInferenceContext() {
    assert(tls_inference_context != nullptr && "InferenceContext not set for this thread");
    return *tls_inference_context;
}

inline void setInferenceContext(InferenceContext *ctx) {
    tls_inference_context = ctx;
}

inline void add(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a, std::shared_ptr<Tensor> b) {
    getInferenceContext().add(c, a, b);
}

inline void rmsnorm(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x,
                    std::shared_ptr<Tensor> w, float epsilon) {
    getInferenceContext().rmsnorm(y, x, w, epsilon);
}

inline void gemm(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
                 std::shared_ptr<Tensor> b, float alpha, float beta) {
    getInferenceContext().gemm(c, a, b, alpha, beta);
}

inline void rearrange(std::shared_ptr<Tensor> dst, std::shared_ptr<Tensor> src) {
    getInferenceContext().rearrange(dst, src);
}

inline void rope(std::shared_ptr<Tensor> q, std::shared_ptr<Tensor> k,
                 std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> sin,
                 std::shared_ptr<Tensor> cos) {
PanZezhong1725's avatar
PanZezhong1725 committed
101
    getInferenceContext().rope(q, k, pos, sin, cos, INFINIOP_ROPE_ALGO_GPT_J);
102
103
}

blkmjsian's avatar
blkmjsian committed
104
105
106
inline void rope_v2(std::shared_ptr<Tensor> q, std::shared_ptr<Tensor> k,
                    std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> sin,
                    std::shared_ptr<Tensor> cos) {
PanZezhong1725's avatar
PanZezhong1725 committed
107
    getInferenceContext().rope(q, k, pos, sin, cos, INFINIOP_ROPE_ALGO_GPT_NEOX);
blkmjsian's avatar
blkmjsian committed
108
109
}

110
111
112
113
inline void causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x) {
    getInferenceContext().causalSoftmax(y, x);
}

blkmjsian's avatar
blkmjsian committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
inline void topkrouter(std::shared_ptr<Tensor> values,  // F32
                       std::shared_ptr<Tensor> indices, // I32
                       std::shared_ptr<Tensor> x,
                       std::shared_ptr<Tensor> correction_bias, // F32
                       float routed_scaling_factor,
                       size_t topk) {

    getInferenceContext().topkrouter(values,  // F32
                                     indices, // I32
                                     x,
                                     correction_bias, // F32
                                     routed_scaling_factor,
                                     topk);
}

129
130
131
132
133
134
135
136
137
138
139
140
inline void swiglu(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> up,
                   std::shared_ptr<Tensor> gate) {
    getInferenceContext().swiglu(out, up, gate);
}

inline void randomSample(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> prob,
                         float random_val, float top_p, uint32_t top_k, float temperature) {
    getInferenceContext().randomSample(out, prob, random_val, top_p, top_k, temperature);
}

inline void linear(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
                   std::shared_ptr<Tensor> b, float alpha, float beta,
141
142
                   std::shared_ptr<Tensor> residual, std::shared_ptr<Tensor> bias) {
    getInferenceContext().linear(c, a, b, alpha, beta, residual, bias);
143
}
blkmjsian's avatar
blkmjsian committed
144
145
146
147
148
149
150
151

inline void dequant_linear(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> x,
                           std::shared_ptr<Tensor> w_w, std::shared_ptr<Tensor> w_s, std::shared_ptr<Tensor> w_z,
                           float alpha, float beta, std::shared_ptr<Tensor> residual, std::shared_ptr<Tensor> bias) {
    auto w = Tensor::buffer(x->dtype(), {x->shape()[1], out->shape()[1]}, getInferenceContext().memory_pool);
    getInferenceContext().dequant(w, w_w, w_s, w_z);
    getInferenceContext().linear(out, x, w, alpha, beta, residual, bias);
}