#ifndef RTG_GUARD_RTGLIB_OPERAND_HPP #define RTG_GUARD_RTGLIB_OPERAND_HPP #include #include #include #include #include #include #include namespace rtg { /* * Type-erased interface for: * * struct operation * { * std::string name() const; * shape compute_shape(std::vector input) const; * argument compute(std::vector input) const; * }; * */ struct operation { // Constructors operation() = default; template operation(PrivateDetailTypeErasedT value) : private_detail_te_handle_mem_var( std::make_shared::type>>( std::forward(value))) { } // Assignment template operation& operator=(PrivateDetailTypeErasedT value) { if(private_detail_te_handle_mem_var.unique()) *private_detail_te_handle_mem_var = std::forward(value); else if(!private_detail_te_handle_mem_var) private_detail_te_handle_mem_var = std::make_shared( std::forward(value)); return *this; } std::string name() const { assert(private_detail_te_handle_mem_var); return private_detail_te_get_handle().name(); } shape compute_shape(std::vector input) const { assert(private_detail_te_handle_mem_var); return private_detail_te_get_handle().compute_shape(std::move(input)); } argument compute(std::vector input) const { assert(private_detail_te_handle_mem_var); return private_detail_te_get_handle().compute(std::move(input)); } private: struct private_detail_te_handle_base_type { virtual ~private_detail_te_handle_base_type() {} virtual std::shared_ptr clone() const = 0; virtual std::string name() const = 0; virtual shape compute_shape(std::vector input) const = 0; virtual argument compute(std::vector input) const = 0; }; template struct private_detail_te_handle_type : private_detail_te_handle_base_type { template private_detail_te_handle_type( PrivateDetailTypeErasedT value, typename std::enable_if::value>::type* = nullptr) : private_detail_te_value(value) { } template private_detail_te_handle_type( PrivateDetailTypeErasedT value, typename std::enable_if::value, int>::type* = nullptr) noexcept : private_detail_te_value(std::move(value)) { } std::shared_ptr clone() const override { return std::make_shared(private_detail_te_value); } std::string name() const override { return private_detail_te_value.name(); } shape compute_shape(std::vector input) const override { return private_detail_te_value.compute_shape(std::move(input)); } argument compute(std::vector input) const override { return private_detail_te_value.compute(std::move(input)); } PrivateDetailTypeErasedT private_detail_te_value; }; template struct private_detail_te_handle_type> : private_detail_te_handle_type { private_detail_te_handle_type(std::reference_wrapper ref) : private_detail_te_handle_type(ref.get()) { } }; const private_detail_te_handle_base_type& private_detail_te_get_handle() const { return *private_detail_te_handle_mem_var; } private_detail_te_handle_base_type& private_detail_te_get_handle() { if(!private_detail_te_handle_mem_var.unique()) private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); return *private_detail_te_handle_mem_var; } std::shared_ptr private_detail_te_handle_mem_var; }; } // namespace rtg #endif