Commit f8822187 authored by charlie's avatar charlie
Browse files

Fixing serialization errors

parent 8bf8e161
...@@ -132,9 +132,10 @@ struct shape ...@@ -132,9 +132,10 @@ struct shape
std::size_t type_size() const; std::size_t type_size() const;
const std::vector<dynamic_dimension>& dyn_dims() const; const std::vector<dynamic_dimension>& dyn_dims() const;
const std::vector<std::size_t>& min_dyn_dims() const;
const std::vector<std::size_t>& max_dyn_dims() const; const std::vector<std::size_t> min_dyn_dims() const;
const std::vector<std::size_t>& opt_dyn_dims() const; const std::vector<std::size_t> max_dyn_dims() const;
const std::vector<std::size_t> opt_dyn_dims() const;
/// Map multiple indices to space index /// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
......
...@@ -210,8 +210,6 @@ std::size_t shape::type_size() const ...@@ -210,8 +210,6 @@ std::size_t shape::type_size() const
return n; return n;
} }
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
std::size_t shape::index(std::initializer_list<std::size_t> l) const std::size_t shape::index(std::initializer_list<std::size_t> l) const
{ {
if(this->dynamic()) if(this->dynamic())
...@@ -283,8 +281,6 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end ...@@ -283,8 +281,6 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
}); });
} }
bool shape::dynamic() const { return (impl->m_dynamic); }
bool shape::packed() const bool shape::packed() const
{ {
if(this->dynamic()) if(this->dynamic())
...@@ -383,6 +379,43 @@ std::size_t shape::element_space() const { return impl->element_space(); } ...@@ -383,6 +379,43 @@ std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); } std::string shape::type_string() const { return name(this->type()); }
bool shape::dynamic() const { return impl->m_dynamic; }
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
const std::vector<std::size_t> shape::min_dyn_dims() const
{
auto num_dims = dyn_dims().size();
std::vector<std::size_t> ret{num_dims};
for(int i = 0; i < num_dims; ++i)
{
ret.at(i) = dyn_dims().at(i).min;
}
return ret;
}
const std::vector<std::size_t> shape::max_dyn_dims() const
{
auto num_dims = dyn_dims().size();
std::vector<std::size_t> ret{num_dims};
for(int i = 0; i < num_dims; ++i)
{
ret.at(i) = dyn_dims().at(i).max;
}
return ret;
}
const std::vector<std::size_t> shape::opt_dyn_dims() const
{
auto num_dims = dyn_dims().size();
std::vector<std::size_t> ret{num_dims};
for(int i = 0; i < num_dims; ++i)
{
ret.at(i) = dyn_dims().at(i).opt;
}
return ret;
}
bool operator==(const shape& x, const shape& y) bool operator==(const shape& x, const shape& y)
{ {
if(x.dynamic() and y.dynamic()) if(x.dynamic() and y.dynamic())
...@@ -436,10 +469,10 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; } ...@@ -436,10 +469,10 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
void migraphx_to_value(value& v, const shape& s) void migraphx_to_value(value& v, const shape& s)
{ {
value result; value result;
result["type"] = migraphx::to_value(s.type_string());
if(s.dynamic()) if(s.dynamic())
{ {
result["dynamic"] = migraphx::to_value(s.dynamic()); result["dynamic"] = migraphx::to_value(s.dynamic());
result["type"] = migraphx::to_value(s.type_string());
result["min_dyn_dims"] = migraphx::to_value(s.min_dyn_dims()); result["min_dyn_dims"] = migraphx::to_value(s.min_dyn_dims());
result["max_dyn_dims"] = migraphx::to_value(s.max_dyn_dims()); result["max_dyn_dims"] = migraphx::to_value(s.max_dyn_dims());
result["opt_dyn_dims"] = migraphx::to_value(s.opt_dyn_dims()); result["opt_dyn_dims"] = migraphx::to_value(s.opt_dyn_dims());
...@@ -447,7 +480,6 @@ void migraphx_to_value(value& v, const shape& s) ...@@ -447,7 +480,6 @@ void migraphx_to_value(value& v, const shape& s)
} }
else else
{ {
result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens()); result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides()); result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes()); result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
...@@ -464,8 +496,7 @@ void migraphx_from_value(const value& v, shape& s) ...@@ -464,8 +496,7 @@ void migraphx_from_value(const value& v, shape& s)
} }
else else
{ {
auto dyn = v.at("dynamic").get_bool(); if(v.contains("dynamic"))
if(dyn)
{ {
auto mins = v.at("min_dyn_dims").to_vector<std::size_t>(); auto mins = v.at("min_dyn_dims").to_vector<std::size_t>();
auto maxes = v.at("max_dyn_dims").to_vector<std::size_t>(); auto maxes = v.at("max_dyn_dims").to_vector<std::size_t>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment