Unverified Commit d1caaaa1 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add to_value/from_value to operation class (#605)

* Add initial serialization

* Formatting

* Add unit tests

* Formatting

* Add tests for serialization

* Formatting

* Use or not and

* Add value test

* Formatting

* Add more tests

* Add shape serialization

* Formatting

* Add serializtion for literal and argument

* Formatting

* Add from and to value to operatation

* Formatting

* Serialize empty types

* Formatting

* Tidy fixes

* Formatting

* Fix tidy issues

* Formatting

* Reformat value type macro

* Formatting

* Handle enum types

* Formatting

* Use const ref

* Update

* Add tests for to_value/from_value

* Formatting

* Add more tests

* Rewrite test to avoid redundant assignment
parent 453517ad
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp> #include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -218,6 +219,18 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, ...@@ -218,6 +219,18 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {}; return {};
} }
template <class T>
value to_value_op(const T& x)
{
return migraphx::to_value(x);
}
template <class T>
void from_value_op(T& x, const value& v)
{
return migraphx::from_value(v, x);
}
} // namespace detail } // namespace detail
/* /*
...@@ -233,6 +246,8 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, ...@@ -233,6 +246,8 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
* shape compute_shape(const std::vector<shape>& input) const; * shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const; * argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input) const; * argument compute(const shape& output,const std::vector<argument>& input) const;
* value to_value() const;
* void from_value(const value& v) ;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ; * friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ; * friend bool operator==(const operation & x,const operation & y) ;
* }; * };
...@@ -350,6 +365,18 @@ struct operation ...@@ -350,6 +365,18 @@ struct operation
return (*this).private_detail_te_get_handle().compute(output, input); return (*this).private_detail_te_get_handle().compute(output, input);
} }
value to_value() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().to_value();
}
void from_value(const value& v)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().from_value(v);
}
friend std::ostream& operator<<(std::ostream& os, const operation& op) friend std::ostream& operator<<(std::ostream& os, const operation& op)
{ {
assert(op.private_detail_te_handle_mem_var); assert(op.private_detail_te_handle_mem_var);
...@@ -385,6 +412,8 @@ struct operation ...@@ -385,6 +412,8 @@ struct operation
virtual argument virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0; compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0; virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0; virtual bool operator==(const operation& y) const = 0;
}; };
...@@ -493,6 +522,34 @@ struct operation ...@@ -493,6 +522,34 @@ struct operation
return detail::compute_op(private_detail_te_self, output, input); return detail::compute_op(private_detail_te_self, output, input);
} }
template <class T>
static auto private_detail_te_default_to_value(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.to_value())
{
return private_detail_te_self.to_value();
}
template <class T>
static value private_detail_te_default_to_value(float, T&& private_detail_te_self)
{
return detail::to_value_op(private_detail_te_self);
}
template <class T>
static auto
private_detail_te_default_from_value(char, T&& private_detail_te_self, const value& v)
-> decltype(private_detail_te_self.from_value(v))
{
private_detail_te_self.from_value(v);
}
template <class T>
static void
private_detail_te_default_from_value(float, T&& private_detail_te_self, const value& v)
{
detail::from_value_op(private_detail_te_self, v);
}
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type struct private_detail_te_handle_type : private_detail_te_handle_base_type
{ {
...@@ -570,6 +627,18 @@ struct operation ...@@ -570,6 +627,18 @@ struct operation
char(0), private_detail_te_value, output, input); char(0), private_detail_te_value, output, input);
} }
value to_value() const override
{
return private_detail_te_default_to_value(char(0), private_detail_te_value);
}
void from_value(const value& v) override
{
private_detail_te_default_from_value(char(0), private_detail_te_value, v);
}
std::ostream& operator_shift_left(std::ostream& os) const override std::ostream& operator_shift_left(std::ostream& os) const override
{ {
using migraphx::detail::operation_operators::operator<<; using migraphx::detail::operation_operators::operator<<;
......
...@@ -11,6 +11,10 @@ ...@@ -11,6 +11,10 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// Avoid implicit conversion with ADL lookup
template <class T>
void migraphx_to_value(value&, const T&) = delete;
template <class T> template <class T>
value to_value(const T& x); value to_value(const T& x);
...@@ -93,7 +97,13 @@ auto to_value_impl(rank<9>, const T& x) -> decltype(migraphx_to_value(x)) ...@@ -93,7 +97,13 @@ auto to_value_impl(rank<9>, const T& x) -> decltype(migraphx_to_value(x))
} }
template <class T> template <class T>
auto to_value_impl(rank<10>, const T& x) auto to_value_impl(rank<10>, const T& x) -> decltype(x.to_value())
{
return x.to_value();
}
template <class T>
auto to_value_impl(rank<11>, const T& x)
-> decltype(migraphx_to_value(std::declval<value&>(), x), value{}) -> decltype(migraphx_to_value(std::declval<value&>(), x), value{})
{ {
value v; value v;
...@@ -152,7 +162,13 @@ void from_value_impl(rank<5>, const value& v, T& x) ...@@ -152,7 +162,13 @@ void from_value_impl(rank<5>, const value& v, T& x)
inline void from_value_impl(rank<6>, const value& v, std::string& x) { x = v.to<std::string>(); } inline void from_value_impl(rank<6>, const value& v, std::string& x) { x = v.to<std::string>(); }
template <class T> template <class T>
auto from_value_impl(rank<7>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void()) auto from_value_impl(rank<7>, const value& v, T& x) -> decltype(x.from_value(v), void())
{
x.from_value(v);
}
template <class T>
auto from_value_impl(rank<8>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
{ {
migraphx_from_value(v, x); migraphx_from_value(v, x);
} }
...@@ -162,13 +178,13 @@ auto from_value_impl(rank<7>, const value& v, T& x) -> decltype(migraphx_from_va ...@@ -162,13 +178,13 @@ auto from_value_impl(rank<7>, const value& v, T& x) -> decltype(migraphx_from_va
template <class T> template <class T>
value to_value(const T& x) value to_value(const T& x)
{ {
return detail::to_value_impl(rank<10>{}, x); return detail::to_value_impl(rank<11>{}, x);
} }
template <class T> template <class T>
void from_value(const value& v, T& x) void from_value(const value& v, T& x)
{ {
detail::from_value_impl(rank<7>{}, v, x); detail::from_value_impl(rank<8>{}, v, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -173,4 +173,35 @@ TEST_CASE(check_run_finalize_throw) ...@@ -173,4 +173,35 @@ TEST_CASE(check_run_finalize_throw)
EXPECT(test::throws([&] { op.finalize(ctx, {}, {}); })); EXPECT(test::throws([&] { op.finalize(ctx, {}, {}); }));
} }
TEST_CASE(check_to_value1)
{
migraphx::operation op = simple_operation{};
auto v = op.to_value();
EXPECT(v == migraphx::value{{"data", 1}});
}
TEST_CASE(check_to_value2)
{
migraphx::operation op = simple_operation{};
auto v = migraphx::to_value(op);
EXPECT(v == migraphx::value{{"data", 1}});
}
TEST_CASE(check_from_value1)
{
migraphx::operation op1 = simple_operation{};
migraphx::operation op2 = simple_operation{3};
op1.from_value({{"data", 3}});
EXPECT(op1 == op2);
}
TEST_CASE(check_from_value2)
{
migraphx::operation op1 = migraphx::from_value<simple_operation>({{"data", 3}});
migraphx::operation op2 = simple_operation{3};
EXPECT(op1 == op2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp> #include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -218,6 +219,18 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, ...@@ -218,6 +219,18 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {}; return {};
} }
template <class T>
value to_value_op(const T& x)
{
return migraphx::to_value(x);
}
template <class T>
void from_value_op(T& x, const value& v)
{
return migraphx::from_value(v, x);
}
} // namespace detail } // namespace detail
<% <%
...@@ -251,6 +264,8 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, ...@@ -251,6 +264,8 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
input = 'const std::vector<argument>&', input = 'const std::vector<argument>&',
const = True, const = True,
default = 'detail::compute_op'), default = 'detail::compute_op'),
virtual('to_value', returns = 'value', const = True, default = 'detail::to_value_op'),
virtual('from_value', v = 'const value&', default = 'detail::from_value_op'),
friend('operator<<', friend('operator<<',
returns = 'std::ostream &', returns = 'std::ostream &',
os = 'std::ostream &', os = 'std::ostream &',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment