tensor.hpp 6.2 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
#ifndef INFER_TENSOR_H
#define INFER_TENSOR_H

thatPepe's avatar
thatPepe committed
4
#include "allocator.hpp"
PanZezhong's avatar
init  
PanZezhong committed
5
6
7
8
9
#include "utils.hpp"
#include <memory>
#include <string>
#include <vector>

thatPepe's avatar
thatPepe committed
10
class Storage {
wooway777's avatar
wooway777 committed
11
12
private:
    Storage() = default;
PanZezhong's avatar
PanZezhong committed
13
14
15
16
17
    void *_memory;
    size_t _size;
    infiniDevice_t _device_type;
    int _device_id;
    std::shared_ptr<MemoryPool> _memory_pool;
wooway777's avatar
wooway777 committed
18

thatPepe's avatar
thatPepe committed
19
public:
PanZezhong's avatar
init  
PanZezhong committed
20
21
    static std::shared_ptr<Storage> create(size_t size);
    static std::shared_ptr<Storage> createAsync(size_t size, infinirtStream_t stream = nullptr);
thatPepe's avatar
thatPepe committed
22
    static std::shared_ptr<Storage> createFromPool(size_t size, std::shared_ptr<MemoryPool> pool = nullptr);
PanZezhong's avatar
init  
PanZezhong committed
23
24
    static std::shared_ptr<Storage> createHost(size_t size);
    ~Storage();
PanZezhong's avatar
PanZezhong committed
25
26
27
28
29

    void *memory() const { return _memory; }
    size_t size() const { return _size; }
    infiniDevice_t deviceType() const { return _device_type; }
    int deviceId() const { return _device_id; }
PanZezhong's avatar
init  
PanZezhong committed
30
31
32
33
34
35
36
37
};

struct SliceParams {
    size_t dim;
    size_t start;
    size_t len;
};

38
39
40
41
42
43
44
45
46
template <typename... Args>
std::vector<size_t> __shape(Args... args) {
    return std::vector<size_t>{static_cast<size_t>(args)...};
}

template <typename... Args>
std::vector<ptrdiff_t> __strides(Args... args) {
    return std::vector<ptrdiff_t>{static_cast<ptrdiff_t>(args)...};
}
PanZezhong's avatar
init  
PanZezhong committed
47
48
class TensorDesc {
private:
PanZezhong's avatar
PanZezhong committed
49
50
51
    infiniDtype_t _dtype;
    std::vector<size_t> _shape;
    std::vector<ptrdiff_t> _strides;
PanZezhong's avatar
init  
PanZezhong committed
52
    infiniopTensorDescriptor_t _desc;
53
    size_t _seed;
PanZezhong's avatar
init  
PanZezhong committed
54

PanZezhong's avatar
PanZezhong committed
55
    TensorDesc(infiniDtype_t dtype, const std::vector<size_t> &shape,
56
               const std::vector<ptrdiff_t> &strides) : _dtype(dtype), _shape(shape), _strides(strides), _desc(nullptr) { computeTensorDesHash(); }
PanZezhong's avatar
PanZezhong committed
57
    void resetDesc();
58
    void computeTensorDesHash();
PanZezhong's avatar
PanZezhong committed
59

PanZezhong's avatar
init  
PanZezhong committed
60
public:
PanZezhong's avatar
PanZezhong committed
61
    ~TensorDesc();
PanZezhong's avatar
init  
PanZezhong committed
62
63
64
    static std::shared_ptr<TensorDesc>
    create(infiniDtype_t dtype, const std::vector<size_t> &shape,
           const std::vector<ptrdiff_t> &strides);
65
66
67
68
69
    static std::shared_ptr<TensorDesc>
    create(infiniDtype_t dtype, const std::vector<size_t> &shape);
    static std::shared_ptr<TensorDesc>
    createWithOrder(infiniDtype_t dtype, const std::vector<size_t> &shape,
                    const std::vector<size_t> &order);
PanZezhong's avatar
PanZezhong committed
70
71
72
73
74
75
76
77

    infiniDtype_t dtype() const { return _dtype; }
    const std::vector<size_t> &shape() const { return _shape; }
    const std::vector<ptrdiff_t> &strides() const { return _strides; }
    size_t ndim() const { return _shape.size(); }
    infiniopTensorDescriptor_t desc() const;
    bool isContigous() const;
    std::string info() const;
78
    size_t seed() const { return _seed; }
PanZezhong's avatar
PanZezhong committed
79
80
81
82

    void dimMerge(size_t dim_start, size_t dim_end);
    void dimSplit(size_t dim, const std::vector<size_t> &dims);
    void permute(const std::vector<size_t> &order);
PanZezhong's avatar
init  
PanZezhong committed
83
84
85
86
};

class Tensor : public std::enable_shared_from_this<Tensor> {
private:
PanZezhong's avatar
PanZezhong committed
87
    std::shared_ptr<Storage> _storage;
wooway777's avatar
wooway777 committed
88
    std::shared_ptr<const TensorDesc> _desc;
PanZezhong's avatar
PanZezhong committed
89
90

    ptrdiff_t _offset;
PanZezhong's avatar
init  
PanZezhong committed
91

PanZezhong's avatar
PanZezhong committed
92
    void *dataImpl(ptrdiff_t offset) const;
PanZezhong's avatar
init  
PanZezhong committed
93
    std::shared_ptr<Tensor>
PanZezhong's avatar
PanZezhong committed
94
    sliceImpl(const std::vector<SliceParams> &slices) const;
PanZezhong's avatar
init  
PanZezhong committed
95
96
97
98

public:
    static std::shared_ptr<Tensor> buffer(infiniDtype_t dtype,
                                          const std::vector<size_t> &shape,
thatPepe's avatar
thatPepe committed
99
                                          std::shared_ptr<MemoryPool> pool = nullptr);
PanZezhong's avatar
init  
PanZezhong committed
100
101
102
    static std::shared_ptr<Tensor> weight(void *host_data,
                                          infiniDtype_t dtype,
                                          const std::vector<size_t> &shape);
blkmjsian's avatar
blkmjsian committed
103
    void load(const void *host_data, infinirtStream_t stream = nullptr);
104
105
    std::shared_ptr<Tensor> memShare(const std::vector<size_t> &shape,
                                     infiniDtype_t dtype = INFINI_DTYPE_INVALID) const;
PanZezhong's avatar
init  
PanZezhong committed
106
107
108
109
110
111
    std::shared_ptr<Tensor> slice(size_t dim, size_t start, size_t len);
    std::shared_ptr<Tensor const> slice(size_t dim, size_t start,
                                        size_t len) const;
    std::shared_ptr<Tensor> slice(const std::vector<SliceParams> &slices);
    std::shared_ptr<Tensor const>
    slice(const std::vector<SliceParams> &slices) const;
PanZezhong's avatar
PanZezhong committed
112
113
114
    std::shared_ptr<Tensor> dimMerge(size_t dim_start, size_t dim_end);
    std::shared_ptr<Tensor> dimSplit(size_t dim,
                                     const std::vector<size_t> &dims);
PanZezhong's avatar
init  
PanZezhong committed
115
116
117
    std::shared_ptr<Tensor> permute(const std::vector<size_t> &order);
    void *data(ptrdiff_t offset = 0);
    void const *data(ptrdiff_t offset = 0) const;
PanZezhong's avatar
PanZezhong committed
118
119
    void copyFrom(std::shared_ptr<Tensor const> src, infiniopHandle_t handle,
                  infinirtStream_t stream = nullptr);
PanZezhong's avatar
init  
PanZezhong committed
120
121
122
123
    const std::vector<size_t> &shape() const;
    const std::vector<ptrdiff_t> &strides() const;
    size_t ndim() const;
    infiniDtype_t dtype() const;
PanZezhong's avatar
PanZezhong committed
124
125
    bool isContigous() const;
    infiniopTensorDescriptor_t desc() const;
PanZezhong's avatar
PanZezhong committed
126
127
128
    ptrdiff_t dataOffset() const;
    infiniDevice_t deviceType() const;
    int deviceId() const;
blkmjsian's avatar
blkmjsian committed
129
    size_t numel() const;
PanZezhong's avatar
init  
PanZezhong committed
130
131
132

    void debug(const std::string &filename) const;
    void debug() const;
PanZezhong's avatar
PanZezhong committed
133
    std::string info() const;
134
    size_t seed() const;
PanZezhong's avatar
init  
PanZezhong committed
135

136
    std::shared_ptr<Tensor> view(const std::vector<size_t> &new_shape) const;
137
    std::shared_ptr<Tensor> view_as(const std::vector<size_t> &new_shape) const;
138
    std::shared_ptr<Tensor> view_as(const std::vector<size_t> &new_shape, const std::vector<ptrdiff_t> &new_strides) const;
139

PanZezhong's avatar
init  
PanZezhong committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    ~Tensor();
};

inline size_t dsize(infiniDtype_t dtype) {
    switch (dtype) {
    case INFINI_DTYPE_INVALID:
        return 0;
    case INFINI_DTYPE_BYTE:
        return 1;
    case INFINI_DTYPE_BOOL:
        return 1;
    case INFINI_DTYPE_I8:
        return 1;
    case INFINI_DTYPE_I16:
        return 2;
    case INFINI_DTYPE_I32:
        return 4;
    case INFINI_DTYPE_I64:
        return 8;
    case INFINI_DTYPE_U8:
        return 1;
    case INFINI_DTYPE_U16:
        return 2;
    case INFINI_DTYPE_U32:
        return 4;
    case INFINI_DTYPE_U64:
        return 8;
    case INFINI_DTYPE_F8:
        return 1;
    case INFINI_DTYPE_F16:
        return 2;
    case INFINI_DTYPE_F32:
        return 4;
    case INFINI_DTYPE_F64:
        return 8;
    case INFINI_DTYPE_C16:
        return 2;
    case INFINI_DTYPE_C32:
        return 4;
    case INFINI_DTYPE_C64:
        return 8;
    case INFINI_DTYPE_C128:
        return 16;
    case INFINI_DTYPE_BF16:
        return 2;
    default:
        return 0;
    }
}

#endif