Unverified Commit 1c0b2a4a authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Dyn shape update (#1199)

Initial sketch for changes to shape to handle dynamic dimensions
parent bd503d89
......@@ -71,6 +71,11 @@ struct check_shapes
return end - begin;
}
/*!
* Check if the number of shape objects is equal to atleast one of the
* given sizes.
* \param ns template parameter pack of sizes to check against
*/
template <class... Ts>
const check_shapes& has(Ts... ns) const
{
......
......@@ -55,8 +55,14 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
return result;
}
/*!
* Returns the permutation needed to apply to the shape to undo the current permutation
*/
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation);
/*!
* Finds the permutation most likely from a transpose operator that has been applied to the shape.
*/
std::vector<int64_t> find_permutation(const shape& s);
std::vector<int64_t> find_permutation(const std::vector<shape>& shapes);
......
......@@ -82,6 +82,23 @@ struct shape
{
};
struct dynamic_dimension
{
std::size_t min = 0;
std::size_t max = 0;
std::size_t opt = 0;
template <class Self, class F>
static auto reflect(Self& self, F f);
bool is_fixed() const;
bool has_optimal() const;
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y);
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x);
};
static const std::vector<type_t>& types();
static std::string name(type_t t);
......@@ -92,6 +109,12 @@ struct shape
shape(type_t t, std::vector<std::size_t> l);
shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
// Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape(type_t t, std::initializer_list<std::size_t> d);
shape(type_t t, std::vector<dynamic_dimension> dims);
template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
{
......@@ -112,10 +135,44 @@ struct shape
type_t type() const;
const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const;
/*!
* Return the number of elements in the tensor.
*/
std::size_t elements() const;
/*!
* Return the number of total bytes used for storage of the tensor data; includes subshapes.
* For dynamic shape, returns the maximum number of bytes presuming a packed shape.
*/
std::size_t bytes() const;
/*!
* Return the size of the type of the main shape.
* Returns 0 if there are subshapes.
*/
std::size_t type_size() const;
const std::vector<dynamic_dimension>& dyn_dims() const;
/*!
* Minimum lengths for dynamic shape.
* lens() for fixed shape.
*/
std::vector<std::size_t> min_lens() const;
/*!
* Maximum lengths for dynamic shape.
* lens() for fixed shape.
*/
std::vector<std::size_t> max_lens() const;
/*!
* Optimum lengths for dynamic shape.
* lens() for fixed shape.
*/
std::vector<std::size_t> opt_lens() const;
/// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const;
/// Map multiple indices to space index
......@@ -136,19 +193,27 @@ struct shape
std::vector<std::size_t> multi(std::size_t i) const;
void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const;
/// Returns true if the shape is packed with no padding
/// Returns true if the shape is packed (number of elements and buffer size the same) with no
/// padding
bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending
/// order
bool transposed() const;
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
bool broadcasted() const;
/// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// not transposed.
bool standard() const;
/// Returns true if all strides are equal to 0 (scalar tensor)
bool scalar() const;
/// Return true if the shape is dynamic
bool dynamic() const;
shape normalize_standard() const;
shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
......@@ -252,6 +317,11 @@ struct shape
const std::vector<shape>& sub_shapes() const;
/*!
* Returns the number of elements in the data buffer.
* For a dynamic shape, returns the maximum number of elements of the data buffer and assumes it
* is packed.
*/
std::size_t element_space() const;
private:
......
......@@ -26,6 +26,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/ranges.hpp>
#include <numeric>
#include <algorithm>
#include <functional>
......@@ -65,13 +66,21 @@ struct shape_impl
std::is_sorted(m_strides.rbegin(), m_strides.rend());
}
shape_impl(shape::type_t t, std::vector<shape::dynamic_dimension> dims)
: m_type(t), m_dyn_dims(std::move(dims))
{
}
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape::type_t m_type;
std::vector<std::size_t> m_lens = {};
std::vector<std::size_t> m_strides = {};
std::vector<shape> m_shapes = {};
bool m_standard = false;
std::vector<shape::dynamic_dimension> m_dyn_dims = {};
void calculate_strides()
{
m_strides.clear();
......@@ -87,6 +96,12 @@ struct shape_impl
std::size_t element_space() const
{
if(not m_dyn_dims.empty())
{
auto maxes = max_lens();
return std::accumulate(maxes.begin(), maxes.end(), std::size_t{1}, std::multiplies<>());
}
assert(m_lens.size() == m_strides.size());
if(m_lens.empty())
return 0;
......@@ -101,6 +116,11 @@ struct shape_impl
std::size_t elements() const
{
if(not m_dyn_dims.empty())
{
MIGRAPHX_THROW("SHAPE: elements() called on dynamic shape");
}
assert(m_lens.size() == m_strides.size());
if(m_lens.empty())
return 0;
......@@ -108,6 +128,35 @@ struct shape_impl
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
}
std::vector<std::size_t> min_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](shape::dynamic_dimension x) { return x.min; });
return ret;
}
std::vector<std::size_t> max_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](shape::dynamic_dimension x) { return x.max; });
return ret;
}
std::vector<std::size_t> opt_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](shape::dynamic_dimension x) { return x.opt; });
return ret;
}
// Does the shape skip over elements?
bool skips() const
{
......@@ -165,6 +214,16 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
{
}
shape::shape(type_t t, std::initializer_list<std::size_t> d)
: shape::shape(t, std::vector<std::size_t>{d.begin(), d.end()})
{
}
shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
: impl(std::make_shared<shape_impl>(t, std::move(dims)))
{
}
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
......@@ -180,9 +239,13 @@ shape shape::from_permutation(type_t t,
}
shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const
{
if(this->sub_shapes().empty())
......@@ -199,6 +262,7 @@ std::size_t shape::bytes() const
[&](auto x, auto y) { return x + y.bytes(); });
}
}
std::size_t shape::type_size() const
{
std::size_t n = 0;
......@@ -206,20 +270,35 @@ std::size_t shape::type_size() const
this->visit_type([&](auto as) { n = as.size(); });
return n;
}
std::size_t shape::index(std::initializer_list<std::size_t> l) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
std::size_t shape::index(const std::vector<std::size_t>& l) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
std::size_t shape::index(std::size_t i) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(this->lens().size() == this->strides().size());
if(this->standard())
return i;
......@@ -267,12 +346,20 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
bool shape::packed() const
{
if(this->dynamic())
{
return false;
}
return this->sub_shapes().empty() and not impl->skips() and
this->elements() == this->element_space();
}
bool shape::transposed() const
{
if(this->dynamic())
{
return false;
}
if(this->broadcasted())
{
// TODO: Use a filter_iterator instead
......@@ -292,6 +379,10 @@ bool shape::transposed() const
bool shape::broadcasted() const
{
if(this->dynamic())
{
return false;
}
assert(this->lens().size() == this->strides().size());
return std::any_of(
this->strides().begin(), this->strides().end(), [](auto x) { return x == 0; });
......@@ -299,6 +390,10 @@ bool shape::broadcasted() const
bool shape::scalar() const
{
if(this->dynamic())
{
return false;
}
assert(this->lens().size() == this->strides().size());
// if any stride > 0, then accumulate will return false
return this->sub_shapes().empty() and
......@@ -317,6 +412,10 @@ shape shape::normalize_standard() const
shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
}
assert(l.size() == this->lens().size());
auto perm = find_permutation(*this);
return shape::from_permutation(t, l, perm);
......@@ -324,6 +423,10 @@ shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
shape shape::with_lens(const std::vector<std::size_t>& l) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
}
return this->with_lens(this->type(), l);
}
......@@ -338,20 +441,80 @@ std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); }
bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
std::vector<std::size_t> shape::min_lens() const
{
return this->dynamic() ? impl->min_lens() : this->lens();
}
std::vector<std::size_t> shape::max_lens() const
{
return this->dynamic() ? impl->max_lens() : this->lens();
}
std::vector<std::size_t> shape::opt_lens() const
{
return this->dynamic() ? impl->opt_lens() : this->lens();
}
bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; }
bool shape::dynamic_dimension::has_optimal() const { return opt != 0; }
template <class Self, class F>
auto shape::dynamic_dimension::reflect(Self& self, F f)
{
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt"));
}
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
return (x.min == y.min and x.max == y.max and x.opt == y.opt);
}
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
return !(x == y);
}
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
{
os << "[" << x.min << ", " << x.max << ", " << x.opt << "]";
return os;
}
bool operator==(const shape& x, const shape& y)
{
return x.impl == y.impl or (x.type() == y.type() and x.lens() == y.lens() and
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
if(x.dynamic() and y.dynamic())
{
return x.impl == y.impl or (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and
x.sub_shapes() == y.sub_shapes());
}
return x.impl == y.impl or
(x.dynamic() == y.dynamic() and x.type() == y.type() and x.lens() == y.lens() and
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
}
bool operator!=(const shape& x, const shape& y) { return !(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x)
{
if(x.sub_shapes().empty())
{
os << x.type_string() << ", ";
os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string_range(x.strides()) << "}";
if(x.dynamic())
{
os << "dynamic, ";
os << x.type_string() << ", ";
os << "{" << to_string_range(x.dyn_dims()) << "}";
}
else
{
os << x.type_string() << ", ";
os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string_range(x.strides()) << "}";
}
}
else
{
......@@ -375,12 +538,14 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
void migraphx_to_value(value& v, const shape& s)
{
value result;
result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
v = result;
result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
result["dynamic_dimensions"] = migraphx::to_value(s.dyn_dims());
v = result;
}
void migraphx_from_value(const value& v, shape& s)
{
auto t = v.at("type").get_string();
......@@ -390,9 +555,25 @@ void migraphx_from_value(const value& v, shape& s)
}
else
{
s = shape{shape::parse_type(t),
v.at("lens").to_vector<std::size_t>(),
v.at("strides").to_vector<std::size_t>()};
if(v.at("dynamic_dimensions").empty())
{
s = shape{shape::parse_type(t),
v.at("lens").to_vector<std::size_t>(),
v.at("strides").to_vector<std::size_t>()};
}
else
{
auto v_dd = v.at("dynamic_dimensions");
std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size());
std::transform(v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](migraphx::value x) {
auto x_min = x.at("min").template to<size_t>();
auto x_max = x.at("max").template to<size_t>();
auto x_opt = x.at("opt").template to<size_t>();
return shape::dynamic_dimension{x_min, x_max, x_opt};
});
s = shape{shape::parse_type(t), dyn_dims};
}
}
}
......
......@@ -981,7 +981,8 @@ TEST_CASE(multibroadcast)
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {}};
std::vector<std::size_t> empt = {};
migraphx::shape input{migraphx::shape::float_type, empt};
throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
}
{
......
......@@ -38,7 +38,6 @@ TEST_CASE(test_shape_default)
EXPECT(s.elements() == 0);
EXPECT(s.bytes() == 0);
}
TEST_CASE(test_shape_assign)
{
migraphx::shape s1{migraphx::shape::float_type, {100, 32, 8, 8}};
......@@ -65,6 +64,118 @@ TEST_CASE(test_shape_standard)
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_min_max_opt)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 1}};
EXPECT(s.min_lens() == s.lens());
EXPECT(s.max_lens() == s.lens());
EXPECT(s.opt_lens() == s.lens());
}
TEST_CASE(test_shape_dynamic_fixed)
{
migraphx::shape s{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}, {3, 3, 0}}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
EXPECT(s.dynamic());
EXPECT(s.dyn_dims().size() == 3);
EXPECT(s.dyn_dims().at(0).is_fixed());
EXPECT(not s.dyn_dims().at(0).has_optimal());
EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2, 3});
EXPECT(s.max_lens() == std::vector<std::size_t>{2, 2, 3});
EXPECT(s.opt_lens() == std::vector<std::size_t>{0, 0, 0});
EXPECT(s.bytes() == 2 * 2 * 3 * sizeof(float));
}
TEST_CASE(test_shape_dynamic_not_fixed)
{
using migraphx::shape;
std::vector<shape::dynamic_dimension> dims = {};
dims.push_back(shape::dynamic_dimension{2, 5, 2});
dims.push_back(shape::dynamic_dimension{2, 8, 0});
migraphx::shape s{migraphx::shape::float_type, dims};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
EXPECT(s.dynamic());
EXPECT(s.dyn_dims().size() == 2);
EXPECT(not s.dyn_dims().at(0).is_fixed());
EXPECT(s.dyn_dims().at(0).has_optimal());
EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2});
EXPECT(s.max_lens() == std::vector<std::size_t>{5, 8});
EXPECT(s.opt_lens() == std::vector<std::size_t>{2, 0});
EXPECT(s.bytes() == 5 * 8 * sizeof(float));
}
TEST_CASE(test_shape_dynamic_compares)
{
using migraphx::shape;
auto a = shape::dynamic_dimension{2, 5, 2};
auto b = a;
auto c = shape::dynamic_dimension{2, 5, 2};
auto d = shape::dynamic_dimension{3, 8, 4};
EXPECT(a == b);
EXPECT(a == c);
EXPECT(a != d);
migraphx::shape s0{shape::float_type, {a, d}};
migraphx::shape s1 = s0;
migraphx::shape s2{shape::float_type, {a, d}};
migraphx::shape s3{shape::int32_type, {a}};
EXPECT(s0 == s1);
EXPECT(s0 == s2);
EXPECT(s0 != s3);
std::stringstream ss0;
std::stringstream ss1;
std::stringstream ss3;
ss0 << s0;
ss1 << s1;
ss3 << s3;
EXPECT(ss0.str() == ss1.str());
EXPECT(ss0.str() != ss3.str());
}
TEST_CASE(test_shape_dynamic_errors)
{
using migraphx::shape;
std::vector<shape::dynamic_dimension> dims = {};
dims.push_back(shape::dynamic_dimension{2, 5, 2});
dims.push_back(shape::dynamic_dimension{2, 8, 0});
migraphx::shape s{shape::float_type, dims};
EXPECT(test::throws([&] { s.elements(); }));
EXPECT(test::throws([&] { s.index({0, 1}); }));
EXPECT(test::throws([&] { s.index(1); }));
EXPECT(test::throws([&] { s.index(std::vector<std::size_t>{0, 1}); }));
EXPECT(test::throws([&] { s.with_lens({3, 5}); }));
EXPECT(test::throws([&] { s.with_lens(shape::float_type, {3, 5}); }));
}
TEST_CASE(test_shape_dynamic_serialize)
{
using migraphx::shape;
std::vector<shape::dynamic_dimension> dims1 = {};
dims1.push_back(shape::dynamic_dimension{2, 5, 2});
dims1.push_back(shape::dynamic_dimension{2, 8, 0});
migraphx::shape s1{shape::float_type, dims1};
auto v1 = migraphx::to_value(s1);
std::vector<shape::dynamic_dimension> dims2 = {};
dims2.push_back(shape::dynamic_dimension{2, 5, 2});
migraphx::shape s2{shape::uint64_type, dims2};
auto v2 = migraphx::to_value(s2);
EXPECT(v1 != v2);
auto s3 = migraphx::from_value<shape>(v1);
EXPECT(s3 == s1);
auto s4 = migraphx::from_value<shape>(v2);
EXPECT(s4 == s2);
EXPECT(s3 != s4);
}
TEST_CASE(test_shape_packed)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {2, 1}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment