inference_context.hpp 6.2 KB
Newer Older
wooway777's avatar
wooway777 committed
1
2
#pragma once

blkmjsian's avatar
blkmjsian committed
3
4
#include "../cache_manager/opcache_manager.hpp"

5
#include <cassert>
wooway777's avatar
wooway777 committed
6
7

struct InferenceContext {
blkmjsian's avatar
blkmjsian committed
8
9
    infiniopHandle_t op_handle;
    std::shared_ptr<MemoryPool> memory_pool;
wooway777's avatar
wooway777 committed
10
11
12
13
14
    CacheManager *cache_manager;
    infinirtStream_t stream;
    std::shared_ptr<Storage> workspace_storage;
    size_t current_workspace_size = 0;

blkmjsian's avatar
blkmjsian committed
15
    InferenceContext(infiniopHandle_t op_handle, std::shared_ptr<MemoryPool> memory_pool, CacheManager *cache_manager, infinirtStream_t stream);
wooway777's avatar
wooway777 committed
16
17

    void ensure_workspace(size_t required_size);
18

19
20
21
    void add(std::shared_ptr<Tensor> c,
             std::shared_ptr<Tensor> a,
             std::shared_ptr<Tensor> b);
wooway777's avatar
wooway777 committed
22
23
24
25
    void rmsnorm(std::shared_ptr<Tensor> y,
                 std::shared_ptr<Tensor> x,
                 std::shared_ptr<Tensor> w,
                 float epsilon);
26
27
28
    void gemm(std::shared_ptr<Tensor> c,
              std::shared_ptr<Tensor> a,
              std::shared_ptr<Tensor> b,
wooway777's avatar
wooway777 committed
29
              float alpha, float beta);
30
31
    void rearrange(std::shared_ptr<Tensor> dst,
                   std::shared_ptr<Tensor> src);
wooway777's avatar
wooway777 committed
32
33
34
35
36
    void rope(std::shared_ptr<Tensor> q,
              std::shared_ptr<Tensor> k,
              std::shared_ptr<Tensor> pos,
              std::shared_ptr<Tensor> sin,
              std::shared_ptr<Tensor> cos);
blkmjsian's avatar
blkmjsian committed
37
38
39
40
41
    void rope_v2(std::shared_ptr<Tensor> q,
                 std::shared_ptr<Tensor> k,
                 std::shared_ptr<Tensor> pos,
                 std::shared_ptr<Tensor> sin,
                 std::shared_ptr<Tensor> cos);
42
43
    void causalSoftmax(std::shared_ptr<Tensor> y,
                       std::shared_ptr<Tensor> x);
blkmjsian's avatar
blkmjsian committed
44
45
46
47
48
49
50
51

    void topkrouter(std::shared_ptr<Tensor> values,  // F32
                    std::shared_ptr<Tensor> indices, // I32
                    std::shared_ptr<Tensor> x,
                    std::shared_ptr<Tensor> correction_bias, // F32
                    float routed_scaling_factor,
                    size_t topk);

52
53
54
55
56
    void swiglu(std::shared_ptr<Tensor> out,
                std::shared_ptr<Tensor> up,
                std::shared_ptr<Tensor> gate);
    void randomSample(std::shared_ptr<Tensor> out,
                      std::shared_ptr<Tensor> prob,
wooway777's avatar
wooway777 committed
57
                      float random_val, float top_p, uint32_t top_k, float temperature);
58
59
60
61
62

    void linear(std::shared_ptr<Tensor> c,
                std::shared_ptr<Tensor> a,
                std::shared_ptr<Tensor> b,
                float alpha, float beta,
63
64
                std::shared_ptr<Tensor> residual,
                std::shared_ptr<Tensor> bias);
blkmjsian's avatar
blkmjsian committed
65
66
67
68
    void dequant(std::shared_ptr<Tensor> weight,
                 std::shared_ptr<Tensor> in_w,
                 std::shared_ptr<Tensor> in_s,
                 std::shared_ptr<Tensor> in_z);
wooway777's avatar
wooway777 committed
69
};
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

namespace {
thread_local InferenceContext *tls_inference_context = nullptr;
}

inline InferenceContext &getInferenceContext() {
    assert(tls_inference_context != nullptr && "InferenceContext not set for this thread");
    return *tls_inference_context;
}

inline void setInferenceContext(InferenceContext *ctx) {
    tls_inference_context = ctx;
}

inline void add(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a, std::shared_ptr<Tensor> b) {
    getInferenceContext().add(c, a, b);
}

inline void rmsnorm(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x,
                    std::shared_ptr<Tensor> w, float epsilon) {
    getInferenceContext().rmsnorm(y, x, w, epsilon);
}

inline void gemm(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
                 std::shared_ptr<Tensor> b, float alpha, float beta) {
    getInferenceContext().gemm(c, a, b, alpha, beta);
}

inline void rearrange(std::shared_ptr<Tensor> dst, std::shared_ptr<Tensor> src) {
    getInferenceContext().rearrange(dst, src);
}

inline void rope(std::shared_ptr<Tensor> q, std::shared_ptr<Tensor> k,
                 std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> sin,
                 std::shared_ptr<Tensor> cos) {
    getInferenceContext().rope(q, k, pos, sin, cos);
}

blkmjsian's avatar
blkmjsian committed
108
109
110
111
112
113
inline void rope_v2(std::shared_ptr<Tensor> q, std::shared_ptr<Tensor> k,
                    std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> sin,
                    std::shared_ptr<Tensor> cos) {
    getInferenceContext().rope_v2(q, k, pos, sin, cos);
}

114
115
116
117
inline void causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x) {
    getInferenceContext().causalSoftmax(y, x);
}

blkmjsian's avatar
blkmjsian committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
inline void topkrouter(std::shared_ptr<Tensor> values,  // F32
                       std::shared_ptr<Tensor> indices, // I32
                       std::shared_ptr<Tensor> x,
                       std::shared_ptr<Tensor> correction_bias, // F32
                       float routed_scaling_factor,
                       size_t topk) {

    getInferenceContext().topkrouter(values,  // F32
                                     indices, // I32
                                     x,
                                     correction_bias, // F32
                                     routed_scaling_factor,
                                     topk);
}

133
134
135
136
137
138
139
140
141
142
143
144
inline void swiglu(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> up,
                   std::shared_ptr<Tensor> gate) {
    getInferenceContext().swiglu(out, up, gate);
}

inline void randomSample(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> prob,
                         float random_val, float top_p, uint32_t top_k, float temperature) {
    getInferenceContext().randomSample(out, prob, random_val, top_p, top_k, temperature);
}

inline void linear(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
                   std::shared_ptr<Tensor> b, float alpha, float beta,
145
146
                   std::shared_ptr<Tensor> residual, std::shared_ptr<Tensor> bias) {
    getInferenceContext().linear(c, a, b, alpha, beta, residual, bias);
147
}
blkmjsian's avatar
blkmjsian committed
148
149
150
151
152
153
154
155

inline void dequant_linear(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> x,
                           std::shared_ptr<Tensor> w_w, std::shared_ptr<Tensor> w_s, std::shared_ptr<Tensor> w_z,
                           float alpha, float beta, std::shared_ptr<Tensor> residual, std::shared_ptr<Tensor> bias) {
    auto w = Tensor::buffer(x->dtype(), {x->shape()[1], out->shape()[1]}, getInferenceContext().memory_pool);
    getInferenceContext().dequant(w, w_w, w_s, w_z);
    getInferenceContext().linear(out, x, w, alpha, beta, residual, bias);
}