#ifndef RTG_GUARD_RTGLIB_SHAPE_HPP #define RTG_GUARD_RTGLIB_SHAPE_HPP #include #include #include #include namespace rtg { struct shape { // Add new types here // clang-format off #define RTG_SHAPE_VISIT_TYPES(m) \ m(float_type, float) \ m(double_type, double) \ m(uint8_type, uint8_t) \ m(int8_type, int8_t) \ m(uint16_type, uint16_t) \ m(int16_type, int16_t) \ m(int32_type, int32_t) \ m(int64_type, int64_t) \ m(uint32_type, uint32_t) \ m(uint64_type, uint64_t) // clang-format on #define RTG_SHAPE_ENUM_TYPES(x, t) x, enum type_t { any_type, RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_ENUM_TYPES) }; #undef RTG_SHAPE_ENUM_TYPES template struct get_type : std::integral_constant { }; #define RTG_SHAPE_GET_TYPE(x, t) \ template \ struct get_type : std::integral_constant \ { \ }; RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_GET_TYPE) #undef RTG_SHAPE_GET_TYPE shape(); shape(type_t t); shape(type_t t, std::vector l); shape(type_t t, std::vector l, std::vector s); type_t type() const; const std::vector& lens() const; const std::vector& strides() const; std::size_t elements() const; std::size_t bytes() const; std::size_t index(std::initializer_list l) const; std::size_t index(const std::vector& l) const; // Map element index to space index std::size_t index(std::size_t i) const; bool packed() const; friend bool operator==(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y); friend std::ostream& operator<<(std::ostream& os, const shape& x); template struct as { using type = T; template T operator()(U u) const { return T(u); } template T* operator()(U* u) const { return static_cast(u); } template const T* operator()(const U* u) const { return static_cast(u); } T operator()() const { return {}; } std::size_t size(std::size_t n = 1) const { return sizeof(T) * n; } template T* from(U* buffer, std::size_t n = 0) const { return reinterpret_cast(buffer) + n; } template const T* from(const U* buffer, std::size_t n = 0) const { return reinterpret_cast(buffer) + n; } }; template void visit_type(Visitor v) const { switch(this->m_type) { case any_type: RTG_THROW("Cannot visit the any_type"); #define RTG_SHAPE_VISITOR_CASE(x, t) \ case x: v(as()); return; RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE) #undef RTG_SHAPE_VISITOR_CASE } RTG_THROW("Unknown type"); } private: type_t m_type; std::vector m_lens; std::vector m_strides; bool m_packed; void calculate_strides(); std::size_t element_space() const; std::string type_string() const; }; } // namespace rtg #endif