Commit b889d472 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into develop

parents 791952a5 9e02008c
......@@ -52,7 +52,8 @@ struct literal : raw_data<literal>
fill(start, end);
}
literal(const shape& s, const char* x) : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
template <class T, MIGRAPHX_REQUIRES(sizeof(T) == 1)>
literal(const shape& s, T* x) : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
{
std::copy(x, x + s.bytes(), buffer.get());
}
......
......@@ -118,7 +118,8 @@ struct value
m(uint64, std::uint64_t) \
m(float, double) \
m(string, std::string) \
m(bool, bool)
m(bool, bool) \
m(binary, value::binary)
// clang-format on
enum type_t
{
......@@ -139,6 +140,20 @@ struct value
using const_pointer = const value_type*;
using array = std::vector<value>;
using object = std::unordered_map<std::string, value>;
struct binary : std::vector<std::uint8_t>
{
using base = std::vector<std::uint8_t>;
binary() {}
template <class Container,
MIGRAPHX_REQUIRES(sizeof(*std::declval<Container>().begin()) == 1)>
explicit binary(const Container& c) : base(c.begin(), c.end())
{
}
template <class T>
binary(T* data, std::size_t s) : base(data, data + s)
{
}
};
value() = default;
......
......@@ -39,6 +39,12 @@ void value_to_json(const T& x, json& j)
j = x;
}
void value_to_json(const value::binary& x, json& j)
{
j = json::object();
j["bytes"] = std::vector<int>(x.begin(), x.end());
}
void value_to_json(const std::vector<value>& x, json& j)
{
for(const auto& v : x)
......@@ -97,6 +103,12 @@ migraphx::value value_from_json(const json& j)
break;
case json::value_t::object:
if(j.contains("bytes") and j.size() == 1)
{
val = migraphx::value::binary{j["bytes"].get<std::vector<std::uint8_t>>()};
}
else
{
val = migraphx::value::object{};
for(const auto& item : j.items())
{
......@@ -104,6 +116,7 @@ migraphx::value value_from_json(const json& j)
const json& jv = item.value();
val[key] = jv.get<value>();
}
}
break;
case json::value_t::binary: MIGRAPHX_THROW("Convert JSON to Value: binary type not supported!");
......
......@@ -45,7 +45,10 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = o.as<std::string>();
break;
}
case msgpack::type::BIN: { MIGRAPHX_THROW("msgpack BIN type not supported.");
case msgpack::type::BIN:
{
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
break;
}
case msgpack::type::ARRAY:
{
......@@ -75,6 +78,21 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
}
};
template <>
struct pack<migraphx::value::binary>
{
template <class Stream>
packer<Stream>& operator()(msgpack::packer<Stream>& o,
const migraphx::value::binary& x) const
{
const auto* data = reinterpret_cast<const char*>(x.data());
auto size = x.size();
o.pack_bin(size);
o.pack_bin_body(data, size);
return o;
}
};
template <>
struct pack<migraphx::value>
{
......
......@@ -284,7 +284,7 @@ std::vector<argument> program::eval(parameter_map params) const
}
}
const int program_file_version = 2;
const int program_file_version = 3;
value program::to_value() const
{
......
......@@ -11,7 +11,7 @@ void raw_data_to_value(value& v, const RawData& rd)
{
value result;
result["shape"] = migraphx::to_value(rd.get_shape());
result["data"] = std::string(rd.data(), rd.data() + rd.get_shape().bytes());
result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes());
v = result;
}
......@@ -19,7 +19,7 @@ void migraphx_to_value(value& v, const literal& l) { raw_data_to_value(v, l); }
void migraphx_from_value(const value& v, literal& l)
{
auto s = migraphx::from_value<shape>(v.at("shape"));
l = literal(s, v.at("data").get_string().data());
l = literal(s, v.at("data").get_binary().data());
}
void migraphx_to_value(value& v, const argument& a) { raw_data_to_value(v, a); }
......
......@@ -458,6 +458,15 @@ void print_value(std::ostream& os, const std::vector<value>& x)
os << "}";
}
void print_value(std::ostream& os, const value::binary& x)
{
// Convert binary to integers
std::vector<int> v(x.begin(), x.end());
os << "{";
os << to_string_range(v);
os << "}";
}
std::ostream& operator<<(std::ostream& os, const value& d)
{
d.visit([&](auto&& y) { print_value(os, y); });
......
......@@ -755,4 +755,23 @@ TEST_CASE(value_init_from_vector)
EXPECT(values.at("a").to_vector<int>() == v);
}
TEST_CASE(value_binary_default)
{
migraphx::value v;
v = migraphx::value::binary{};
EXPECT(v.is_binary());
EXPECT(v.get_key().empty());
}
TEST_CASE(value_binary)
{
migraphx::value v;
std::vector<std::uint8_t> data(20);
std::iota(data.begin(), data.end(), 0);
v = migraphx::value::binary{data};
EXPECT(v.is_binary());
EXPECT(v.get_binary() == data);
EXPECT(v.get_key().empty());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment