Commit 8bf8e161 authored by charlie's avatar charlie
Browse files

Shape class changes to handle dynamic

* More throw errors for functions that don't make sense for dynamic shape
* Print output changes
* Serialization changes
parent 4680518a
...@@ -64,8 +64,21 @@ struct shape ...@@ -64,8 +64,21 @@ struct shape
std::size_t min = 0; std::size_t min = 0;
std::size_t max = 0; std::size_t max = 0;
std::size_t opt = 0; std::size_t opt = 0;
bool is_fixed() const { return min == max; }; bool is_fixed() const { return min == max; }
bool has_optimal() const { return opt != 0; }; bool has_optimal() const { return opt != 0; }
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y)
{
return (x.min == y.min and x.max == y.max and x.opt == y.opt);
}
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y)
{
return !(x == y);
}
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x)
{
os << "[" << x.min << ", " << x.max << ", " << x.opt << "]";
return os;
}
}; };
static const std::vector<type_t>& types(); static const std::vector<type_t>& types();
...@@ -100,11 +113,28 @@ struct shape ...@@ -100,11 +113,28 @@ struct shape
type_t type() const; type_t type() const;
const std::vector<std::size_t>& lens() const; const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
/*!
* Return the number of elements in the tensor. Multiply the lengths.
*/
std::size_t elements() const; std::size_t elements() const;
/*!
* Return the number of total bytes used for storage of the tensor data.
* Includes subshapes.
*/
std::size_t bytes() const; std::size_t bytes() const;
/*!
* Return the size of the type of the main shape.
* Returns 0 if there are subshapes.
*/
std::size_t type_size() const; std::size_t type_size() const;
const std::vector<dynamic_dimension>& dyn_dims() const; const std::vector<dynamic_dimension>& dyn_dims() const;
const std::vector<std::size_t>& min_dyn_dims() const;
const std::vector<std::size_t>& max_dyn_dims() const;
const std::vector<std::size_t>& opt_dyn_dims() const;
/// Map multiple indices to space index /// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
......
...@@ -184,6 +184,10 @@ std::size_t shape::elements() const { return impl->elements(); } ...@@ -184,6 +184,10 @@ std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const std::size_t shape::bytes() const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: bytes() called on dynamic shape");
}
if(this->sub_shapes().empty()) if(this->sub_shapes().empty())
{ {
std::size_t n = 0; std::size_t n = 0;
...@@ -210,18 +214,30 @@ const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return im ...@@ -210,18 +214,30 @@ const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return im
std::size_t shape::index(std::initializer_list<std::size_t> l) const std::size_t shape::index(std::initializer_list<std::size_t> l) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(l.size() <= this->lens().size()); assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0}); return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
} }
std::size_t shape::index(const std::vector<std::size_t>& l) const std::size_t shape::index(const std::vector<std::size_t>& l) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(l.size() <= this->lens().size()); assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0}); return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
} }
std::size_t shape::index(std::size_t i) const std::size_t shape::index(std::size_t i) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
if(this->standard()) if(this->standard())
return i; return i;
...@@ -338,6 +354,10 @@ shape shape::normalize_standard() const ...@@ -338,6 +354,10 @@ shape shape::normalize_standard() const
shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
}
assert(l.size() == this->lens().size()); assert(l.size() == this->lens().size());
auto perm = find_permutation(*this); auto perm = find_permutation(*this);
return shape::from_permutation(t, l, perm); return shape::from_permutation(t, l, perm);
...@@ -345,6 +365,10 @@ shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const ...@@ -345,6 +365,10 @@ shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
shape shape::with_lens(const std::vector<std::size_t>& l) const shape shape::with_lens(const std::vector<std::size_t>& l) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
}
return this->with_lens(this->type(), l); return this->with_lens(this->type(), l);
} }
...@@ -361,19 +385,35 @@ std::string shape::type_string() const { return name(this->type()); } ...@@ -361,19 +385,35 @@ std::string shape::type_string() const { return name(this->type()); }
bool operator==(const shape& x, const shape& y) bool operator==(const shape& x, const shape& y)
{ {
return x.impl == y.impl or (x.type() == y.type() and x.lens() == y.lens() and if(x.dynamic() and y.dynamic())
{
return x.impl == y.impl or (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and
x.sub_shapes() == y.sub_shapes());
}
return x.impl == y.impl or
(x.dynamic() == y.dynamic() and x.type() == y.type() and x.lens() == y.lens() and
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes()); x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
} }
bool operator!=(const shape& x, const shape& y) { return !(x == y); } bool operator!=(const shape& x, const shape& y) { return !(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x) std::ostream& operator<<(std::ostream& os, const shape& x)
{ {
if(x.sub_shapes().empty()) if(x.sub_shapes().empty())
{
if(x.dynamic())
{
os << "dynamic, ";
os << x.type_string() << ", ";
os << "{" << to_string_range(x.dyn_dims()) << "}";
}
else
{ {
os << x.type_string() << ", "; os << x.type_string() << ", ";
os << "{" << to_string_range(x.lens()) << "}, "; os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string_range(x.strides()) << "}"; os << "{" << to_string_range(x.strides()) << "}";
} }
}
else else
{ {
os << "[" << to_string_range(x.sub_shapes()) << "]"; os << "[" << to_string_range(x.sub_shapes()) << "]";
...@@ -396,12 +436,25 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; } ...@@ -396,12 +436,25 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
void migraphx_to_value(value& v, const shape& s) void migraphx_to_value(value& v, const shape& s)
{ {
value result; value result;
if(s.dynamic())
{
result["dynamic"] = migraphx::to_value(s.dynamic());
result["type"] = migraphx::to_value(s.type_string());
result["min_dyn_dims"] = migraphx::to_value(s.min_dyn_dims());
result["max_dyn_dims"] = migraphx::to_value(s.max_dyn_dims());
result["opt_dyn_dims"] = migraphx::to_value(s.opt_dyn_dims());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
}
else
{
result["type"] = migraphx::to_value(s.type_string()); result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens()); result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides()); result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes()); result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
}
v = result; v = result;
} }
void migraphx_from_value(const value& v, shape& s) void migraphx_from_value(const value& v, shape& s)
{ {
auto t = v.at("type").get_string(); auto t = v.at("type").get_string();
...@@ -410,11 +463,29 @@ void migraphx_from_value(const value& v, shape& s) ...@@ -410,11 +463,29 @@ void migraphx_from_value(const value& v, shape& s)
s = shape{migraphx::from_value<std::vector<migraphx::shape>>(v.at("sub_shapes"))}; s = shape{migraphx::from_value<std::vector<migraphx::shape>>(v.at("sub_shapes"))};
} }
else else
{
auto dyn = v.at("dynamic").get_bool();
if(dyn)
{
auto mins = v.at("min_dyn_dims").to_vector<std::size_t>();
auto maxes = v.at("max_dyn_dims").to_vector<std::size_t>();
auto opts = v.at("opt_dyn_dims").to_vector<std::size_t>();
assert(mins.size() == maxes.size() == opts.size());
auto num_dims = mins.size();
std::vector<shape::dynamic_dimension> dyn_dims{num_dims};
for(int i = 0; i < mins.size(); ++i)
{
dyn_dims.at(i) = {mins[i], maxes[i], opts[i]};
}
s = shape{shape::parse_type(t), dyn_dims};
}
else
{ {
s = shape{shape::parse_type(t), s = shape{shape::parse_type(t),
v.at("lens").to_vector<std::size_t>(), v.at("lens").to_vector<std::size_t>(),
v.at("strides").to_vector<std::size_t>()}; v.at("strides").to_vector<std::size_t>()};
} }
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment