Commit d9cd32a4 authored by charlie's avatar charlie
Browse files

Seralize and reflect changes

parent 38b5c752
......@@ -65,21 +65,15 @@ struct shape
std::size_t max = 0;
std::size_t opt = 0;
bool is_fixed() const { return min == max; }
bool has_optimal() const { return opt != 0; }
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y)
{
return (x.min == y.min and x.max == y.max and x.opt == y.opt);
}
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y)
{
return !(x == y);
}
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x)
{
os << "[" << x.min << ", " << x.max << ", " << x.opt << "]";
return os;
}
template <class Self, class F>
static auto reflect(Self& self, F f);
bool is_fixed() const;
bool has_optimal() const;
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y);
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x);
};
static const std::vector<type_t>& types();
......
......@@ -209,6 +209,7 @@ shape shape::from_permutation(type_t t,
return result;
}
shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
......@@ -447,6 +448,39 @@ std::vector<std::size_t> shape::opt_lens() const
;
}
bool shape::dynamic_dimension::is_fixed() const
{
return this->min == this->max;
}
bool shape::dynamic_dimension::has_optimal() const
{
return opt != 0;
}
template <class Self, class F>
auto shape::dynamic_dimension::reflect(Self& self, F f)
{
return pack(f(self.min, "min"),
f(self.max, "max"),
f(self.opt, "opt"));
}
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
return (x.min == y.min and x.max == y.max and x.opt == y.opt);
}
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
return !(x == y);
}
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
{
os << "[" << x.min << ", " << x.max << ", " << x.opt << "]";
return os;
}
bool operator==(const shape& x, const shape& y)
{
if(x.dynamic() and y.dynamic())
......@@ -501,20 +535,10 @@ void migraphx_to_value(value& v, const shape& s)
{
value result;
result["type"] = migraphx::to_value(s.type_string());
if(s.dynamic())
{
result["dynamic"] = migraphx::to_value(s.dynamic());
result["min_lens"] = migraphx::to_value(s.min_lens());
result["max_lens"] = migraphx::to_value(s.max_lens());
result["opt_lens"] = migraphx::to_value(s.opt_lens());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
}
else
{
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
}
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
result["dynamic_dimensions"] = migraphx::to_value(s.dyn_dims());
v = result;
}
......@@ -527,26 +551,34 @@ void migraphx_from_value(const value& v, shape& s)
}
else
{
if(v.contains("dynamic"))
{
auto mins = v.at("min_lens").to_vector<std::size_t>();
auto maxes = v.at("max_lens").to_vector<std::size_t>();
auto opts = v.at("opt_lens").to_vector<std::size_t>();
assert(mins.size() == maxes.size() and maxes.size() == opts.size());
auto num_dims = mins.size();
std::vector<shape::dynamic_dimension> dyn_dims(num_dims);
for(int i = 0; i < num_dims; ++i)
{
dyn_dims.at(i) = shape::dynamic_dimension{mins[i], maxes[i], opts[i]};
}
s = shape{shape::parse_type(t), dyn_dims};
}
else
if(v.at("dynamic_dimensions").empty())
{
s = shape{shape::parse_type(t),
v.at("lens").to_vector<std::size_t>(),
v.at("strides").to_vector<std::size_t>()};
}
else
{
auto v_dd = v.at("dynamic_dimensions");
std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size());
std::transform(
v_dd.begin(),
v_dd.end(),
dyn_dims.begin(),
[](migraphx::value x)
{
auto x_min = x.at("min").template to<size_t>();
auto x_max = x.at("max").template to<size_t>();
auto x_opt = x.at("opt").template to<size_t>();
return shape::dynamic_dimension{x_min, x_max, x_opt};
}
);
s = shape{
shape::parse_type(t),
dyn_dims
};
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment