Commit c0e18e78 authored by charlie's avatar charlie
Browse files

Dynamic shape handling in shape object

parent 764273e4
...@@ -48,6 +48,11 @@ struct check_shapes ...@@ -48,6 +48,11 @@ struct check_shapes
return end - begin; return end - begin;
} }
/*!
* Check if the number of shape objects is equal to atleast one of the
* given sizes.
* \param ns template parameter pack of sizes to check against
*/
template <class... Ts> template <class... Ts>
const check_shapes& has(Ts... ns) const const check_shapes& has(Ts... ns) const
{ {
......
...@@ -59,6 +59,15 @@ struct shape ...@@ -59,6 +59,15 @@ struct shape
{ {
}; };
struct dynamic_dimension
{
std::size_t min = 0;
std::size_t max = 0;
std::size_t opt = 0;
bool is_fixed() const { return min == max; };
bool has_optimal() const { return opt != 0; };
};
static const std::vector<type_t>& types(); static const std::vector<type_t>& types();
static std::string name(type_t t); static std::string name(type_t t);
...@@ -69,6 +78,8 @@ struct shape ...@@ -69,6 +78,8 @@ struct shape
shape(type_t t, std::vector<std::size_t> l); shape(type_t t, std::vector<std::size_t> l);
shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s); shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
shape(type_t t, std::vector<dynamic_dimension> dims);
template <class Range> template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end())) shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
{ {
...@@ -93,6 +104,8 @@ struct shape ...@@ -93,6 +104,8 @@ struct shape
std::size_t bytes() const; std::size_t bytes() const;
std::size_t type_size() const; std::size_t type_size() const;
const std::vector<dynamic_dimension>& dyn_dims() const;
/// Map multiple indices to space index /// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
/// Map multiple indices to space index /// Map multiple indices to space index
...@@ -115,17 +128,24 @@ struct shape ...@@ -115,17 +128,24 @@ struct shape
/// Returns true if the shape is packed with no padding /// Returns true if the shape is packed with no padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending /// Returns true is the shape has been transposed. That is the strides are not in descending
/// order /// order
bool transposed() const; bool transposed() const;
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero /// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
bool broadcasted() const; bool broadcasted() const;
/// Returns true if the shape is in its standard format. That is, the shape is both packed and /// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// not transposed. /// not transposed.
bool standard() const; bool standard() const;
/// Returns true if all strides are equal to 0 (scalar tensor) /// Returns true if all strides are equal to 0 (scalar tensor)
bool scalar() const; bool scalar() const;
/// Return true if the shape is dynamic
bool dynamic() const;
shape normalize_standard() const; shape normalize_standard() const;
shape with_lens(type_t t, const std::vector<std::size_t>& l) const; shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
...@@ -225,6 +245,7 @@ struct shape ...@@ -225,6 +245,7 @@ struct shape
const std::vector<shape>& sub_shapes() const; const std::vector<shape>& sub_shapes() const;
/// size of the data buffer
std::size_t element_space() const; std::size_t element_space() const;
private: private:
......
...@@ -13,11 +13,18 @@ shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation) ...@@ -13,11 +13,18 @@ shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation)
return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)}; return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)};
} }
/*!
* Inverts the permutation using the less_than operator
*/
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation) std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{ {
return sort_permutation(permutation, std::less<>{}); return sort_permutation(permutation, std::less<>{});
} }
/*!
* Computes a permutation for the lengths based on decesending stride order.
* Compares the lengths if the strides are the same.
*/
std::vector<int64_t> find_permutation(const shape& s) std::vector<int64_t> find_permutation(const shape& s)
{ {
std::vector<std::int64_t> result(s.lens().size()); std::vector<std::int64_t> result(s.lens().size());
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/ranges.hpp>
#include <numeric> #include <numeric>
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
...@@ -45,11 +46,20 @@ struct shape_impl ...@@ -45,11 +46,20 @@ struct shape_impl
} }
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {} shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape_impl(shape::type_t t, std::vector<shape::dynamic_dimension> dims)
: m_type(t), m_dynamic(true), m_dyn_dims(std::move(dims))
{
}
shape::type_t m_type; shape::type_t m_type;
std::vector<std::size_t> m_lens = {}; std::vector<std::size_t> m_lens = {};
std::vector<std::size_t> m_strides = {}; std::vector<std::size_t> m_strides = {};
std::vector<shape> m_shapes = {}; std::vector<shape> m_shapes = {};
bool m_standard = false; bool m_standard = false;
bool m_dynamic = false;
std::vector<shape::dynamic_dimension> m_dyn_dims = {};
void calculate_strides() void calculate_strides()
{ {
...@@ -66,6 +76,11 @@ struct shape_impl ...@@ -66,6 +76,11 @@ struct shape_impl
std::size_t element_space() const std::size_t element_space() const
{ {
if(m_dynamic)
{
MIGRAPHX_THROW("SHAPE: element_space() called on dynamic shape");
}
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
if(m_lens.empty()) if(m_lens.empty())
return 0; return 0;
...@@ -80,6 +95,11 @@ struct shape_impl ...@@ -80,6 +95,11 @@ struct shape_impl
std::size_t elements() const std::size_t elements() const
{ {
if(m_dynamic)
{
MIGRAPHX_THROW("SHAPE: elements() called on dynamic shape");
}
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
if(m_lens.empty()) if(m_lens.empty())
return 0; return 0;
...@@ -137,6 +157,11 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) ...@@ -137,6 +157,11 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {} shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
: impl(std::make_shared<shape_impl>(t, std::move(dims)))
{
}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {} shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
shape shape::from_permutation(type_t t, shape shape::from_permutation(type_t t,
...@@ -150,9 +175,13 @@ shape shape::from_permutation(type_t t, ...@@ -150,9 +175,13 @@ shape shape::from_permutation(type_t t,
} }
shape::type_t shape::type() const { return impl->m_type; } shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; } const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; } const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
std::size_t shape::elements() const { return impl->elements(); } std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const std::size_t shape::bytes() const
{ {
if(this->sub_shapes().empty()) if(this->sub_shapes().empty())
...@@ -176,6 +205,9 @@ std::size_t shape::type_size() const ...@@ -176,6 +205,9 @@ std::size_t shape::type_size() const
this->visit_type([&](auto as) { n = as.size(); }); this->visit_type([&](auto as) { n = as.size(); });
return n; return n;
} }
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
std::size_t shape::index(std::initializer_list<std::size_t> l) const std::size_t shape::index(std::initializer_list<std::size_t> l) const
{ {
assert(l.size() <= this->lens().size()); assert(l.size() <= this->lens().size());
...@@ -235,13 +267,23 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end ...@@ -235,13 +267,23 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
}); });
} }
bool shape::dynamic() const { return (impl->m_dynamic); }
bool shape::packed() const bool shape::packed() const
{ {
if(this->dynamic())
{
return false;
}
return this->sub_shapes().empty() and this->elements() == this->element_space(); return this->sub_shapes().empty() and this->elements() == this->element_space();
} }
bool shape::transposed() const bool shape::transposed() const
{ {
if(this->dynamic())
{
return false;
}
if(this->broadcasted()) if(this->broadcasted())
{ {
// TODO: Use a filter_iterator instead // TODO: Use a filter_iterator instead
...@@ -261,6 +303,10 @@ bool shape::transposed() const ...@@ -261,6 +303,10 @@ bool shape::transposed() const
bool shape::broadcasted() const bool shape::broadcasted() const
{ {
if(this->dynamic())
{
return false;
}
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::accumulate(this->strides().begin(), return std::accumulate(this->strides().begin(),
this->strides().end(), this->strides().end(),
...@@ -270,6 +316,10 @@ bool shape::broadcasted() const ...@@ -270,6 +316,10 @@ bool shape::broadcasted() const
bool shape::scalar() const bool shape::scalar() const
{ {
if(this->dynamic())
{
return false;
}
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
// if any stride > 0, then accumulate will return false // if any stride > 0, then accumulate will return false
return this->sub_shapes().empty() and return this->sub_shapes().empty() and
......
...@@ -958,7 +958,7 @@ TEST_CASE(multibroadcast) ...@@ -958,7 +958,7 @@ TEST_CASE(multibroadcast)
} }
{ {
std::vector<std::size_t> lens{4, 1, 3}; std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {}}; migraphx::shape input{migraphx::shape::float_type};
throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input); throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
} }
{ {
......
...@@ -42,6 +42,39 @@ TEST_CASE(test_shape_standard) ...@@ -42,6 +42,39 @@ TEST_CASE(test_shape_standard)
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
TEST_CASE(test_shape_dynamic_fixed)
{
std::vector<migraphx::shape::dynamic_dimension> dims = {};
dims.emplace_back(migraphx::shape::dynamic_dimension{2, 2, 0});
dims.emplace_back(migraphx::shape::dynamic_dimension{2, 2, 0});
dims.emplace_back(migraphx::shape::dynamic_dimension{3, 3, 0});
migraphx::shape s{migraphx::shape::float_type, dims};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
EXPECT(s.dynamic());
EXPECT(s.dyn_dims().size() == 3);
EXPECT(s.dyn_dims().at(0).is_fixed());
EXPECT(not s.dyn_dims().at(0).has_optimal());
}
TEST_CASE(test_shape_dynamic_not_fixed)
{
std::vector<migraphx::shape::dynamic_dimension> dims = {};
dims.emplace_back(migraphx::shape::dynamic_dimension{2, 5, 2});
dims.emplace_back(migraphx::shape::dynamic_dimension{2, 8, 0});
migraphx::shape s{migraphx::shape::float_type, dims};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
EXPECT(s.dynamic());
EXPECT(s.dyn_dims().size() == 2);
EXPECT(not s.dyn_dims().at(0).is_fixed());
EXPECT(s.dyn_dims().at(0).has_optimal());
}
TEST_CASE(test_shape_packed) TEST_CASE(test_shape_packed)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {2, 1}}; migraphx::shape s{migraphx::shape::float_type, {2, 2}, {2, 1}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment