inference_context.hpp 1.52 KB
Newer Older
wooway777's avatar
wooway777 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// inference_context.hpp
#pragma once

#include "cache_manager.hpp"
#include "jiuge/jiuge_impl.hpp"
#include "jiuge/jiuge_weight.hpp"

struct InferenceContext {
    DeviceResource *rsrc;
    CacheManager *cache_manager;
    infinirtStream_t stream;
    std::shared_ptr<Storage> workspace_storage;
    size_t current_workspace_size = 0;

    InferenceContext(DeviceResource *rsrc, CacheManager *cache_manager, infinirtStream_t stream);

    void ensure_workspace(size_t required_size);
18

wooway777's avatar
wooway777 committed
19
20
21
22
    void rmsnorm(std::shared_ptr<Tensor> y,
                 std::shared_ptr<Tensor> x,
                 std::shared_ptr<Tensor> w,
                 float epsilon);
23
24
25
    void gemm(std::shared_ptr<Tensor> c,
              std::shared_ptr<Tensor> a,
              std::shared_ptr<Tensor> b,
wooway777's avatar
wooway777 committed
26
              float alpha, float beta);
27
28
    void rearrange(std::shared_ptr<Tensor> dst,
                   std::shared_ptr<Tensor> src);
wooway777's avatar
wooway777 committed
29
30
31
32
33
    void rope(std::shared_ptr<Tensor> q,
              std::shared_ptr<Tensor> k,
              std::shared_ptr<Tensor> pos,
              std::shared_ptr<Tensor> sin,
              std::shared_ptr<Tensor> cos);
34
35
36
37
38
39
40
    void causalSoftmax(std::shared_ptr<Tensor> y,
                       std::shared_ptr<Tensor> x);
    void swiglu(std::shared_ptr<Tensor> out,
                std::shared_ptr<Tensor> up,
                std::shared_ptr<Tensor> gate);
    void randomSample(std::shared_ptr<Tensor> out,
                      std::shared_ptr<Tensor> prob,
wooway777's avatar
wooway777 committed
41
42
                      float random_val, float top_p, uint32_t top_k, float temperature);
};