Commit de4c1b44 authored by charlie's avatar charlie
Browse files

Add dyn_data struct to avoid ambiguous constructor

parent 0a8342b8
...@@ -87,6 +87,13 @@ struct shape ...@@ -87,6 +87,13 @@ struct shape
} }
}; };
// Avoid ambiguous constructor
struct dyn_data
{
type_t t;
std::vector<dynamic_dimension> dims;
};
static const std::vector<type_t>& types(); static const std::vector<type_t>& types();
static std::string name(type_t t); static std::string name(type_t t);
...@@ -97,7 +104,7 @@ struct shape ...@@ -97,7 +104,7 @@ struct shape
shape(type_t t, std::vector<std::size_t> l); shape(type_t t, std::vector<std::size_t> l);
shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s); shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
shape(type_t t, std::vector<dynamic_dimension> dims); explicit shape(dyn_data data);
template <class Range> template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end())) shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
......
...@@ -47,8 +47,8 @@ struct shape_impl ...@@ -47,8 +47,8 @@ struct shape_impl
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {} shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape_impl(shape::type_t t, std::vector<shape::dynamic_dimension> dims) explicit shape_impl(shape::dyn_data data)
: m_type(t), m_dynamic(true), m_dyn_dims(std::move(dims)) : m_type(data.t), m_dynamic(true), m_dyn_dims(data.dims)
{ {
} }
...@@ -157,10 +157,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) ...@@ -157,10 +157,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {} shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims) shape::shape(dyn_data data) : impl(std::make_shared<shape_impl>(std::move(data))) {}
: impl(std::make_shared<shape_impl>(t, std::move(dims)))
{
}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {} shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
...@@ -509,7 +506,7 @@ void migraphx_from_value(const value& v, shape& s) ...@@ -509,7 +506,7 @@ void migraphx_from_value(const value& v, shape& s)
{ {
dyn_dims.emplace_back(mins[i], maxes[i], opts[i]); dyn_dims.emplace_back(mins[i], maxes[i], opts[i]);
} }
s = shape{shape::parse_type(t), dyn_dims}; s = shape{migraphx::shape::dyn_data{shape::parse_type(t), dyn_dims}};
} }
else else
{ {
......
...@@ -43,11 +43,7 @@ TEST_CASE(test_shape_standard) ...@@ -43,11 +43,7 @@ TEST_CASE(test_shape_standard)
TEST_CASE(test_shape_dynamic_fixed) TEST_CASE(test_shape_dynamic_fixed)
{ {
std::vector<migraphx::shape::dynamic_dimension> dims = {}; migraphx::shape s{{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}, {3, 3, 0}}}};
dims.emplace_back(2, 2, 0);
dims.emplace_back(2, 2, 0);
dims.emplace_back(3, 3, 0);
migraphx::shape s{migraphx::shape::float_type, dims};
EXPECT(not s.standard()); EXPECT(not s.standard());
EXPECT(not s.packed()); EXPECT(not s.packed());
EXPECT(not s.transposed()); EXPECT(not s.transposed());
...@@ -63,7 +59,7 @@ TEST_CASE(test_shape_dynamic_not_fixed) ...@@ -63,7 +59,7 @@ TEST_CASE(test_shape_dynamic_not_fixed)
std::vector<migraphx::shape::dynamic_dimension> dims = {}; std::vector<migraphx::shape::dynamic_dimension> dims = {};
dims.emplace_back(2, 5, 2); dims.emplace_back(2, 5, 2);
dims.emplace_back(2, 8, 0); dims.emplace_back(2, 8, 0);
migraphx::shape s{migraphx::shape::float_type, dims}; migraphx::shape s{migraphx::shape::dyn_data{migraphx::shape::float_type, dims}};
EXPECT(not s.standard()); EXPECT(not s.standard());
EXPECT(not s.packed()); EXPECT(not s.packed());
EXPECT(not s.transposed()); EXPECT(not s.transposed());
...@@ -84,10 +80,10 @@ TEST_CASE(test_shape_dynamic_compares) ...@@ -84,10 +80,10 @@ TEST_CASE(test_shape_dynamic_compares)
EXPECT(a == c); EXPECT(a == c);
EXPECT(a != d); EXPECT(a != d);
migraphx::shape s0{migraphx::shape::float_type, {a, d}}; migraphx::shape s0{migraphx::shape::dyn_data{migraphx::shape::float_type, {a, d}}};
migraphx::shape s1 = s0; migraphx::shape s1 = s0;
migraphx::shape s2{migraphx::shape::float_type, {a, d}}; migraphx::shape s2{migraphx::shape::dyn_data{migraphx::shape::float_type, {a, d}}};
migraphx::shape s3{migraphx::shape::int32_type, {a}}; migraphx::shape s3{{migraphx::shape::int32_type, {a}}};
EXPECT(s0 == s1); EXPECT(s0 == s1);
EXPECT(s0 == s2); EXPECT(s0 == s2);
EXPECT(s0 != s3); EXPECT(s0 != s3);
...@@ -98,7 +94,7 @@ TEST_CASE(test_shape_dynamic_errors) ...@@ -98,7 +94,7 @@ TEST_CASE(test_shape_dynamic_errors)
std::vector<migraphx::shape::dynamic_dimension> dims = {}; std::vector<migraphx::shape::dynamic_dimension> dims = {};
dims.emplace_back(2, 5, 2); dims.emplace_back(2, 5, 2);
dims.emplace_back(2, 8, 0); dims.emplace_back(2, 8, 0);
migraphx::shape s{migraphx::shape::float_type, dims}; migraphx::shape s{migraphx::shape::dyn_data{migraphx::shape::float_type, dims}};
EXPECT(test::throws([&] { s.element_space(); })); EXPECT(test::throws([&] { s.element_space(); }));
EXPECT(test::throws([&] { s.elements(); })); EXPECT(test::throws([&] { s.elements(); }));
EXPECT(test::throws([&] { s.bytes(); })); EXPECT(test::throws([&] { s.bytes(); }));
...@@ -113,12 +109,12 @@ TEST_CASE(test_shape_dynamic_serialize) ...@@ -113,12 +109,12 @@ TEST_CASE(test_shape_dynamic_serialize)
std::vector<migraphx::shape::dynamic_dimension> dims1 = {}; std::vector<migraphx::shape::dynamic_dimension> dims1 = {};
dims1.emplace_back(2, 5, 2); dims1.emplace_back(2, 5, 2);
dims1.emplace_back(2, 8, 0); dims1.emplace_back(2, 8, 0);
migraphx::shape s1{migraphx::shape::float_type, dims1}; migraphx::shape s1{migraphx::shape::dyn_data{migraphx::shape::float_type, dims1}};
auto v1 = migraphx::to_value(s1); auto v1 = migraphx::to_value(s1);
std::vector<migraphx::shape::dynamic_dimension> dims2 = {}; std::vector<migraphx::shape::dynamic_dimension> dims2 = {};
dims2.emplace_back(2, 5, 2); dims2.emplace_back(2, 5, 2);
migraphx::shape s2{migraphx::shape::uint64_type, dims2}; migraphx::shape s2{migraphx::shape::dyn_data{migraphx::shape::uint64_type, dims2}};
auto v2 = migraphx::to_value(s2); auto v2 = migraphx::to_value(s2);
EXPECT(v1 != v2); EXPECT(v1 != v2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment