Commit d9cd32a4 authored by charlie's avatar charlie
Browse files

Seralize and reflect changes

parent 38b5c752
...@@ -65,21 +65,15 @@ struct shape ...@@ -65,21 +65,15 @@ struct shape
std::size_t max = 0; std::size_t max = 0;
std::size_t opt = 0; std::size_t opt = 0;
bool is_fixed() const { return min == max; } template <class Self, class F>
bool has_optimal() const { return opt != 0; } static auto reflect(Self& self, F f);
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y)
{ bool is_fixed() const;
return (x.min == y.min and x.max == y.max and x.opt == y.opt); bool has_optimal() const;
}
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y) friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y);
{ friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y);
return !(x == y); friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x);
}
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x)
{
os << "[" << x.min << ", " << x.max << ", " << x.opt << "]";
return os;
}
}; };
static const std::vector<type_t>& types(); static const std::vector<type_t>& types();
......
...@@ -209,6 +209,7 @@ shape shape::from_permutation(type_t t, ...@@ -209,6 +209,7 @@ shape shape::from_permutation(type_t t,
return result; return result;
} }
shape::type_t shape::type() const { return impl->m_type; } shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; } const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
...@@ -447,6 +448,39 @@ std::vector<std::size_t> shape::opt_lens() const ...@@ -447,6 +448,39 @@ std::vector<std::size_t> shape::opt_lens() const
; ;
} }
bool shape::dynamic_dimension::is_fixed() const
{
return this->min == this->max;
}
bool shape::dynamic_dimension::has_optimal() const
{
return opt != 0;
}
template <class Self, class F>
auto shape::dynamic_dimension::reflect(Self& self, F f)
{
return pack(f(self.min, "min"),
f(self.max, "max"),
f(self.opt, "opt"));
}
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
return (x.min == y.min and x.max == y.max and x.opt == y.opt);
}
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
return !(x == y);
}
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
{
os << "[" << x.min << ", " << x.max << ", " << x.opt << "]";
return os;
}
bool operator==(const shape& x, const shape& y) bool operator==(const shape& x, const shape& y)
{ {
if(x.dynamic() and y.dynamic()) if(x.dynamic() and y.dynamic())
...@@ -501,20 +535,10 @@ void migraphx_to_value(value& v, const shape& s) ...@@ -501,20 +535,10 @@ void migraphx_to_value(value& v, const shape& s)
{ {
value result; value result;
result["type"] = migraphx::to_value(s.type_string()); result["type"] = migraphx::to_value(s.type_string());
if(s.dynamic()) result["lens"] = migraphx::to_value(s.lens());
{ result["strides"] = migraphx::to_value(s.strides());
result["dynamic"] = migraphx::to_value(s.dynamic()); result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
result["min_lens"] = migraphx::to_value(s.min_lens()); result["dynamic_dimensions"] = migraphx::to_value(s.dyn_dims());
result["max_lens"] = migraphx::to_value(s.max_lens());
result["opt_lens"] = migraphx::to_value(s.opt_lens());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
}
else
{
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
}
v = result; v = result;
} }
...@@ -527,26 +551,34 @@ void migraphx_from_value(const value& v, shape& s) ...@@ -527,26 +551,34 @@ void migraphx_from_value(const value& v, shape& s)
} }
else else
{ {
if(v.contains("dynamic")) if(v.at("dynamic_dimensions").empty())
{
auto mins = v.at("min_lens").to_vector<std::size_t>();
auto maxes = v.at("max_lens").to_vector<std::size_t>();
auto opts = v.at("opt_lens").to_vector<std::size_t>();
assert(mins.size() == maxes.size() and maxes.size() == opts.size());
auto num_dims = mins.size();
std::vector<shape::dynamic_dimension> dyn_dims(num_dims);
for(int i = 0; i < num_dims; ++i)
{
dyn_dims.at(i) = shape::dynamic_dimension{mins[i], maxes[i], opts[i]};
}
s = shape{shape::parse_type(t), dyn_dims};
}
else
{ {
s = shape{shape::parse_type(t), s = shape{shape::parse_type(t),
v.at("lens").to_vector<std::size_t>(), v.at("lens").to_vector<std::size_t>(),
v.at("strides").to_vector<std::size_t>()}; v.at("strides").to_vector<std::size_t>()};
} }
else
{
auto v_dd = v.at("dynamic_dimensions");
std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size());
std::transform(
v_dd.begin(),
v_dd.end(),
dyn_dims.begin(),
[](migraphx::value x)
{
auto x_min = x.at("min").template to<size_t>();
auto x_max = x.at("max").template to<size_t>();
auto x_opt = x.at("opt").template to<size_t>();
return shape::dynamic_dimension{x_min, x_max, x_opt};
}
);
s = shape{
shape::parse_type(t),
dyn_dims
};
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment