Commit ac0224a9 authored by charlie's avatar charlie
Browse files

Use std::initializer_list in constructor

Reverts the dyn_data struct change
Should get around the ambiguous braced initialization list error
parent 2e27a823
......@@ -89,13 +89,6 @@ struct shape
}
};
// Avoid ambiguous constructor
struct dyn_data
{
type_t t;
std::vector<dynamic_dimension> dims;
};
static const std::vector<type_t>& types();
static std::string name(type_t t);
......@@ -106,7 +99,12 @@ struct shape
shape(type_t t, std::vector<std::size_t> l);
shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
explicit shape(dyn_data data);
// Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape(type_t t, std::initializer_list<std::size_t> d);
typedef std::vector<dynamic_dimension> dynamic_dimensions;
shape(type_t t, dynamic_dimensions dims);
template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
......@@ -130,7 +128,7 @@ struct shape
const std::vector<std::size_t>& strides() const;
/*!
* Return the number of elements in the tensor. Multiply the lengths.
* Return the number of elements in the tensor.
*/
std::size_t elements() const;
......
......@@ -45,19 +45,18 @@ struct shape_impl
std::is_sorted(m_strides.rbegin(), m_strides.rend());
}
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
explicit shape_impl(shape::dyn_data data)
: m_type(data.t), m_dynamic(true), m_dyn_dims(std::move(data.dims))
shape_impl(shape::type_t t, shape::dynamic_dimensions dims)
: m_type(t), m_dyn_dims(std::move(dims))
{
}
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape::type_t m_type;
std::vector<std::size_t> m_lens = {};
std::vector<std::size_t> m_strides = {};
std::vector<shape> m_shapes = {};
bool m_standard = false;
bool m_dynamic = false;
std::vector<shape::dynamic_dimension> m_dyn_dims = {};
......@@ -76,7 +75,7 @@ struct shape_impl
std::size_t element_space() const
{
if(m_dynamic)
if(not m_dyn_dims.empty())
{
MIGRAPHX_THROW("SHAPE: element_space() called on dynamic shape");
}
......@@ -95,7 +94,7 @@ struct shape_impl
std::size_t elements() const
{
if(m_dynamic)
if(not m_dyn_dims.empty())
{
MIGRAPHX_THROW("SHAPE: elements() called on dynamic shape");
}
......@@ -155,9 +154,17 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
{
}
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(type_t t, std::initializer_list<std::size_t> d)
: shape::shape(t, std::vector<std::size_t>{d.begin(), d.end()})
{
}
shape::shape(type_t t, dynamic_dimensions dims)
: impl(std::make_shared<shape_impl>(t, std::move(dims)))
{
}
shape::shape(dyn_data data) : impl(std::make_shared<shape_impl>(std::move(data))) {}
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
......@@ -199,6 +206,7 @@ std::size_t shape::bytes() const
[&](auto x, auto y) { return x + y.bytes(); });
}
}
std::size_t shape::type_size() const
{
std::size_t n = 0;
......@@ -378,7 +386,7 @@ std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); }
bool shape::dynamic() const { return impl->m_dynamic; }
bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
......@@ -502,12 +510,12 @@ void migraphx_from_value(const value& v, shape& s)
auto opts = v.at("opt_dyn_dims").to_vector<std::size_t>();
assert(mins.size() == maxes.size() and maxes.size() == opts.size());
auto num_dims = mins.size();
std::vector<shape::dynamic_dimension> dyn_dims(num_dims);
shape::dynamic_dimensions dyn_dims(num_dims);
for(int i = 0; i < num_dims; ++i)
{
dyn_dims.at(i) = shape::dynamic_dimension{mins[i], maxes[i], opts[i]};
}
s = shape{migraphx::shape::dyn_data{shape::parse_type(t), dyn_dims}};
s = shape{shape::parse_type(t), dyn_dims};
}
else
{
......
......@@ -43,7 +43,7 @@ TEST_CASE(test_shape_standard)
TEST_CASE(test_shape_dynamic_fixed)
{
migraphx::shape s{{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}, {3, 3, 0}}}};
migraphx::shape s{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}, {3, 3, 0}}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
......@@ -56,10 +56,10 @@ TEST_CASE(test_shape_dynamic_fixed)
TEST_CASE(test_shape_dynamic_not_fixed)
{
std::vector<migraphx::shape::dynamic_dimension> dims = {};
migraphx::shape::dynamic_dimensions dims = {};
dims.emplace_back(2, 5, 2);
dims.emplace_back(2, 8, 0);
migraphx::shape s{migraphx::shape::dyn_data{migraphx::shape::float_type, dims}};
migraphx::shape s{migraphx::shape::float_type, dims};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
......@@ -80,10 +80,10 @@ TEST_CASE(test_shape_dynamic_compares)
EXPECT(a == c);
EXPECT(a != d);
migraphx::shape s0{migraphx::shape::dyn_data{migraphx::shape::float_type, {a, d}}};
migraphx::shape s0{migraphx::shape::float_type, {a, d}};
migraphx::shape s1 = s0;
migraphx::shape s2{migraphx::shape::dyn_data{migraphx::shape::float_type, {a, d}}};
migraphx::shape s3{{migraphx::shape::int32_type, {a}}};
migraphx::shape s2{migraphx::shape::float_type, {a, d}};
migraphx::shape s3{migraphx::shape::int32_type, {a}};
EXPECT(s0 == s1);
EXPECT(s0 == s2);
EXPECT(s0 != s3);
......@@ -94,7 +94,7 @@ TEST_CASE(test_shape_dynamic_errors)
std::vector<migraphx::shape::dynamic_dimension> dims = {};
dims.emplace_back(2, 5, 2);
dims.emplace_back(2, 8, 0);
migraphx::shape s{migraphx::shape::dyn_data{migraphx::shape::float_type, dims}};
migraphx::shape s{migraphx::shape::float_type, dims};
EXPECT(test::throws([&] { s.element_space(); }));
EXPECT(test::throws([&] { s.elements(); }));
EXPECT(test::throws([&] { s.bytes(); }));
......@@ -109,12 +109,12 @@ TEST_CASE(test_shape_dynamic_serialize)
std::vector<migraphx::shape::dynamic_dimension> dims1 = {};
dims1.emplace_back(2, 5, 2);
dims1.emplace_back(2, 8, 0);
migraphx::shape s1{migraphx::shape::dyn_data{migraphx::shape::float_type, dims1}};
migraphx::shape s1{migraphx::shape::float_type, dims1};
auto v1 = migraphx::to_value(s1);
std::vector<migraphx::shape::dynamic_dimension> dims2 = {};
dims2.emplace_back(2, 5, 2);
migraphx::shape s2{migraphx::shape::dyn_data{migraphx::shape::uint64_type, dims2}};
migraphx::shape s2{migraphx::shape::uint64_type, dims2};
auto v2 = migraphx::to_value(s2);
EXPECT(v1 != v2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment