Unverified Commit 6e1f9f20 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Context serialization (#607)



* Add initial serialization

* Formatting

* Add unit tests

* Formatting

* Add tests for serialization

* Formatting

* Use or not and

* Add value test

* Formatting

* Add more tests

* Add shape serialization

* Formatting

* Add serializtion for literal and argument

* Formatting

* Add from and to value to operatation

* Formatting

* Serialize empty types

* Formatting

* Tidy fixes

* Formatting

* Fix tidy issues

* Formatting

* Reformat value type macro

* Formatting

* Handle enum types

* Formatting

* Use const ref

* Update

* Add tests for to_value/from_value

* Formatting

* code backup

* clang format

* code backup

* clang format

* code backup

* clang format

* remove the from/to_value method for the generate context struct

* clang format

* code backup

* Dont print literal data in hip_copy_literal

* clang format

* add unit test to have better coverage

* remove unnecessary code

* remove unnecessary code

* fix review comments

* clang format

* fix review comments
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 59b80d4e
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -25,11 +26,24 @@ struct context ...@@ -25,11 +26,24 @@ struct context
#else #else
template <class T>
value to_value_context(const T&)
{
return value{};
}
template <class T>
void from_value_context(T&, const value&)
{
}
/* /*
* Type-erased interface for: * Type-erased interface for:
* *
* struct context * struct context
* { * {
* value to_value() const;
* void from_value(const value& v) ;
* void finish() const; * void finish() const;
* }; * };
* *
...@@ -98,6 +112,18 @@ struct context ...@@ -98,6 +112,18 @@ struct context
return private_detail_te_get_handle().type(); return private_detail_te_get_handle().type();
} }
value to_value() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().to_value();
}
void from_value(const value& v)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().from_value(v);
}
void finish() const void finish() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -117,9 +143,39 @@ struct context ...@@ -117,9 +143,39 @@ struct context
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual void finish() const = 0; virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual void finish() const = 0;
}; };
template <class T>
static auto private_detail_te_default_to_value(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.to_value())
{
return private_detail_te_self.to_value();
}
template <class T>
static value private_detail_te_default_to_value(float, T&& private_detail_te_self)
{
return to_value_context(private_detail_te_self);
}
template <class T>
static auto
private_detail_te_default_from_value(char, T&& private_detail_te_self, const value& v)
-> decltype(private_detail_te_self.from_value(v))
{
private_detail_te_self.from_value(v);
}
template <class T>
static void
private_detail_te_default_from_value(float, T&& private_detail_te_self, const value& v)
{
from_value_context(private_detail_te_self, v);
}
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type struct private_detail_te_handle_type : private_detail_te_handle_base_type
{ {
...@@ -148,6 +204,18 @@ struct context ...@@ -148,6 +204,18 @@ struct context
const std::type_info& type() const override { return typeid(private_detail_te_value); } const std::type_info& type() const override { return typeid(private_detail_te_value); }
value to_value() const override
{
return private_detail_te_default_to_value(char(0), private_detail_te_value);
}
void from_value(const value& v) override
{
private_detail_te_default_from_value(char(0), private_detail_te_value, v);
}
void finish() const override { private_detail_te_value.finish(); } void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
...@@ -215,6 +283,9 @@ inline const ValueType& any_cast(const context& x) ...@@ -215,6 +283,9 @@ inline const ValueType& any_cast(const context& x)
return *y; return *y;
} }
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/context.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <unordered_map> #include <unordered_map>
#include <memory>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -168,11 +169,36 @@ struct context ...@@ -168,11 +169,36 @@ struct context
return hip_event_ptr{event}; return hip_event_ptr{event};
} }
value to_value() const
{
value result;
result["events"] = events.size();
result["streams"] = current_device->nstreams();
return result;
}
void from_value(const value& v)
{
auto v_events = v.at("events");
std::size_t n_events = v_events.without_key().to<std::size_t>();
this->create_events(n_events - 1);
auto v_streams = v.at("streams");
std::size_t n_streams = v_streams.without_key().to<std::size_t>();
this->current_device = std::make_shared<hip_device>(0, n_streams);
}
private: private:
// TODO: Make this a vector to support multiple devices // TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device; std::shared_ptr<hip_device> current_device;
std::vector<shared<hip_event_ptr>> events; std::vector<shared<hip_event_ptr>> events;
}; };
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#include <migraphx/serialize.hpp>
#include <migraphx/context.hpp>
#include <migraphx/cpu/context.hpp>
#include <migraphx/functional.hpp>
#include <test.hpp>
TEST_CASE(context)
{
migraphx::context ctx = migraphx::cpu::context{};
migraphx::value v = ctx.to_value();
EXPECT(v.empty());
migraphx::context cpu_ctx = migraphx::cpu::context{};
cpu_ctx.from_value(v);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <iostream>
#include <vector>
#include <migraphx/verify.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/context.hpp>
#include "test.hpp"
TEST_CASE(gpu_context)
{
migraphx::context ctx = migraphx::gpu::context{0, 3};
auto v = ctx.to_value();
EXPECT(v.size() == 2);
EXPECT(v.contains("events"));
EXPECT(v.at("events").without_key().to<std::size_t>() == 0);
EXPECT(v.contains("streams"));
EXPECT(v.at("streams").without_key().to<std::size_t>() == 3);
migraphx::gpu::context g_ctx;
g_ctx.from_value(v);
auto v1 = g_ctx.to_value();
EXPECT(v == v1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -25,11 +26,26 @@ struct context ...@@ -25,11 +26,26 @@ struct context
#else #else
template <class T>
value to_value_context(const T&)
{
return value{};
}
template <class T>
void from_value_context(T&, const value&){}
<% <%
interface('context', interface('context',
virtual('finish', returns='void', const=True) virtual('to_value', returns = 'value', const = True, default = 'to_value_context'),
) virtual('from_value', v = 'const value&', default = 'from_value_context'),
%> virtual('finish', returns = 'void', const = True)) %>
inline void migraphx_to_value(value& v, const context& ctx)
{
v = ctx.to_value();
}
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment