#ifndef MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP #include #include #include #include #include #include #include #include #include #include #include #include namespace migraph { #ifdef DOXYGEN /// The operation interface represents an action an instruction will perform. All /// operation classes must be CopyConstructible. struct operation { /// A unique name identifying the operation std::string name() const; /// This is used to compute the resulting shape from an operation. If an /// operation cannot be run with input shapes, then it should throw an /// exception. shape compute_shape(const std::vector& input) const; /** * @brief This performs the operation's computation. * * This method can be optional when the operation is only used as a placeholder to be lowered * later on. * * @param ctx This is the context created by the `target` during compilation. Implementations * can use the target's `context` class rather than the `context` interface class. * @param output This is the output shape. It is equivalent to running `compute_shape` with each * `shape` of the `argument`. * @param input This is the `argument` result from the previous instruction's computation. * @return Return an `argument` of the result computation. The `shape` of `argument` should be * the same the `output` shape. */ argument compute(context& ctx, const shape& output, const std::vector& input) const; /// An optional stream operator to print the operation. When this is not /// implemented, it will just print the operation's name. friend std::ostream& operator<<(std::ostream& os, const operation& op); }; #else namespace operation_stream { template auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) { os << x.name(); char delim = '['; reflect_each(x, [&](auto& y, auto name) { os << delim; os << name << "="; stream_write_value(os, y); delim = ','; }); if(delim == ',') os << "]"; return os; } } // namespace operation_stream namespace operation_equal { template auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) { if(x.name() != y.name()) return false; const auto& yy = any_cast(y); return reflect_tie(x) == reflect_tie(yy); } } // namespace operation_equal template auto compute_op(rank<1>, const T& x, context& ctx, const shape& output_shape, const std::vector& input) -> decltype(x.compute(auto_any_cast(ctx), output_shape, input)) { return x.compute(auto_any_cast(ctx), output_shape, input); } template argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector&) { std::string name = x.name(); MIGRAPH_THROW("Not computable: " + name); } template argument compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector& input) { return compute_op(rank<1>{}, x, ctx, output_shape, input); } /* * Type-erased interface for: * * struct operation * { * std::string name() const; * shape compute_shape(const std::vector& input) const; * argument compute(context& ctx,const shape& output,const std::vector& input) const; * friend std::ostream & operator<<(std::ostream & os,const operation & op) ; * friend bool operator==(const operation & x,const operation & y) ; * }; * */ struct operation { // Constructors operation() = default; template operation(PrivateDetailTypeErasedT value) : private_detail_te_handle_mem_var( std::make_shared::type>>( std::forward(value))) { } // Assignment template operation& operator=(PrivateDetailTypeErasedT value) { if(private_detail_te_handle_mem_var.unique()) *private_detail_te_handle_mem_var = std::forward(value); else if(!private_detail_te_handle_mem_var) private_detail_te_handle_mem_var = std::make_shared( std::forward(value)); return *this; } // Cast template PrivateDetailTypeErasedT* any_cast() { return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type>&>( private_detail_te_get_handle()) .private_detail_te_value) : nullptr; } template const typename std::remove_cv::type* any_cast() const { return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type>&>( private_detail_te_get_handle()) .private_detail_te_value) : nullptr; } const std::type_info& type_id() const { if(private_detail_te_handle_empty()) return typeid(std::nullptr_t); else return private_detail_te_get_handle().type(); } std::string name() const { assert((*this).private_detail_te_handle_mem_var); return (*this).private_detail_te_get_handle().name(); } shape compute_shape(const std::vector& input) const { assert((*this).private_detail_te_handle_mem_var); return (*this).private_detail_te_get_handle().compute_shape(input); } argument compute(context& ctx, const shape& output, const std::vector& input) const { assert((*this).private_detail_te_handle_mem_var); return (*this).private_detail_te_get_handle().compute(ctx, output, input); } friend std::ostream& operator<<(std::ostream& os, const operation& op) { assert(op.private_detail_te_handle_mem_var); return op.private_detail_te_get_handle().operator_shift_left(os); } friend bool operator==(const operation& x, const operation& y) { assert(x.private_detail_te_handle_mem_var); return x.private_detail_te_get_handle().operator==(y); } private: struct private_detail_te_handle_base_type { virtual ~private_detail_te_handle_base_type() {} virtual std::shared_ptr clone() const = 0; virtual const std::type_info& type() const = 0; virtual std::string name() const = 0; virtual shape compute_shape(const std::vector& input) const = 0; virtual argument compute(context& ctx, const shape& output, const std::vector& input) const = 0; virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; virtual bool operator==(const operation& y) const = 0; }; template struct private_detail_te_handle_type : private_detail_te_handle_base_type { template private_detail_te_handle_type( PrivateDetailTypeErasedT value, typename std::enable_if::value>::type* = nullptr) : private_detail_te_value(value) { } template private_detail_te_handle_type( PrivateDetailTypeErasedT value, typename std::enable_if::value, int>::type* = nullptr) noexcept : private_detail_te_value(std::move(value)) { } std::shared_ptr clone() const override { return std::make_shared(private_detail_te_value); } const std::type_info& type() const override { return typeid(private_detail_te_value); } std::string name() const override { return private_detail_te_value.name(); } shape compute_shape(const std::vector& input) const override { return private_detail_te_value.compute_shape(input); } argument compute(context& ctx, const shape& output, const std::vector& input) const override { return compute_op(private_detail_te_value, ctx, output, input); } std::ostream& operator_shift_left(std::ostream& os) const override { using migraph::operation_stream::operator<<; return os << private_detail_te_value; } bool operator==(const operation& y) const override { using migraph::operation_equal::operator==; return private_detail_te_value == y; } PrivateDetailTypeErasedT private_detail_te_value; }; template struct private_detail_te_handle_type> : private_detail_te_handle_type { private_detail_te_handle_type(std::reference_wrapper ref) : private_detail_te_handle_type(ref.get()) { } }; bool private_detail_te_handle_empty() const { return private_detail_te_handle_mem_var == nullptr; } const private_detail_te_handle_base_type& private_detail_te_get_handle() const { assert(private_detail_te_handle_mem_var != nullptr); return *private_detail_te_handle_mem_var; } private_detail_te_handle_base_type& private_detail_te_get_handle() { assert(private_detail_te_handle_mem_var != nullptr); if(!private_detail_te_handle_mem_var.unique()) private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); return *private_detail_te_handle_mem_var; } std::shared_ptr private_detail_te_handle_mem_var; }; template inline const ValueType* any_cast(const operation* x) { return x->any_cast(); } template inline ValueType* any_cast(operation* x) { return x->any_cast(); } template inline ValueType& any_cast(operation& x) { auto* y = x.any_cast::type>(); if(y == nullptr) throw std::bad_cast(); return *y; } template inline const ValueType& any_cast(const operation& x) { const auto* y = x.any_cast::type>(); if(y == nullptr) throw std::bad_cast(); return *y; } inline bool operator!=(const operation& x, const operation& y) { return !(x == y); } #endif } // namespace migraph #endif