Commit d29aa711 authored by Paul's avatar Paul
Browse files

Add reflectable equality operator

parent 76f7ae49
......@@ -8,7 +8,7 @@
#include <type_traits>
#include <utility>
#include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/reflect.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
......@@ -59,6 +59,19 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream
namespace operation_equal {
template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{
if(x.name() != y.name())
return false;
const auto& yy = any_cast<T>(y);
return reflect_tie(x) == reflect_tie(yy);
}
} // namespace operation_equal
template <class T>
auto compute_op(rank<1>,
const T& x,
......@@ -93,6 +106,7 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ;
* };
*
*/
......@@ -178,6 +192,12 @@ struct operation
return op.private_detail_te_get_handle().operator_shift_left(os);
}
friend bool operator==(const operation& x, const operation& y)
{
assert(x.private_detail_te_handle_mem_var);
return x.private_detail_te_get_handle().operator==(y);
}
private:
struct private_detail_te_handle_base_type
{
......@@ -190,6 +210,7 @@ struct operation
virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -242,6 +263,12 @@ struct operation
return os << private_detail_te_value;
}
bool operator==(const operation& y) const override
{
using migraph::operation_equal::operator==;
return private_detail_te_value == y;
}
PrivateDetailTypeErasedT private_detail_te_value;
};
......@@ -307,6 +334,8 @@ inline const ValueType& any_cast(const operation& x)
return *y;
}
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
#endif
} // namespace migraph
......
#ifndef MIGRAPH_GUARD_RTGLIB_REFLECT_HPP
#define MIGRAPH_GUARD_RTGLIB_REFLECT_HPP
#include <migraph/functional.hpp>
#include <migraph/rank.hpp>
#include <functional>
namespace migraph {
namespace detail {
template<class T, class Selector>
auto reflect_impl(rank<1>, T& x, Selector f) -> decltype(T::reflect(x, f))
{
return T::reflect(x, std::move(f));
}
template<class T, class Selector>
auto reflect_impl(rank<0>, T&, Selector)
{
return pack();
}
} // namespace detail
template<class T, class Selector>
auto reflect(T& x, Selector f)
{
return detail::reflect_impl(rank<1>{}, x, std::move(f));
}
template<class T>
auto reflect_tie(T& x)
{
return reflect(x, [](auto&& y, auto&&...) { return std::ref(y); })([](auto&&... xs) {
return std::tie(xs.get()...);
});
}
template<class T, class F>
void reflect_each(T& x, F f)
{
return reflect(x, [](auto&& y, auto... ys) { return pack(std::ref(y), ys...); })([&](auto&&... xs) {
each_args([&](auto p) {
p([&](auto&& y, auto... ys) {
f(y, ys...);
});
}, xs...);
});
}
} // namespace migraph
#endif
......@@ -6,6 +6,11 @@
struct simple_operation
{
template<class T, class F>
static auto reflect(T& x, F f)
{
return migraph::pack(f(x.data, "data"));
}
int data = 1;
std::string name() const { return "simple"; }
migraph::shape compute_shape(const std::vector<migraph::shape>&) const
......@@ -19,7 +24,7 @@ struct simple_operation
}
friend std::ostream& operator<<(std::ostream& os, const simple_operation& op)
{
os << "[" << op.name() << "]";
os << op.name() << "[" << op.data << "]";
return os;
}
};
......@@ -44,9 +49,23 @@ void operation_copy_test()
migraph::operation op1 = s; // NOLINT
migraph::operation op2 = op1; // NOLINT
// cppcheck-suppress duplicateExpression
EXPECT(s.name() == op1.name());
EXPECT(s == op1);
// cppcheck-suppress duplicateExpression
EXPECT(op2.name() == op1.name());
EXPECT(op2 == op1);
}
void operation_equal_test()
{
simple_operation s{};
migraph::operation op1 = s;
s.data = 2;
migraph::operation op2 = op1;
migraph::operation op3 = s;
EXPECT(s != op1);
EXPECT(op2 == op1);
EXPECT(op3 != op2);
EXPECT(op3 != op1);
}
struct not_operation
......@@ -70,7 +89,7 @@ void operation_print()
std::stringstream ss;
ss << op;
std::string s = ss.str();
EXPECT(s == "[simple]");
EXPECT(s == "simple[1]");
}
void operation_default_print()
......@@ -85,6 +104,7 @@ void operation_default_print()
int main()
{
operation_copy_test();
operation_equal_test();
operation_any_cast();
operation_print();
operation_default_print();
......
......@@ -8,7 +8,7 @@
#include <type_traits>
#include <utility>
#include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/reflect.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
......@@ -59,6 +59,19 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream
namespace operation_equal {
template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{
if(x.name() != y.name())
return false;
const auto& yy = any_cast<T>(y);
return reflect_tie(x) == reflect_tie(yy);
}
} // namespace operation_equal
template <class T>
auto compute_op(rank<1>,
const T& x,
......@@ -89,10 +102,16 @@ interface('operation',
virtual('name', returns='std::string', const=True),
virtual('compute_shape', returns='shape', input='const std::vector<shape>&', const=True),
virtual('compute', returns='argument', ctx='context&', output='const shape&', input='const std::vector<argument>&', const=True, default='compute_op'),
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='migraph::operation_stream::operator<<')
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='migraph::operation_stream::operator<<'),
friend('operator==', returns='bool', x='const operation &', y='const operation &', using='migraph::operation_equal::operator==')
)
%>
inline bool operator!=(const operation& x, const operation& y)
{
return !(x == y);
}
#endif
} // namespace migraph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment