tensor.hpp 4.88 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
#ifndef INFER_TENSOR_H
#define INFER_TENSOR_H

thatPepe's avatar
thatPepe committed
4
#include "allocator.hpp"
PanZezhong's avatar
init  
PanZezhong committed
5
6
7
8
9
10
#include "infinicore_infer.h"
#include "utils.hpp"
#include <memory>
#include <string>
#include <vector>

thatPepe's avatar
thatPepe committed
11
class Storage {
wooway777's avatar
wooway777 committed
12
13
14
private:
    Storage() = default;

thatPepe's avatar
thatPepe committed
15
public:
PanZezhong's avatar
init  
PanZezhong committed
16
17
18
19
    void *memory;
    size_t size;
    infiniDevice_t device_type;
    int device_id;
thatPepe's avatar
thatPepe committed
20
    std::shared_ptr<MemoryPool> memory_pool;
PanZezhong's avatar
init  
PanZezhong committed
21
22
23

    static std::shared_ptr<Storage> create(size_t size);
    static std::shared_ptr<Storage> createAsync(size_t size, infinirtStream_t stream = nullptr);
thatPepe's avatar
thatPepe committed
24
    static std::shared_ptr<Storage> createFromPool(size_t size, std::shared_ptr<MemoryPool> pool = nullptr);
PanZezhong's avatar
init  
PanZezhong committed
25
26
27
28
29
30
31
32
33
34
    static std::shared_ptr<Storage> createHost(size_t size);
    ~Storage();
};

struct SliceParams {
    size_t dim;
    size_t start;
    size_t len;
};

35
36
37
38
39
40
41
42
43
template <typename... Args>
std::vector<size_t> __shape(Args... args) {
    return std::vector<size_t>{static_cast<size_t>(args)...};
}

template <typename... Args>
std::vector<ptrdiff_t> __strides(Args... args) {
    return std::vector<ptrdiff_t>{static_cast<ptrdiff_t>(args)...};
}
PanZezhong's avatar
init  
PanZezhong committed
44
45
46
47
48
49
50
51
class TensorDesc {
private:
    infiniopTensorDescriptor_t _desc;

public:
    static std::shared_ptr<TensorDesc>
    create(infiniDtype_t dtype, const std::vector<size_t> &shape,
           const std::vector<ptrdiff_t> &strides);
52
53
54
55
56
    static std::shared_ptr<TensorDesc>
    create(infiniDtype_t dtype, const std::vector<size_t> &shape);
    static std::shared_ptr<TensorDesc>
    createWithOrder(infiniDtype_t dtype, const std::vector<size_t> &shape,
                    const std::vector<size_t> &order);
PanZezhong's avatar
init  
PanZezhong committed
57
58
59
60
61
62
63
64
65
66
67
    infiniopTensorDescriptor_t get() const { return _desc; };
    ~TensorDesc();
};

class Tensor : public std::enable_shared_from_this<Tensor> {
private:
    infiniDtype_t _dtype;
    std::vector<size_t> _shape;
    std::vector<ptrdiff_t> _strides;
    void *_data;
    ptrdiff_t _offset;
PanZezhong's avatar
PanZezhong committed
68
    std::shared_ptr<Storage> _storage;
PanZezhong's avatar
init  
PanZezhong committed
69
70
    infiniopTensorDescriptor_t _desc;

PanZezhong's avatar
PanZezhong committed
71
    void *dataImpl(ptrdiff_t offset) const;
PanZezhong's avatar
init  
PanZezhong committed
72
    std::shared_ptr<Tensor>
PanZezhong's avatar
PanZezhong committed
73
    sliceImpl(const std::vector<SliceParams> &slices) const;
PanZezhong's avatar
init  
PanZezhong committed
74
75
76
77

public:
    static std::shared_ptr<Tensor> buffer(infiniDtype_t dtype,
                                          const std::vector<size_t> &shape,
thatPepe's avatar
thatPepe committed
78
                                          std::shared_ptr<MemoryPool> pool = nullptr);
PanZezhong's avatar
init  
PanZezhong committed
79
80
81
    static std::shared_ptr<Tensor> weight(void *host_data,
                                          infiniDtype_t dtype,
                                          const std::vector<size_t> &shape);
82
83
    std::shared_ptr<Tensor> memShare(const std::vector<size_t> &shape,
                                     infiniDtype_t dtype = INFINI_DTYPE_INVALID) const;
PanZezhong's avatar
init  
PanZezhong committed
84
85
86
87
88
89
    std::shared_ptr<Tensor> slice(size_t dim, size_t start, size_t len);
    std::shared_ptr<Tensor const> slice(size_t dim, size_t start,
                                        size_t len) const;
    std::shared_ptr<Tensor> slice(const std::vector<SliceParams> &slices);
    std::shared_ptr<Tensor const>
    slice(const std::vector<SliceParams> &slices) const;
PanZezhong's avatar
PanZezhong committed
90
91
92
    std::shared_ptr<Tensor> dimMerge(size_t dim_start, size_t dim_end);
    std::shared_ptr<Tensor> dimSplit(size_t dim,
                                     const std::vector<size_t> &dims);
PanZezhong's avatar
init  
PanZezhong committed
93
94
95
    std::shared_ptr<Tensor> permute(const std::vector<size_t> &order);
    void *data(ptrdiff_t offset = 0);
    void const *data(ptrdiff_t offset = 0) const;
PanZezhong's avatar
PanZezhong committed
96
97
    void copyFrom(std::shared_ptr<Tensor const> src, infiniopHandle_t handle,
                  infinirtStream_t stream = nullptr);
PanZezhong's avatar
init  
PanZezhong committed
98
99
100
101
102
    const std::vector<size_t> &shape() const;
    const std::vector<ptrdiff_t> &strides() const;
    size_t ndim() const;
    infiniDtype_t dtype() const;
    std::shared_ptr<TensorDesc> desc() const;
PanZezhong's avatar
PanZezhong committed
103
104
105
    ptrdiff_t dataOffset() const;
    infiniDevice_t deviceType() const;
    int deviceId() const;
PanZezhong's avatar
init  
PanZezhong committed
106
107
108
109
    bool is_contigous() const;

    void debug(const std::string &filename) const;
    void debug() const;
PanZezhong's avatar
PanZezhong committed
110
    std::string info() const;
PanZezhong's avatar
init  
PanZezhong committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

    ~Tensor();
};

inline size_t dsize(infiniDtype_t dtype) {
    switch (dtype) {
    case INFINI_DTYPE_INVALID:
        return 0;
    case INFINI_DTYPE_BYTE:
        return 1;
    case INFINI_DTYPE_BOOL:
        return 1;
    case INFINI_DTYPE_I8:
        return 1;
    case INFINI_DTYPE_I16:
        return 2;
    case INFINI_DTYPE_I32:
        return 4;
    case INFINI_DTYPE_I64:
        return 8;
    case INFINI_DTYPE_U8:
        return 1;
    case INFINI_DTYPE_U16:
        return 2;
    case INFINI_DTYPE_U32:
        return 4;
    case INFINI_DTYPE_U64:
        return 8;
    case INFINI_DTYPE_F8:
        return 1;
    case INFINI_DTYPE_F16:
        return 2;
    case INFINI_DTYPE_F32:
        return 4;
    case INFINI_DTYPE_F64:
        return 8;
    case INFINI_DTYPE_C16:
        return 2;
    case INFINI_DTYPE_C32:
        return 4;
    case INFINI_DTYPE_C64:
        return 8;
    case INFINI_DTYPE_C128:
        return 16;
    case INFINI_DTYPE_BF16:
        return 2;
    default:
        return 0;
    }
}

#endif