inference_context.cpp 14 KB
Newer Older
wooway777's avatar
wooway777 committed
1
2
3
4
#include "inference_context.hpp"
#include "../tensor.hpp"
#include "../utils.hpp"

blkmjsian's avatar
blkmjsian committed
5
6
InferenceContext::InferenceContext(infiniopHandle_t op_handle_, std::shared_ptr<MemoryPool> memory_pool_, CacheManager *cache_manager, infinirtStream_t stream)
    : op_handle(op_handle_), memory_pool(memory_pool_), cache_manager(cache_manager), stream(stream) {}
wooway777's avatar
wooway777 committed
7
8

void InferenceContext::ensure_workspace(size_t required_size) {
9
    if (required_size > current_workspace_size || !workspace_storage) {
blkmjsian's avatar
blkmjsian committed
10
        workspace_storage = Storage::createFromPool(required_size, memory_pool);
wooway777's avatar
wooway777 committed
11
12
13
14
        current_workspace_size = required_size;
    }
}

15
16
17
void InferenceContext::add(std::shared_ptr<Tensor> c,
                           std::shared_ptr<Tensor> a,
                           std::shared_ptr<Tensor> b) {
18
    size_t key = CacheManager::createDescriptorKey(c, a, b);
19
20
21

    infiniopAddDescriptor_t desc;
    if (!cache_manager->getAddDescriptor(key, desc)) {
blkmjsian's avatar
blkmjsian committed
22
        RUN_INFINI(infiniopCreateAddDescriptor(op_handle, &desc, c->desc(), a->desc(), b->desc()));
23
24
25
26
27
28
29
30
31
32
33
34
35
        cache_manager->putAddDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetAddWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopAdd(
        desc, workspace, workspace_size,
        c->data(), a->data(), b->data(), stream));
}

hejianlin's avatar
hejianlin committed
36
void InferenceContext::conv(std::shared_ptr<Tensor> y,
PanZezhong's avatar
PanZezhong committed
37
38
39
40
41
42
43
                            std::shared_ptr<Tensor> x,
                            std::shared_ptr<Tensor> w,
                            std::shared_ptr<Tensor> bias,
                            void *pads,
                            void *strides,
                            void *dilations,
                            size_t n) {
hejianlin's avatar
hejianlin committed
44
45
    size_t key = CacheManager::createDescriptorKey(y, x, w, bias);
    // Combine additional parameters into the key for unique identification
PanZezhong's avatar
PanZezhong committed
46
47
48
    hash_combine(key, std::hash<void *>()(pads));
    hash_combine(key, std::hash<void *>()(strides));
    hash_combine(key, std::hash<void *>()(dilations));
hejianlin's avatar
hejianlin committed
49
50
51
52
53
    hash_combine(key, std::hash<size_t>()(n));

    infiniopConvDescriptor_t desc;
    if (!cache_manager->getConvDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateConvDescriptor(
PanZezhong's avatar
PanZezhong committed
54
            op_handle, &desc, y->desc(), x->desc(), w->desc(),
hejianlin's avatar
hejianlin committed
55
56
57
58
59
60
61
62
63
64
65
            bias ? bias->desc() : nullptr, pads, strides, dilations, n));
        cache_manager->putConvDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetConvWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopConv(
        desc, workspace, workspace_size,
PanZezhong's avatar
PanZezhong committed
66
        y->data(), x->data(), w->data(),
hejianlin's avatar
hejianlin committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        bias ? bias->data() : nullptr, stream));
}

void InferenceContext::mul(std::shared_ptr<Tensor> c,
                           std::shared_ptr<Tensor> a,
                           std::shared_ptr<Tensor> b) {
    size_t key = CacheManager::createDescriptorKey(c, a, b);

    infiniopMulDescriptor_t desc;
    if (!cache_manager->getMulDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateMulDescriptor(op_handle, &desc, c->desc(), a->desc(), b->desc()));
        cache_manager->putMulDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetMulWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopMul(
        desc, workspace, workspace_size,
        c->data(), a->data(), b->data(), stream));
}

wooway777's avatar
wooway777 committed
91
92
93
94
void InferenceContext::rmsnorm(std::shared_ptr<Tensor> y,
                               std::shared_ptr<Tensor> x,
                               std::shared_ptr<Tensor> w,
                               float epsilon) {
95
    size_t key = CacheManager::createDescriptorKey(y, x, w);
wooway777's avatar
wooway777 committed
96
97
98
99

    infiniopRMSNormDescriptor_t desc;
    if (!cache_manager->getRMSNormDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateRMSNormDescriptor(
blkmjsian's avatar
blkmjsian committed
100
            op_handle, &desc, y->desc(), x->desc(), w->desc(), epsilon));
wooway777's avatar
wooway777 committed
101
102
103
104
105
106
107
108
109
110
111
112
113
        cache_manager->putRMSNormDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetRMSNormWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopRMSNorm(
        desc, workspace, workspace_size,
        y->data(), x->data(), w->data(), stream));
}

114
115
116
void InferenceContext::gemm(std::shared_ptr<Tensor> c,
                            std::shared_ptr<Tensor> a,
                            std::shared_ptr<Tensor> b,
wooway777's avatar
wooway777 committed
117
                            float alpha, float beta) {
118
    size_t key = CacheManager::createDescriptorKey(c, a, b);
wooway777's avatar
wooway777 committed
119
120
121

    infiniopGemmDescriptor_t desc;
    if (!cache_manager->getGemmDescriptor(key, desc)) {
blkmjsian's avatar
blkmjsian committed
122
        RUN_INFINI(infiniopCreateGemmDescriptor(op_handle, &desc, c->desc(), a->desc(), b->desc()));
wooway777's avatar
wooway777 committed
123
124
125
126
127
128
129
130
131
132
133
134
135
        cache_manager->putGemmDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetGemmWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopGemm(
        desc, workspace, workspace_size,
        c->data(), a->data(), b->data(), alpha, beta, stream));
}

136
137
void InferenceContext::rearrange(std::shared_ptr<Tensor> dst,
                                 std::shared_ptr<Tensor> src) {
138
    size_t key = CacheManager::createDescriptorKey(dst, src);
wooway777's avatar
wooway777 committed
139
140
141

    infiniopRearrangeDescriptor_t desc;
    if (!cache_manager->getRearrangeDescriptor(key, desc)) {
blkmjsian's avatar
blkmjsian committed
142
        RUN_INFINI(infiniopCreateRearrangeDescriptor(op_handle, &desc, dst->desc(), src->desc()));
wooway777's avatar
wooway777 committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        cache_manager->putRearrangeDescriptor(key, desc);
    }

    RUN_INFINI(infiniopRearrange(
        desc,
        dst->data(),
        src->data(),
        stream));
}

void InferenceContext::rope(std::shared_ptr<Tensor> q,
                            std::shared_ptr<Tensor> k,
                            std::shared_ptr<Tensor> pos,
                            std::shared_ptr<Tensor> sin,
PanZezhong1725's avatar
PanZezhong1725 committed
157
158
                            std::shared_ptr<Tensor> cos,
                            infiniopRoPEAlgo_t algo) {
159
    size_t key = CacheManager::createDescriptorKey(q, k, pos, sin, cos);
PanZezhong1725's avatar
PanZezhong1725 committed
160
    hash_combine(key, std::hash<int>()(algo));
wooway777's avatar
wooway777 committed
161
162
163
164

    infiniopRoPEDescriptor_t desc;
    if (!cache_manager->getRoPEDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateRoPEDescriptor(
blkmjsian's avatar
blkmjsian committed
165
            op_handle, &desc, q->desc(), k->desc(),
PanZezhong1725's avatar
PanZezhong1725 committed
166
            pos->desc(), sin->desc(), cos->desc(), algo));
wooway777's avatar
wooway777 committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        cache_manager->putRoPEDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetRoPEWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopRoPE(
        desc, workspace, workspace_size,
        q->data(), k->data(), pos->data(),
        sin->data(), cos->data(), stream));
}

181
182
void InferenceContext::causalSoftmax(std::shared_ptr<Tensor> y,
                                     std::shared_ptr<Tensor> x) {
183
    size_t key = CacheManager::createDescriptorKey(y, x);
wooway777's avatar
wooway777 committed
184
185
186
187

    infiniopCausalSoftmaxDescriptor_t desc;
    if (!cache_manager->getCausalSoftmaxDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateCausalSoftmaxDescriptor(
blkmjsian's avatar
blkmjsian committed
188
            op_handle, &desc, y->desc(), x->desc()));
wooway777's avatar
wooway777 committed
189
190
191
192
193
194
195
196
197
198
199
200
        cache_manager->putCausalSoftmaxDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetCausalSoftmaxWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopCausalSoftmax(desc, workspace, workspace_size,
                                     y->data(), x->data(), stream));
}

blkmjsian's avatar
blkmjsian committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
void InferenceContext::topkrouter(std::shared_ptr<Tensor> values,  // F32
                                  std::shared_ptr<Tensor> indices, // I32
                                  std::shared_ptr<Tensor> x,
                                  std::shared_ptr<Tensor> correction_bias, // F32
                                  float routed_scaling_factor,
                                  size_t topk) {
    size_t key = CacheManager::createDescriptorKey(values, indices, x, correction_bias);

    infiniopTopkrouterDescriptor_t desc;
    if (!cache_manager->getTopkrouterDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateTopkrouterDescriptor(
            op_handle, &desc, x->desc(), correction_bias->desc()));
        cache_manager->putTopkrouterDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetTopkrouterWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopTopkrouter(desc, workspace, workspace_size,
                                  values->data(), indices->data(), x->data(), correction_bias->data(),
                                  routed_scaling_factor, topk, stream));
}

226
227
228
void InferenceContext::swiglu(std::shared_ptr<Tensor> out,
                              std::shared_ptr<Tensor> up,
                              std::shared_ptr<Tensor> gate) {
229
    size_t key = CacheManager::createDescriptorKey(out, up, gate);
wooway777's avatar
wooway777 committed
230
231
232
233

    infiniopSwiGLUDescriptor_t desc;
    if (!cache_manager->getSwiGLUDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateSwiGLUDescriptor(
blkmjsian's avatar
blkmjsian committed
234
            op_handle, &desc, out->desc(), up->desc(), gate->desc()));
wooway777's avatar
wooway777 committed
235
236
237
238
239
240
241
242
243
244
245
246
        cache_manager->putSwiGLUDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetSwiGLUWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopSwiGLU(desc, workspace, workspace_size,
                              out->data(), up->data(), gate->data(), stream));
}

hejianlin's avatar
hejianlin committed
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
void InferenceContext::silu(std::shared_ptr<Tensor> out,
                            std::shared_ptr<Tensor> input) {
    size_t key = CacheManager::createDescriptorKey(out, input);

    infiniopSiluDescriptor_t desc;
    if (!cache_manager->getSiluDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateSiluDescriptor(
            op_handle, &desc, out->desc(), input->desc()));
        cache_manager->putSiluDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetSiluWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopSilu(desc, workspace, workspace_size,
                            out->data(), input->data(), stream));
}

267
268
void InferenceContext::randomSample(std::shared_ptr<Tensor> out,
                                    std::shared_ptr<Tensor> prob,
wooway777's avatar
wooway777 committed
269
                                    float random_val, float top_p, uint32_t top_k, float temperature) {
270
    size_t key = CacheManager::createDescriptorKey(out, prob);
wooway777's avatar
wooway777 committed
271
272
273
274

    infiniopRandomSampleDescriptor_t desc;
    if (!cache_manager->getRandomSampleDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateRandomSampleDescriptor(
blkmjsian's avatar
blkmjsian committed
275
            op_handle, &desc, out->desc(), prob->desc()));
wooway777's avatar
wooway777 committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        cache_manager->putRandomSampleDescriptor(key, desc);
    }

    size_t workspace_size = 0;
    RUN_INFINI(infiniopGetRandomSampleWorkspaceSize(desc, &workspace_size));
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

    RUN_INFINI(infiniopRandomSample(
        desc, workspace, workspace_size,
        out->data(), prob->data(),
        random_val, top_p, top_k, temperature,
        stream));
}
290
291
292
293
294

void InferenceContext::linear(std::shared_ptr<Tensor> c,
                              std::shared_ptr<Tensor> a,
                              std::shared_ptr<Tensor> b,
                              float alpha, float beta,
295
296
                              std::shared_ptr<Tensor> residual,
                              std::shared_ptr<Tensor> bias) {
297
298
299
300
301
302
303
304
305
306
307
308
    bool residual_flag = residual != nullptr;

    if (bias && !residual) {
        int ndim_diff = c->ndim() - 1;
        ASSERT_EQ(bias->ndim(), 1);
        ASSERT_EQ(bias->shape()[0], c->shape()[ndim_diff]);
        std::vector<ptrdiff_t> strides(ndim_diff, 0);
        strides.push_back(bias->strides()[0]);
        rearrange(c, bias->view_as(c->shape(), strides));
        residual = c;
    }

309
310
311
312
313
    if (residual) {
        if (residual->data() == c->data()) {
            if (beta == 0.0) {
                gemm(c, a, b, alpha, 1.0);
            } else {
blkmjsian's avatar
blkmjsian committed
314
315
                auto c_copy = Tensor::buffer(c->dtype(), c->shape(), memory_pool);
                c_copy->copyFrom(c, op_handle, stream);
316
317
318
319
320
321
322
323
324
325
                gemm(c, a, b, alpha, beta);
                add(c, c, c_copy);
            }
        } else {
            gemm(c, a, b, alpha, beta);
            add(c, c, residual);
        }
    } else {
        gemm(c, a, b, alpha, beta);
    }
326

327
    if (bias && residual_flag) {
328
329
330
331
332
333
334
        int ndim_diff = c->ndim() - 1;
        ASSERT_EQ(bias->ndim(), 1);
        ASSERT_EQ(bias->shape()[0], c->shape()[ndim_diff]);
        std::vector<ptrdiff_t> strides(ndim_diff, 0);
        strides.push_back(bias->strides()[0]);
        add(c, c, bias->view_as(c->shape(), strides));
    }
335
}
blkmjsian's avatar
blkmjsian committed
336
337
338
339
340
341
342
343

void InferenceContext::dequant(std::shared_ptr<Tensor> weight,
                               std::shared_ptr<Tensor> in_w,
                               std::shared_ptr<Tensor> in_s,
                               std::shared_ptr<Tensor> in_z) {

    size_t key = CacheManager::createDescriptorKey(weight, in_w, in_s, in_z);

344
345
346
347
    infiniopDequantizeAWQDescriptor_t desc;
    if (!cache_manager->getDequantizeAWQDescriptor(key, desc)) {
        RUN_INFINI(infiniopCreateDequantizeAWQDescriptor(op_handle, &desc, weight->desc(), in_w->desc(), in_s->desc(), in_z->desc()));
        cache_manager->putDequantizeAWQDescriptor(key, desc);
blkmjsian's avatar
blkmjsian committed
348
349
350
    }

    size_t workspace_size = 0;
351
    RUN_INFINI(infiniopGetDequantizeAWQWorkspaceSize(desc, &workspace_size));
blkmjsian's avatar
blkmjsian committed
352
353
354
    ensure_workspace(workspace_size);
    void *workspace = workspace_storage->memory();

355
    RUN_INFINI(infiniopDequantizeAWQ(
blkmjsian's avatar
blkmjsian committed
356
        desc, workspace, workspace_size,
357
        weight->data(), in_w->data(), in_s->data(), in_z->data(), stream));
blkmjsian's avatar
blkmjsian committed
358
}