inference_context.cpp 8.24 KB
Newer Older
wooway777's avatar
wooway777 committed
1
2
3
4
5
6
7
8
#include "inference_context.hpp"
#include "../tensor.hpp"
#include "../utils.hpp"

InferenceContext::InferenceContext(DeviceResource *rsrc, CacheManager *cache_manager, infinirtStream_t stream)
    : rsrc(rsrc), cache_manager(cache_manager), stream(stream) {}

void InferenceContext::ensure_workspace(size_t required_size) {
9
    if (required_size > current_workspace_size || !workspace_storage) {
wooway777's avatar
wooway777 committed
10
11
12
13
14
        workspace_storage = Storage::createFromPool(required_size, rsrc->memory_pool);
        current_workspace_size = required_size;
    }
}

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
void InferenceContext::add(std::shared_ptr<Tensor> c,
                           std::shared_ptr<Tensor> a,
                           std::shared_ptr<Tensor> b) {
    size_t key = CacheManager::createDescriptorKey(c, a, b,
                                                   nullptr, nullptr);

    infiniopAddDescriptor_t desc;
    if (!cache_manager->getAddDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateAddDescriptor(rsrc->handle, &desc, c->desc(), a->desc(), b->desc()));
        cache_manager->putAddDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetAddWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopAdd(
        desc, workspace, workspace_size,
        c->data(), a->data(), b->data(), stream));
}

wooway777's avatar
wooway777 committed
37
38
39
40
void InferenceContext::rmsnorm(std::shared_ptr<Tensor> y,
                               std::shared_ptr<Tensor> x,
                               std::shared_ptr<Tensor> w,
                               float epsilon) {
41
    size_t key = CacheManager::createDescriptorKey(y, x, w, nullptr, nullptr);
wooway777's avatar
wooway777 committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

    infiniopRMSNormDescriptor_t desc;
    if (!cache_manager->getRMSNormDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateRMSNormDescriptor(
            rsrc->handle, &desc, y->desc(), x->desc(), w->desc(), epsilon));
        cache_manager->putRMSNormDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetRMSNormWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopRMSNorm(
        desc, workspace, workspace_size,
        y->data(), x->data(), w->data(), stream));
}

60
61
62
void InferenceContext::gemm(std::shared_ptr<Tensor> c,
                            std::shared_ptr<Tensor> a,
                            std::shared_ptr<Tensor> b,
wooway777's avatar
wooway777 committed
63
                            float alpha, float beta) {
64
65
    size_t key = CacheManager::createDescriptorKey(c, a, b,
                                                   nullptr, nullptr);
wooway777's avatar
wooway777 committed
66
67
68

    infiniopGemmDescriptor_t desc;
    if (!cache_manager->getGemmDescriptor(key, desc)) {
69
        RUN_INFINI(infiniopCreateGemmDescriptor(rsrc->handle, &desc, c->desc(), a->desc(), b->desc()));
wooway777's avatar
wooway777 committed
70
71
72
73
74
75
76
77
78
79
80
81
82
        cache_manager->putGemmDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetGemmWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopGemm(
        desc, workspace, workspace_size,
        c->data(), a->data(), b->data(), alpha, beta, stream));
}

83
84
85
void InferenceContext::rearrange(std::shared_ptr<Tensor> dst,
                                 std::shared_ptr<Tensor> src) {
    size_t key = CacheManager::createDescriptorKey(dst, src, nullptr, nullptr, nullptr);
wooway777's avatar
wooway777 committed
86
87
88

    infiniopRearrangeDescriptor_t desc;
    if (!cache_manager->getRearrangeDescriptor(key, desc)) {
89
        RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc->handle, &desc, dst->desc(), src->desc()));
wooway777's avatar
wooway777 committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        cache_manager->putRearrangeDescriptor(key, desc);
    }

    RUN_INFINI(infiniopRearrange(
        desc,
        dst->data(),
        src->data(),
        stream));
}

void InferenceContext::rope(std::shared_ptr<Tensor> q,
                            std::shared_ptr<Tensor> k,
                            std::shared_ptr<Tensor> pos,
                            std::shared_ptr<Tensor> sin,
                            std::shared_ptr<Tensor> cos) {
105
    size_t key = CacheManager::createDescriptorKey(q, k, pos, sin, cos);
wooway777's avatar
wooway777 committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

    infiniopRoPEDescriptor_t desc;
    if (!cache_manager->getRoPEDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateRoPEDescriptor(
            rsrc->handle, &desc, q->desc(), k->desc(),
            pos->desc(), sin->desc(), cos->desc()));
        cache_manager->putRoPEDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetRoPEWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopRoPE(
        desc, workspace, workspace_size,
        q->data(), k->data(), pos->data(),
        sin->data(), cos->data(), stream));
}

126
127
128
void InferenceContext::causalSoftmax(std::shared_ptr<Tensor> y,
                                     std::shared_ptr<Tensor> x) {
    size_t key = CacheManager::createDescriptorKey(y, x, nullptr, nullptr, nullptr);
wooway777's avatar
wooway777 committed
129
130
131
132

    infiniopCausalSoftmaxDescriptor_t desc;
    if (!cache_manager->getCausalSoftmaxDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateCausalSoftmaxDescriptor(
133
            rsrc->handle, &desc, y->desc(), x->desc()));
wooway777's avatar
wooway777 committed
134
135
136
137
138
139
140
141
142
143
144
145
        cache_manager->putCausalSoftmaxDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetCausalSoftmaxWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopCausalSoftmax(desc, workspace, workspace_size,
                                     y->data(), x->data(), stream));
}

146
147
148
149
void InferenceContext::swiglu(std::shared_ptr<Tensor> out,
                              std::shared_ptr<Tensor> up,
                              std::shared_ptr<Tensor> gate) {
    size_t key = CacheManager::createDescriptorKey(out, up, gate, nullptr, nullptr);
wooway777's avatar
wooway777 committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

    infiniopSwiGLUDescriptor_t desc;
    if (!cache_manager->getSwiGLUDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateSwiGLUDescriptor(
            rsrc->handle, &desc, out->desc(), up->desc(), gate->desc()));
        cache_manager->putSwiGLUDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetSwiGLUWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopSwiGLU(desc, workspace, workspace_size,
                              out->data(), up->data(), gate->data(), stream));
}

167
168
void InferenceContext::randomSample(std::shared_ptr<Tensor> out,
                                    std::shared_ptr<Tensor> prob,
wooway777's avatar
wooway777 committed
169
                                    float random_val, float top_p, uint32_t top_k, float temperature) {
170
    size_t key = CacheManager::createDescriptorKey(out, prob, nullptr, nullptr, nullptr);
wooway777's avatar
wooway777 committed
171
172
173
174

    infiniopRandomSampleDescriptor_t desc;
    if (!cache_manager->getRandomSampleDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateRandomSampleDescriptor(
175
            rsrc->handle, &desc, out->desc(), prob->desc()));
wooway777's avatar
wooway777 committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        cache_manager->putRandomSampleDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetRandomSampleWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopRandomSample(
        desc, workspace, workspace_size,
        out->data(), prob->data(),
        random_val, top_p, top_k, temperature,
        stream));
}
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

void InferenceContext::linear(std::shared_ptr<Tensor> c,
                              std::shared_ptr<Tensor> a,
                              std::shared_ptr<Tensor> b,
                              float alpha, float beta,
                              std::shared_ptr<Tensor> residual) {
    if (residual) {
        if (residual->data() == c->data()) {
            if (beta == 0.0) {
                gemm(c, a, b, alpha, 1.0);
            } else {
                auto c_copy = Tensor::buffer(c->dtype(), c->shape(), rsrc->memory_pool);
                c_copy->copyFrom(c, rsrc->handle, stream);
                gemm(c, a, b, alpha, beta);
                add(c, c, c_copy);
            }
        } else {
            gemm(c, a, b, alpha, beta);
            add(c, c, residual);
        }
    } else {
        gemm(c, a, b, alpha, beta);
    }
}