tensor.h 1.72 KB
Newer Older
PanZezhong's avatar
PanZezhong committed
1
2
3
4
#ifndef __INFINIOP_TENSOR_H__
#define __INFINIOP_TENSOR_H__

#include "infiniop/tensor_descriptor.h"
PanZezhong's avatar
PanZezhong committed
5
6
7

#include "../utils.h"

PanZezhong's avatar
PanZezhong committed
8
9
10
#include <string>
#include <vector>

PanZezhong's avatar
PanZezhong committed
11
12
13
14
15
16
17
#define TRANSFORM_TENSOR_DESC(__TENSOR_DESC__, __OP__) \
    do {                                               \
        auto __RESULT__ = __TENSOR_DESC__->__OP__;     \
        CHECK_RESULT(__RESULT__);                      \
        __TENSOR_DESC__ = __RESULT__.take();           \
    } while (0)

PanZezhong's avatar
PanZezhong committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
struct InfiniopTensorDescriptor {
private:
    // Datatype
    infiniDtype_t _dtype;
    // Shape of the tensor
    std::vector<size_t> _shape;
    // Stride of each dimension in elements
    std::vector<ptrdiff_t> _strides;

public:
    InfiniopTensorDescriptor(infiniDtype_t dtype, size_t ndim, const size_t *shape, const ptrdiff_t *strides);
    ~InfiniopTensorDescriptor() = default;
    infiniDtype_t dtype() const;
    std::vector<size_t> shape() const;
    size_t dim(size_t i) const;
    size_t ndim() const;
    std::vector<ptrdiff_t> strides() const;
    ptrdiff_t stride(size_t i) const;
    std::vector<ptrdiff_t> getByteStrides() const;
    bool isContiguous(size_t dim_start, size_t dim_end) const;
    bool isContiguous() const;
    size_t numel() const;

41
42
43
44
    // a dim is broadcasted if it's corresponding stride is 0 but dim > 1
    bool hasBroadcastDim() const;
    std::vector<size_t> getBroadcastDim() const;

PanZezhong's avatar
PanZezhong committed
45
46
47
    utils::Result<infiniopTensorDescriptor_t> dimMerge(size_t dim_start, size_t dim_end) const;
    utils::Result<infiniopTensorDescriptor_t> dimSplit(size_t axis, const std::vector<size_t> &dims) const;
    utils::Result<infiniopTensorDescriptor_t> dimPermute(const std::vector<size_t> &order) const;
PanZezhong's avatar
PanZezhong committed
48
49
50
51
52

    std::string toString() const;
};

#endif // __INFINIOP_TENSOR_H__