#ifndef GUARD_RTGLIB_SHAPE_HPP #define GUARD_RTGLIB_SHAPE_HPP #include #include namespace rtg { struct shape { // Add new types here #define RTG_SHAPE_VISIT_TYPES(m) \ m(float_type, float) \ m(int_type, int) \ #define RTG_SHAPE_ENUM_TYPES(x, t) x, enum type_t { RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_ENUM_TYPES) }; #undef RTG_SHAPE_ENUM_TYPES template struct get_type; #define RTG_SHAPE_GET_TYPE(x, t) \ template \ struct get_type : std::integral_constant \ {}; RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_GET_TYPE) #undef RTG_SHAPE_GET_TYPE shape(); shape(type_t t); shape(type_t t, std::vector l); shape(type_t t, std::vector l, std::vector s); type_t type() const; const std::vector& lens() const; const std::vector& strides() const; std::size_t elements() const; std::size_t bytes() const; std::size_t index(std::initializer_list l) const; std::size_t index(const std::vector& l) const; // Map element index to space index std::size_t index(std::size_t i) const; bool packed() const; friend bool operator==(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y); template struct as { using type = T; template T operator()(U u) const { return T(u); } template T* operator()(U* u) const { return static_cast(u); } template const T* operator()(const U* u) const { return static_cast(u); } T operator()() const { return {}; } std::size_t size(std::size_t n=1) const { return sizeof(T)*n; } template T* from(U* buffer, std::size_t n=0) const { return reinterpret_cast(buffer)+n; } template const T* from(const U* buffer, std::size_t n=0) const { return reinterpret_cast(buffer)+n; } }; template void visit_type(Visitor v) const { switch(this->type_) { #define RTG_SHAPE_VISITOR_CASE(x, t) \ case x: \ v(as()); \ return; RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE) #undef RTG_SHAPE_VISITOR_CASE } assert(true); } private: type_t type_; std::vector lens_; std::vector strides_; bool packed_; void calculate_strides(); std::size_t element_space() const; }; } #endif