Unverified Commit d1caaaa1 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add to_value/from_value to operation class (#605)

* Add initial serialization

* Formatting

* Add unit tests

* Formatting

* Add tests for serialization

* Formatting

* Use or not and

* Add value test

* Formatting

* Add more tests

* Add shape serialization

* Formatting

* Add serializtion for literal and argument

* Formatting

* Add from and to value to operatation

* Formatting

* Serialize empty types

* Formatting

* Tidy fixes

* Formatting

* Fix tidy issues

* Formatting

* Reformat value type macro

* Formatting

* Handle enum types

* Formatting

* Use const ref

* Update

* Add tests for to_value/from_value

* Formatting

* Add more tests

* Rewrite test to avoid redundant assignment
parent 453517ad
......@@ -10,6 +10,7 @@
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp>
......@@ -218,6 +219,18 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {};
}
template <class T>
value to_value_op(const T& x)
{
return migraphx::to_value(x);
}
template <class T>
void from_value_op(T& x, const value& v)
{
return migraphx::from_value(v, x);
}
} // namespace detail
/*
......@@ -233,6 +246,8 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input) const;
* value to_value() const;
* void from_value(const value& v) ;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ;
* };
......@@ -350,6 +365,18 @@ struct operation
return (*this).private_detail_te_get_handle().compute(output, input);
}
value to_value() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().to_value();
}
void from_value(const value& v)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().from_value(v);
}
friend std::ostream& operator<<(std::ostream& os, const operation& op)
{
assert(op.private_detail_te_handle_mem_var);
......@@ -385,6 +412,8 @@ struct operation
virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0;
};
......@@ -493,6 +522,34 @@ struct operation
return detail::compute_op(private_detail_te_self, output, input);
}
template <class T>
static auto private_detail_te_default_to_value(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.to_value())
{
return private_detail_te_self.to_value();
}
template <class T>
static value private_detail_te_default_to_value(float, T&& private_detail_te_self)
{
return detail::to_value_op(private_detail_te_self);
}
template <class T>
static auto
private_detail_te_default_from_value(char, T&& private_detail_te_self, const value& v)
-> decltype(private_detail_te_self.from_value(v))
{
private_detail_te_self.from_value(v);
}
template <class T>
static void
private_detail_te_default_from_value(float, T&& private_detail_te_self, const value& v)
{
detail::from_value_op(private_detail_te_self, v);
}
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
......@@ -570,6 +627,18 @@ struct operation
char(0), private_detail_te_value, output, input);
}
value to_value() const override
{
return private_detail_te_default_to_value(char(0), private_detail_te_value);
}
void from_value(const value& v) override
{
private_detail_te_default_from_value(char(0), private_detail_te_value, v);
}
std::ostream& operator_shift_left(std::ostream& os) const override
{
using migraphx::detail::operation_operators::operator<<;
......
......@@ -11,6 +11,10 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Avoid implicit conversion with ADL lookup
template <class T>
void migraphx_to_value(value&, const T&) = delete;
template <class T>
value to_value(const T& x);
......@@ -93,7 +97,13 @@ auto to_value_impl(rank<9>, const T& x) -> decltype(migraphx_to_value(x))
}
template <class T>
auto to_value_impl(rank<10>, const T& x)
auto to_value_impl(rank<10>, const T& x) -> decltype(x.to_value())
{
return x.to_value();
}
template <class T>
auto to_value_impl(rank<11>, const T& x)
-> decltype(migraphx_to_value(std::declval<value&>(), x), value{})
{
value v;
......@@ -152,7 +162,13 @@ void from_value_impl(rank<5>, const value& v, T& x)
inline void from_value_impl(rank<6>, const value& v, std::string& x) { x = v.to<std::string>(); }
template <class T>
auto from_value_impl(rank<7>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
auto from_value_impl(rank<7>, const value& v, T& x) -> decltype(x.from_value(v), void())
{
x.from_value(v);
}
template <class T>
auto from_value_impl(rank<8>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
{
migraphx_from_value(v, x);
}
......@@ -162,13 +178,13 @@ auto from_value_impl(rank<7>, const value& v, T& x) -> decltype(migraphx_from_va
template <class T>
value to_value(const T& x)
{
return detail::to_value_impl(rank<10>{}, x);
return detail::to_value_impl(rank<11>{}, x);
}
template <class T>
void from_value(const value& v, T& x)
{
detail::from_value_impl(rank<7>{}, v, x);
detail::from_value_impl(rank<8>{}, v, x);
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -173,4 +173,35 @@ TEST_CASE(check_run_finalize_throw)
EXPECT(test::throws([&] { op.finalize(ctx, {}, {}); }));
}
TEST_CASE(check_to_value1)
{
migraphx::operation op = simple_operation{};
auto v = op.to_value();
EXPECT(v == migraphx::value{{"data", 1}});
}
TEST_CASE(check_to_value2)
{
migraphx::operation op = simple_operation{};
auto v = migraphx::to_value(op);
EXPECT(v == migraphx::value{{"data", 1}});
}
TEST_CASE(check_from_value1)
{
migraphx::operation op1 = simple_operation{};
migraphx::operation op2 = simple_operation{3};
op1.from_value({{"data", 3}});
EXPECT(op1 == op2);
}
TEST_CASE(check_from_value2)
{
migraphx::operation op1 = migraphx::from_value<simple_operation>({{"data", 3}});
migraphx::operation op2 = simple_operation{3};
EXPECT(op1 == op2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -10,6 +10,7 @@
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp>
......@@ -218,6 +219,18 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {};
}
template <class T>
value to_value_op(const T& x)
{
return migraphx::to_value(x);
}
template <class T>
void from_value_op(T& x, const value& v)
{
return migraphx::from_value(v, x);
}
} // namespace detail
<%
......@@ -251,6 +264,8 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
input = 'const std::vector<argument>&',
const = True,
default = 'detail::compute_op'),
virtual('to_value', returns = 'value', const = True, default = 'detail::to_value_op'),
virtual('from_value', v = 'const value&', default = 'detail::from_value_op'),
friend('operator<<',
returns = 'std::ostream &',
os = 'std::ostream &',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment