Commit ac0224a9 authored by charlie's avatar charlie
Browse files

Use std::initializer_list in constructor

Reverts the dyn_data struct change
Should get around the ambiguous braced initialization list error
parent 2e27a823
...@@ -89,13 +89,6 @@ struct shape ...@@ -89,13 +89,6 @@ struct shape
} }
}; };
// Avoid ambiguous constructor
struct dyn_data
{
type_t t;
std::vector<dynamic_dimension> dims;
};
static const std::vector<type_t>& types(); static const std::vector<type_t>& types();
static std::string name(type_t t); static std::string name(type_t t);
...@@ -106,7 +99,12 @@ struct shape ...@@ -106,7 +99,12 @@ struct shape
shape(type_t t, std::vector<std::size_t> l); shape(type_t t, std::vector<std::size_t> l);
shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s); shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
explicit shape(dyn_data data); // Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape(type_t t, std::initializer_list<std::size_t> d);
typedef std::vector<dynamic_dimension> dynamic_dimensions;
shape(type_t t, dynamic_dimensions dims);
template <class Range> template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end())) shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
...@@ -130,7 +128,7 @@ struct shape ...@@ -130,7 +128,7 @@ struct shape
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
/*! /*!
* Return the number of elements in the tensor. Multiply the lengths. * Return the number of elements in the tensor.
*/ */
std::size_t elements() const; std::size_t elements() const;
......
...@@ -45,19 +45,18 @@ struct shape_impl ...@@ -45,19 +45,18 @@ struct shape_impl
std::is_sorted(m_strides.rbegin(), m_strides.rend()); std::is_sorted(m_strides.rbegin(), m_strides.rend());
} }
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {} shape_impl(shape::type_t t, shape::dynamic_dimensions dims)
: m_type(t), m_dyn_dims(std::move(dims))
explicit shape_impl(shape::dyn_data data)
: m_type(data.t), m_dynamic(true), m_dyn_dims(std::move(data.dims))
{ {
} }
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape::type_t m_type; shape::type_t m_type;
std::vector<std::size_t> m_lens = {}; std::vector<std::size_t> m_lens = {};
std::vector<std::size_t> m_strides = {}; std::vector<std::size_t> m_strides = {};
std::vector<shape> m_shapes = {}; std::vector<shape> m_shapes = {};
bool m_standard = false; bool m_standard = false;
bool m_dynamic = false;
std::vector<shape::dynamic_dimension> m_dyn_dims = {}; std::vector<shape::dynamic_dimension> m_dyn_dims = {};
...@@ -76,7 +75,7 @@ struct shape_impl ...@@ -76,7 +75,7 @@ struct shape_impl
std::size_t element_space() const std::size_t element_space() const
{ {
if(m_dynamic) if(not m_dyn_dims.empty())
{ {
MIGRAPHX_THROW("SHAPE: element_space() called on dynamic shape"); MIGRAPHX_THROW("SHAPE: element_space() called on dynamic shape");
} }
...@@ -95,7 +94,7 @@ struct shape_impl ...@@ -95,7 +94,7 @@ struct shape_impl
std::size_t elements() const std::size_t elements() const
{ {
if(m_dynamic) if(not m_dyn_dims.empty())
{ {
MIGRAPHX_THROW("SHAPE: elements() called on dynamic shape"); MIGRAPHX_THROW("SHAPE: elements() called on dynamic shape");
} }
...@@ -155,9 +154,17 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) ...@@ -155,9 +154,17 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
{ {
} }
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {} shape::shape(type_t t, std::initializer_list<std::size_t> d)
: shape::shape(t, std::vector<std::size_t>{d.begin(), d.end()})
{
}
shape::shape(type_t t, dynamic_dimensions dims)
: impl(std::make_shared<shape_impl>(t, std::move(dims)))
{
}
shape::shape(dyn_data data) : impl(std::make_shared<shape_impl>(std::move(data))) {} shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {} shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
...@@ -199,6 +206,7 @@ std::size_t shape::bytes() const ...@@ -199,6 +206,7 @@ std::size_t shape::bytes() const
[&](auto x, auto y) { return x + y.bytes(); }); [&](auto x, auto y) { return x + y.bytes(); });
} }
} }
std::size_t shape::type_size() const std::size_t shape::type_size() const
{ {
std::size_t n = 0; std::size_t n = 0;
...@@ -378,7 +386,7 @@ std::size_t shape::element_space() const { return impl->element_space(); } ...@@ -378,7 +386,7 @@ std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); } std::string shape::type_string() const { return name(this->type()); }
bool shape::dynamic() const { return impl->m_dynamic; } bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; } const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
...@@ -502,12 +510,12 @@ void migraphx_from_value(const value& v, shape& s) ...@@ -502,12 +510,12 @@ void migraphx_from_value(const value& v, shape& s)
auto opts = v.at("opt_dyn_dims").to_vector<std::size_t>(); auto opts = v.at("opt_dyn_dims").to_vector<std::size_t>();
assert(mins.size() == maxes.size() and maxes.size() == opts.size()); assert(mins.size() == maxes.size() and maxes.size() == opts.size());
auto num_dims = mins.size(); auto num_dims = mins.size();
std::vector<shape::dynamic_dimension> dyn_dims(num_dims); shape::dynamic_dimensions dyn_dims(num_dims);
for(int i = 0; i < num_dims; ++i) for(int i = 0; i < num_dims; ++i)
{ {
dyn_dims.at(i) = shape::dynamic_dimension{mins[i], maxes[i], opts[i]}; dyn_dims.at(i) = shape::dynamic_dimension{mins[i], maxes[i], opts[i]};
} }
s = shape{migraphx::shape::dyn_data{shape::parse_type(t), dyn_dims}}; s = shape{shape::parse_type(t), dyn_dims};
} }
else else
{ {
......
...@@ -43,7 +43,7 @@ TEST_CASE(test_shape_standard) ...@@ -43,7 +43,7 @@ TEST_CASE(test_shape_standard)
TEST_CASE(test_shape_dynamic_fixed) TEST_CASE(test_shape_dynamic_fixed)
{ {
migraphx::shape s{{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}, {3, 3, 0}}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}, {3, 3, 0}}};
EXPECT(not s.standard()); EXPECT(not s.standard());
EXPECT(not s.packed()); EXPECT(not s.packed());
EXPECT(not s.transposed()); EXPECT(not s.transposed());
...@@ -56,10 +56,10 @@ TEST_CASE(test_shape_dynamic_fixed) ...@@ -56,10 +56,10 @@ TEST_CASE(test_shape_dynamic_fixed)
TEST_CASE(test_shape_dynamic_not_fixed) TEST_CASE(test_shape_dynamic_not_fixed)
{ {
std::vector<migraphx::shape::dynamic_dimension> dims = {}; migraphx::shape::dynamic_dimensions dims = {};
dims.emplace_back(2, 5, 2); dims.emplace_back(2, 5, 2);
dims.emplace_back(2, 8, 0); dims.emplace_back(2, 8, 0);
migraphx::shape s{migraphx::shape::dyn_data{migraphx::shape::float_type, dims}}; migraphx::shape s{migraphx::shape::float_type, dims};
EXPECT(not s.standard()); EXPECT(not s.standard());
EXPECT(not s.packed()); EXPECT(not s.packed());
EXPECT(not s.transposed()); EXPECT(not s.transposed());
...@@ -80,10 +80,10 @@ TEST_CASE(test_shape_dynamic_compares) ...@@ -80,10 +80,10 @@ TEST_CASE(test_shape_dynamic_compares)
EXPECT(a == c); EXPECT(a == c);
EXPECT(a != d); EXPECT(a != d);
migraphx::shape s0{migraphx::shape::dyn_data{migraphx::shape::float_type, {a, d}}}; migraphx::shape s0{migraphx::shape::float_type, {a, d}};
migraphx::shape s1 = s0; migraphx::shape s1 = s0;
migraphx::shape s2{migraphx::shape::dyn_data{migraphx::shape::float_type, {a, d}}}; migraphx::shape s2{migraphx::shape::float_type, {a, d}};
migraphx::shape s3{{migraphx::shape::int32_type, {a}}}; migraphx::shape s3{migraphx::shape::int32_type, {a}};
EXPECT(s0 == s1); EXPECT(s0 == s1);
EXPECT(s0 == s2); EXPECT(s0 == s2);
EXPECT(s0 != s3); EXPECT(s0 != s3);
...@@ -94,7 +94,7 @@ TEST_CASE(test_shape_dynamic_errors) ...@@ -94,7 +94,7 @@ TEST_CASE(test_shape_dynamic_errors)
std::vector<migraphx::shape::dynamic_dimension> dims = {}; std::vector<migraphx::shape::dynamic_dimension> dims = {};
dims.emplace_back(2, 5, 2); dims.emplace_back(2, 5, 2);
dims.emplace_back(2, 8, 0); dims.emplace_back(2, 8, 0);
migraphx::shape s{migraphx::shape::dyn_data{migraphx::shape::float_type, dims}}; migraphx::shape s{migraphx::shape::float_type, dims};
EXPECT(test::throws([&] { s.element_space(); })); EXPECT(test::throws([&] { s.element_space(); }));
EXPECT(test::throws([&] { s.elements(); })); EXPECT(test::throws([&] { s.elements(); }));
EXPECT(test::throws([&] { s.bytes(); })); EXPECT(test::throws([&] { s.bytes(); }));
...@@ -109,12 +109,12 @@ TEST_CASE(test_shape_dynamic_serialize) ...@@ -109,12 +109,12 @@ TEST_CASE(test_shape_dynamic_serialize)
std::vector<migraphx::shape::dynamic_dimension> dims1 = {}; std::vector<migraphx::shape::dynamic_dimension> dims1 = {};
dims1.emplace_back(2, 5, 2); dims1.emplace_back(2, 5, 2);
dims1.emplace_back(2, 8, 0); dims1.emplace_back(2, 8, 0);
migraphx::shape s1{migraphx::shape::dyn_data{migraphx::shape::float_type, dims1}}; migraphx::shape s1{migraphx::shape::float_type, dims1};
auto v1 = migraphx::to_value(s1); auto v1 = migraphx::to_value(s1);
std::vector<migraphx::shape::dynamic_dimension> dims2 = {}; std::vector<migraphx::shape::dynamic_dimension> dims2 = {};
dims2.emplace_back(2, 5, 2); dims2.emplace_back(2, 5, 2);
migraphx::shape s2{migraphx::shape::dyn_data{migraphx::shape::uint64_type, dims2}}; migraphx::shape s2{migraphx::shape::uint64_type, dims2};
auto v2 = migraphx::to_value(s2); auto v2 = migraphx::to_value(s2);
EXPECT(v1 != v2); EXPECT(v1 != v2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment