Commit 8bf8e161 authored by charlie's avatar charlie
Browse files

Shape class changes to handle dynamic

* More throw errors for functions that don't make sense for dynamic shape
* Print output changes
* Serialization changes
parent 4680518a
......@@ -64,8 +64,21 @@ struct shape
std::size_t min = 0;
std::size_t max = 0;
std::size_t opt = 0;
bool is_fixed() const { return min == max; };
bool has_optimal() const { return opt != 0; };
bool is_fixed() const { return min == max; }
bool has_optimal() const { return opt != 0; }
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y)
{
return (x.min == y.min and x.max == y.max and x.opt == y.opt);
}
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y)
{
return !(x == y);
}
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x)
{
os << "[" << x.min << ", " << x.max << ", " << x.opt << "]";
return os;
}
};
static const std::vector<type_t>& types();
......@@ -100,11 +113,28 @@ struct shape
type_t type() const;
const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const;
/*!
* Return the number of elements in the tensor. Multiply the lengths.
*/
std::size_t elements() const;
/*!
* Return the number of total bytes used for storage of the tensor data.
* Includes subshapes.
*/
std::size_t bytes() const;
/*!
* Return the size of the type of the main shape.
* Returns 0 if there are subshapes.
*/
std::size_t type_size() const;
const std::vector<dynamic_dimension>& dyn_dims() const;
const std::vector<std::size_t>& min_dyn_dims() const;
const std::vector<std::size_t>& max_dyn_dims() const;
const std::vector<std::size_t>& opt_dyn_dims() const;
/// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const;
......
......@@ -184,6 +184,10 @@ std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: bytes() called on dynamic shape");
}
if(this->sub_shapes().empty())
{
std::size_t n = 0;
......@@ -210,18 +214,30 @@ const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return im
std::size_t shape::index(std::initializer_list<std::size_t> l) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
std::size_t shape::index(const std::vector<std::size_t>& l) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
std::size_t shape::index(std::size_t i) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(this->lens().size() == this->strides().size());
if(this->standard())
return i;
......@@ -338,6 +354,10 @@ shape shape::normalize_standard() const
shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
}
assert(l.size() == this->lens().size());
auto perm = find_permutation(*this);
return shape::from_permutation(t, l, perm);
......@@ -345,6 +365,10 @@ shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
shape shape::with_lens(const std::vector<std::size_t>& l) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
}
return this->with_lens(this->type(), l);
}
......@@ -361,18 +385,34 @@ std::string shape::type_string() const { return name(this->type()); }
bool operator==(const shape& x, const shape& y)
{
return x.impl == y.impl or (x.type() == y.type() and x.lens() == y.lens() and
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
if(x.dynamic() and y.dynamic())
{
return x.impl == y.impl or (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and
x.sub_shapes() == y.sub_shapes());
}
return x.impl == y.impl or
(x.dynamic() == y.dynamic() and x.type() == y.type() and x.lens() == y.lens() and
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
}
bool operator!=(const shape& x, const shape& y) { return !(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x)
{
if(x.sub_shapes().empty())
{
os << x.type_string() << ", ";
os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string_range(x.strides()) << "}";
if(x.dynamic())
{
os << "dynamic, ";
os << x.type_string() << ", ";
os << "{" << to_string_range(x.dyn_dims()) << "}";
}
else
{
os << x.type_string() << ", ";
os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string_range(x.strides()) << "}";
}
}
else
{
......@@ -396,12 +436,25 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
void migraphx_to_value(value& v, const shape& s)
{
value result;
result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
v = result;
if(s.dynamic())
{
result["dynamic"] = migraphx::to_value(s.dynamic());
result["type"] = migraphx::to_value(s.type_string());
result["min_dyn_dims"] = migraphx::to_value(s.min_dyn_dims());
result["max_dyn_dims"] = migraphx::to_value(s.max_dyn_dims());
result["opt_dyn_dims"] = migraphx::to_value(s.opt_dyn_dims());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
}
else
{
result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
}
v = result;
}
void migraphx_from_value(const value& v, shape& s)
{
auto t = v.at("type").get_string();
......@@ -411,9 +464,27 @@ void migraphx_from_value(const value& v, shape& s)
}
else
{
s = shape{shape::parse_type(t),
v.at("lens").to_vector<std::size_t>(),
v.at("strides").to_vector<std::size_t>()};
auto dyn = v.at("dynamic").get_bool();
if(dyn)
{
auto mins = v.at("min_dyn_dims").to_vector<std::size_t>();
auto maxes = v.at("max_dyn_dims").to_vector<std::size_t>();
auto opts = v.at("opt_dyn_dims").to_vector<std::size_t>();
assert(mins.size() == maxes.size() == opts.size());
auto num_dims = mins.size();
std::vector<shape::dynamic_dimension> dyn_dims{num_dims};
for(int i = 0; i < mins.size(); ++i)
{
dyn_dims.at(i) = {mins[i], maxes[i], opts[i]};
}
s = shape{shape::parse_type(t), dyn_dims};
}
else
{
s = shape{shape::parse_type(t),
v.at("lens").to_vector<std::size_t>(),
v.at("strides").to_vector<std::size_t>()};
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment