Commit de4c1b44 authored by charlie's avatar charlie
Browse files

Add dyn_data struct to avoid ambiguous constructor

parent 0a8342b8
......@@ -87,6 +87,13 @@ struct shape
}
};
// Avoid ambiguous constructor
struct dyn_data
{
type_t t;
std::vector<dynamic_dimension> dims;
};
static const std::vector<type_t>& types();
static std::string name(type_t t);
......@@ -97,7 +104,7 @@ struct shape
shape(type_t t, std::vector<std::size_t> l);
shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
shape(type_t t, std::vector<dynamic_dimension> dims);
explicit shape(dyn_data data);
template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
......
......@@ -47,8 +47,8 @@ struct shape_impl
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape_impl(shape::type_t t, std::vector<shape::dynamic_dimension> dims)
: m_type(t), m_dynamic(true), m_dyn_dims(std::move(dims))
explicit shape_impl(shape::dyn_data data)
: m_type(data.t), m_dynamic(true), m_dyn_dims(data.dims)
{
}
......@@ -157,10 +157,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
: impl(std::make_shared<shape_impl>(t, std::move(dims)))
{
}
shape::shape(dyn_data data) : impl(std::make_shared<shape_impl>(std::move(data))) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
......@@ -509,7 +506,7 @@ void migraphx_from_value(const value& v, shape& s)
{
dyn_dims.emplace_back(mins[i], maxes[i], opts[i]);
}
s = shape{shape::parse_type(t), dyn_dims};
s = shape{migraphx::shape::dyn_data{shape::parse_type(t), dyn_dims}};
}
else
{
......
......@@ -43,11 +43,7 @@ TEST_CASE(test_shape_standard)
TEST_CASE(test_shape_dynamic_fixed)
{
std::vector<migraphx::shape::dynamic_dimension> dims = {};
dims.emplace_back(2, 2, 0);
dims.emplace_back(2, 2, 0);
dims.emplace_back(3, 3, 0);
migraphx::shape s{migraphx::shape::float_type, dims};
migraphx::shape s{{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}, {3, 3, 0}}}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
......@@ -63,7 +59,7 @@ TEST_CASE(test_shape_dynamic_not_fixed)
std::vector<migraphx::shape::dynamic_dimension> dims = {};
dims.emplace_back(2, 5, 2);
dims.emplace_back(2, 8, 0);
migraphx::shape s{migraphx::shape::float_type, dims};
migraphx::shape s{migraphx::shape::dyn_data{migraphx::shape::float_type, dims}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
......@@ -84,10 +80,10 @@ TEST_CASE(test_shape_dynamic_compares)
EXPECT(a == c);
EXPECT(a != d);
migraphx::shape s0{migraphx::shape::float_type, {a, d}};
migraphx::shape s0{migraphx::shape::dyn_data{migraphx::shape::float_type, {a, d}}};
migraphx::shape s1 = s0;
migraphx::shape s2{migraphx::shape::float_type, {a, d}};
migraphx::shape s3{migraphx::shape::int32_type, {a}};
migraphx::shape s2{migraphx::shape::dyn_data{migraphx::shape::float_type, {a, d}}};
migraphx::shape s3{{migraphx::shape::int32_type, {a}}};
EXPECT(s0 == s1);
EXPECT(s0 == s2);
EXPECT(s0 != s3);
......@@ -98,7 +94,7 @@ TEST_CASE(test_shape_dynamic_errors)
std::vector<migraphx::shape::dynamic_dimension> dims = {};
dims.emplace_back(2, 5, 2);
dims.emplace_back(2, 8, 0);
migraphx::shape s{migraphx::shape::float_type, dims};
migraphx::shape s{migraphx::shape::dyn_data{migraphx::shape::float_type, dims}};
EXPECT(test::throws([&] { s.element_space(); }));
EXPECT(test::throws([&] { s.elements(); }));
EXPECT(test::throws([&] { s.bytes(); }));
......@@ -113,12 +109,12 @@ TEST_CASE(test_shape_dynamic_serialize)
std::vector<migraphx::shape::dynamic_dimension> dims1 = {};
dims1.emplace_back(2, 5, 2);
dims1.emplace_back(2, 8, 0);
migraphx::shape s1{migraphx::shape::float_type, dims1};
migraphx::shape s1{migraphx::shape::dyn_data{migraphx::shape::float_type, dims1}};
auto v1 = migraphx::to_value(s1);
std::vector<migraphx::shape::dynamic_dimension> dims2 = {};
dims2.emplace_back(2, 5, 2);
migraphx::shape s2{migraphx::shape::uint64_type, dims2};
migraphx::shape s2{migraphx::shape::dyn_data{migraphx::shape::uint64_type, dims2}};
auto v2 = migraphx::to_value(s2);
EXPECT(v1 != v2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment