Unverified Commit 6e1f9f20 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Context serialization (#607)



* Add initial serialization

* Formatting

* Add unit tests

* Formatting

* Add tests for serialization

* Formatting

* Use or not and

* Add value test

* Formatting

* Add more tests

* Add shape serialization

* Formatting

* Add serializtion for literal and argument

* Formatting

* Add from and to value to operatation

* Formatting

* Serialize empty types

* Formatting

* Tidy fixes

* Formatting

* Fix tidy issues

* Formatting

* Reformat value type macro

* Formatting

* Handle enum types

* Formatting

* Use const ref

* Update

* Add tests for to_value/from_value

* Formatting

* code backup

* clang format

* code backup

* clang format

* code backup

* clang format

* remove the from/to_value method for the generate context struct

* clang format

* code backup

* Dont print literal data in hip_copy_literal

* clang format

* add unit test to have better coverage

* remove unnecessary code

* remove unnecessary code

* fix review comments

* clang format

* fix review comments
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 59b80d4e
......@@ -8,6 +8,7 @@
#include <type_traits>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -25,11 +26,24 @@ struct context
#else
template <class T>
value to_value_context(const T&)
{
return value{};
}
template <class T>
void from_value_context(T&, const value&)
{
}
/*
* Type-erased interface for:
*
* struct context
* {
* value to_value() const;
* void from_value(const value& v) ;
* void finish() const;
* };
*
......@@ -98,6 +112,18 @@ struct context
return private_detail_te_get_handle().type();
}
value to_value() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().to_value();
}
void from_value(const value& v)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().from_value(v);
}
void finish() const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -117,9 +143,39 @@ struct context
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual void finish() const = 0;
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual void finish() const = 0;
};
template <class T>
static auto private_detail_te_default_to_value(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.to_value())
{
return private_detail_te_self.to_value();
}
template <class T>
static value private_detail_te_default_to_value(float, T&& private_detail_te_self)
{
return to_value_context(private_detail_te_self);
}
template <class T>
static auto
private_detail_te_default_from_value(char, T&& private_detail_te_self, const value& v)
-> decltype(private_detail_te_self.from_value(v))
{
private_detail_te_self.from_value(v);
}
template <class T>
static void
private_detail_te_default_from_value(float, T&& private_detail_te_self, const value& v)
{
from_value_context(private_detail_te_self, v);
}
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
......@@ -148,6 +204,18 @@ struct context
const std::type_info& type() const override { return typeid(private_detail_te_value); }
value to_value() const override
{
return private_detail_te_default_to_value(char(0), private_detail_te_value);
}
void from_value(const value& v) override
{
private_detail_te_default_from_value(char(0), private_detail_te_value, v);
}
void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value;
......@@ -215,6 +283,9 @@ inline const ValueType& any_cast(const context& x)
return *y;
}
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
#endif
} // namespace MIGRAPHX_INLINE_NS
......
#include <migraphx/serialize.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -8,6 +8,7 @@
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <unordered_map>
#include <memory>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -168,11 +169,36 @@ struct context
return hip_event_ptr{event};
}
value to_value() const
{
value result;
result["events"] = events.size();
result["streams"] = current_device->nstreams();
return result;
}
void from_value(const value& v)
{
auto v_events = v.at("events");
std::size_t n_events = v_events.without_key().to<std::size_t>();
this->create_events(n_events - 1);
auto v_streams = v.at("streams");
std::size_t n_streams = v_streams.without_key().to<std::size_t>();
this->current_device = std::make_shared<hip_device>(0, n_streams);
}
private:
// TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device;
std::vector<shared<hip_event_ptr>> events;
};
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#include <migraphx/serialize.hpp>
#include <migraphx/context.hpp>
#include <migraphx/cpu/context.hpp>
#include <migraphx/functional.hpp>
#include <test.hpp>
TEST_CASE(context)
{
migraphx::context ctx = migraphx::cpu::context{};
migraphx::value v = ctx.to_value();
EXPECT(v.empty());
migraphx::context cpu_ctx = migraphx::cpu::context{};
cpu_ctx.from_value(v);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <iostream>
#include <vector>
#include <migraphx/verify.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/context.hpp>
#include "test.hpp"
TEST_CASE(gpu_context)
{
migraphx::context ctx = migraphx::gpu::context{0, 3};
auto v = ctx.to_value();
EXPECT(v.size() == 2);
EXPECT(v.contains("events"));
EXPECT(v.at("events").without_key().to<std::size_t>() == 0);
EXPECT(v.contains("streams"));
EXPECT(v.at("streams").without_key().to<std::size_t>() == 3);
migraphx::gpu::context g_ctx;
g_ctx.from_value(v);
auto v1 = g_ctx.to_value();
EXPECT(v == v1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -8,6 +8,7 @@
#include <type_traits>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -25,11 +26,26 @@ struct context
#else
template <class T>
value to_value_context(const T&)
{
return value{};
}
template <class T>
void from_value_context(T&, const value&){}
<%
interface('context',
virtual('finish', returns='void', const=True)
)
%>
interface('context',
virtual('to_value', returns = 'value', const = True, default = 'to_value_context'),
virtual('from_value', v = 'const value&', default = 'from_value_context'),
virtual('finish', returns = 'void', const = True)) %>
inline void migraphx_to_value(value& v, const context& ctx)
{
v = ctx.to_value();
}
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment