Unverified Commit 64b306ab authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Ensure same layout is used for miopen convolution (#2102)

parent ea97ce52
...@@ -70,13 +70,19 @@ struct check_shapes ...@@ -70,13 +70,19 @@ struct check_shapes
check_dynamic(); check_dynamic();
} }
template <class Op> template <class Op, MIGRAPHX_REQUIRES(not std::is_convertible<Op, std::string>{})>
check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false) check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false)
: begin(s.begin()), end(s.end()), name(op.name()), dynamic_allowed(d) : begin(s.begin()), end(s.end()), name(op.name()), dynamic_allowed(d)
{ {
check_dynamic(); check_dynamic();
} }
check_shapes(const std::vector<shape>& s, const std::string& n, const bool d = false)
: begin(s.begin()), end(s.end()), name(n), dynamic_allowed(d)
{
check_dynamic();
}
void check_dynamic() const void check_dynamic() const
{ {
if(not dynamic_allowed and this->any_of([&](const shape& s) { return s.dynamic(); })) if(not dynamic_allowed and this->any_of([&](const shape& s) { return s.dynamic(); }))
...@@ -228,6 +234,16 @@ struct check_shapes ...@@ -228,6 +234,16 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same layout.
*/
const check_shapes& same_layout() const
{
if(not this->same([](const shape& s) { return find_permutation(s); }))
MIGRAPHX_THROW(prefix() + "Layouts do not match");
return *this;
}
/*! /*!
* Check all shapes are standard. * Check all shapes are standard.
*/ */
......
...@@ -84,8 +84,10 @@ struct miopen_convolution ...@@ -84,8 +84,10 @@ struct miopen_convolution
{ {
check_shapes{inputs, op}.has(4); check_shapes{inputs, op}.has(4);
std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2); std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2);
check_shapes{conv_inputs, *this}.max_ndims(5).packed_layouts( check_shapes{conv_inputs, *this}
{{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}}); .max_ndims(5)
.packed_layouts({{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}})
.same_layout();
return migraphx::compute_shape<Op>(op, conv_inputs); return migraphx::compute_shape<Op>(op, conv_inputs);
} }
......
...@@ -31,24 +31,39 @@ ...@@ -31,24 +31,39 @@
using migraphx::shape; using migraphx::shape;
bool create_shapes(bool dynamic_allowed) void create_shapes(bool dynamic_allowed)
{ {
try shape a{shape::int64_type, {3}};
{ shape b{shape::float_type, {{3, 6}, {4, 4}}};
shape a{shape::int64_type, {3}}; migraphx::check_shapes{{a, b}, "", dynamic_allowed}.has(2);
shape b{shape::float_type, {{3, 6}, {4, 4}}};
auto op = migraphx::make_op("add");
migraphx::check_shapes{{a, b}, op, dynamic_allowed}.has(2);
return true;
}
catch(...)
{
return false;
}
} }
TEST_CASE(allow_dynamic_shape) { EXPECT(create_shapes(true)); } TEST_CASE(allow_dynamic_shape)
{
EXPECT(not test::throws([] { create_shapes(true); }));
}
TEST_CASE(fail_dynamic_shape)
{
EXPECT(test::throws([] { create_shapes(false); }));
}
TEST_CASE(fail_dynamic_shape) { EXPECT(not create_shapes(false)); } TEST_CASE(same_layout_fail)
{
EXPECT(test::throws([] {
shape a{shape::float_type, {2, 3}};
shape b{shape::float_type, {2, 3}, {1, 2}};
migraphx::check_shapes{{a, b}, ""}.same_layout();
}));
}
TEST_CASE(same_layout_pass)
{
EXPECT(not test::throws([] {
shape a{shape::float_type, {2, 3}, {1, 2}};
shape b{shape::float_type, {2, 3}, {1, 2}};
migraphx::check_shapes{{a, b}, ""}.same_layout();
}));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment