#ifndef GUARD_RTGLIB_OPERAND_HPP #define GUARD_RTGLIB_OPERAND_HPP #include #include #include #include #include #include #include namespace rtg { /* * Type-erased interface for: * * struct operand * { * std::string name() const; * shape compute_shape(std::vector input) const; * argument compute(std::vector input) const; * }; * */ struct operand { // Constructors operand() = default; template operand(TypeErased_T_ value) : handle_mem_var_( std::make_shared::type>>( std::forward(value))) { } // Assignment template operand& operator=(TypeErased_T_ value) { if(handle_mem_var_.unique()) *handle_mem_var_ = std::forward(value); else if(!handle_mem_var_) handle_mem_var_ = std::make_shared(std::forward(value)); return *this; } std::string name() const { assert(handle_mem_var_); return get_handle_().name(); } shape compute_shape(std::vector input) const { assert(handle_mem_var_); return get_handle_().compute_shape(std::move(input)); } argument compute(std::vector input) const { assert(handle_mem_var_); return get_handle_().compute(std::move(input)); } private: struct handle_base_type_ { virtual ~handle_base_type_() {} virtual std::shared_ptr clone() const = 0; virtual std::string name() const = 0; virtual shape compute_shape(std::vector input) const = 0; virtual argument compute(std::vector input) const = 0; }; template struct handle_type_ : handle_base_type_ { template handle_type_( TypeErased_T_ value, typename std::enable_if::value>::type* = nullptr) : value_(value) { } template handle_type_(TypeErased_T_ value, typename std::enable_if::value, int>::type* = nullptr) noexcept : value_(std::move(value)) { } std::shared_ptr clone() const override { return std::make_shared(value_); } std::string name() const override { return value_.name(); } shape compute_shape(std::vector input) const override { return value_.compute_shape(std::move(input)); } argument compute(std::vector input) const override { return value_.compute(std::move(input)); } TypeErased_T_ value_; }; template struct handle_type_> : handle_type_ { handle_type_(std::reference_wrapper ref) : handle_type_(ref.get()) { } }; const handle_base_type_& get_handle_() const { return *handle_mem_var_; } handle_base_type_& get_handle_() { if(!handle_mem_var_.unique()) handle_mem_var_ = handle_mem_var_->clone(); return *handle_mem_var_; } std::shared_ptr handle_mem_var_; }; } // namespace rtg #endif