#ifndef INFER_TENSOR_H #define INFER_TENSOR_H #include "allocator.hpp" #include "infinicore_infer.h" #include "utils.hpp" #include #include #include class Storage { private: Storage() = default; void *_memory; size_t _size; infiniDevice_t _device_type; int _device_id; std::shared_ptr _memory_pool; public: static std::shared_ptr create(size_t size); static std::shared_ptr createAsync(size_t size, infinirtStream_t stream = nullptr); static std::shared_ptr createFromPool(size_t size, std::shared_ptr pool = nullptr); static std::shared_ptr createHost(size_t size); ~Storage(); void *memory() const { return _memory; } size_t size() const { return _size; } infiniDevice_t deviceType() const { return _device_type; } int deviceId() const { return _device_id; } }; struct SliceParams { size_t dim; size_t start; size_t len; }; template std::vector __shape(Args... args) { return std::vector{static_cast(args)...}; } template std::vector __strides(Args... args) { return std::vector{static_cast(args)...}; } class TensorDesc { private: infiniDtype_t _dtype; std::vector _shape; std::vector _strides; infiniopTensorDescriptor_t _desc; TensorDesc(infiniDtype_t dtype, const std::vector &shape, const std::vector &strides) : _dtype(dtype), _shape(shape), _strides(strides), _desc(nullptr) {} void resetDesc(); public: ~TensorDesc(); static std::shared_ptr create(infiniDtype_t dtype, const std::vector &shape, const std::vector &strides); static std::shared_ptr create(infiniDtype_t dtype, const std::vector &shape); static std::shared_ptr createWithOrder(infiniDtype_t dtype, const std::vector &shape, const std::vector &order); infiniDtype_t dtype() const { return _dtype; } const std::vector &shape() const { return _shape; } const std::vector &strides() const { return _strides; } size_t ndim() const { return _shape.size(); } infiniopTensorDescriptor_t desc() const; bool isContigous() const; std::string info() const; void dimMerge(size_t dim_start, size_t dim_end); void dimSplit(size_t dim, const std::vector &dims); void permute(const std::vector &order); void reDesc(const std::vector new_shape, const std::vector new_strides); }; class Tensor : public std::enable_shared_from_this { private: std::shared_ptr _storage; std::shared_ptr _desc; ptrdiff_t _offset; void *dataImpl(ptrdiff_t offset) const; std::shared_ptr sliceImpl(const std::vector &slices) const; public: static std::shared_ptr buffer(infiniDtype_t dtype, const std::vector &shape, std::shared_ptr pool = nullptr); static std::shared_ptr weight(void *host_data, infiniDtype_t dtype, const std::vector &shape); std::shared_ptr memShare(const std::vector &shape, infiniDtype_t dtype = INFINI_DTYPE_INVALID) const; std::shared_ptr slice(size_t dim, size_t start, size_t len); std::shared_ptr slice(size_t dim, size_t start, size_t len) const; std::shared_ptr slice(const std::vector &slices); std::shared_ptr slice(const std::vector &slices) const; std::shared_ptr dimMerge(size_t dim_start, size_t dim_end); std::shared_ptr dimSplit(size_t dim, const std::vector &dims); std::shared_ptr permute(const std::vector &order); std::shared_ptr reDesc(const std::vector new_shape, const std::vector new_strides); void *data(ptrdiff_t offset = 0); void const *data(ptrdiff_t offset = 0) const; void copyFrom(std::shared_ptr src, infiniopHandle_t handle, infinirtStream_t stream = nullptr); const std::vector &shape() const; const std::vector &strides() const; size_t ndim() const; infiniDtype_t dtype() const; bool isContigous() const; infiniopTensorDescriptor_t desc() const; ptrdiff_t dataOffset() const; infiniDevice_t deviceType() const; int deviceId() const; void debug(const std::string &filename) const; void debug() const; std::string info() const; ~Tensor(); }; inline size_t dsize(infiniDtype_t dtype) { switch (dtype) { case INFINI_DTYPE_INVALID: return 0; case INFINI_DTYPE_BYTE: return 1; case INFINI_DTYPE_BOOL: return 1; case INFINI_DTYPE_I8: return 1; case INFINI_DTYPE_I16: return 2; case INFINI_DTYPE_I32: return 4; case INFINI_DTYPE_I64: return 8; case INFINI_DTYPE_U8: return 1; case INFINI_DTYPE_U16: return 2; case INFINI_DTYPE_U32: return 4; case INFINI_DTYPE_U64: return 8; case INFINI_DTYPE_F8: return 1; case INFINI_DTYPE_F16: return 2; case INFINI_DTYPE_F32: return 4; case INFINI_DTYPE_F64: return 8; case INFINI_DTYPE_C16: return 2; case INFINI_DTYPE_C32: return 4; case INFINI_DTYPE_C64: return 8; case INFINI_DTYPE_C128: return 16; case INFINI_DTYPE_BF16: return 2; default: return 0; } } #endif