Unverified Commit f0e530f0 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix bug in unsqueeze shape and improve test coverage (#495)



* Fix bug in unsqueeze shape and improve test coverage

* Remove else statement

* Revert unsqueeze changes

* Update tests

* Formatting

* Update tests

* Fix handling of scalars

* Add another scalar test
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 11bc6104
...@@ -35,7 +35,12 @@ struct unsqueeze ...@@ -35,7 +35,12 @@ struct unsqueeze
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
if(input_shape.scalar()) if(input_shape.scalar())
return shape{type, old_lens}; {
if(old_lens.size() == 1 and old_lens.front() == 1)
return shape{type, old_lens};
else
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
}
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
......
...@@ -499,17 +499,62 @@ TEST_CASE(test_argmin) ...@@ -499,17 +499,62 @@ TEST_CASE(test_argmin)
TEST_CASE(test_squeeze) TEST_CASE(test_squeeze)
{ {
{ migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}}; expect_shape(s2, migraphx::op::squeeze{{3}}, s1);
expect_shape(s2, migraphx::op::squeeze{{-2}}, s1); }
}
{ TEST_CASE(test_squeeze_negative_axis)
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; {
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
expect_shape(s2, migraphx::op::unsqueeze{{-2}}, s1); migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
} expect_shape(s2, migraphx::op::squeeze{{-2}}, s1);
}
TEST_CASE(test_squeeze_wrong_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
throws_shape(migraphx::op::squeeze{{0}}, s1);
}
TEST_CASE(test_squeeze_all)
{
migraphx::shape s1{migraphx::shape::float_type, {1}};
migraphx::shape s2{migraphx::shape::float_type};
expect_shape(s2, migraphx::op::squeeze{{0}}, s1);
}
TEST_CASE(test_unsqueeze_scalar)
{
migraphx::shape s1{migraphx::shape::float_type, {1}, {0}};
migraphx::shape s2{migraphx::shape::float_type, {1}, {1}};
expect_shape(s2, migraphx::op::unsqueeze{{0}}, s1);
}
TEST_CASE(test_unsqueeze_scalar_tensor1)
{
migraphx::shape s{migraphx::shape::float_type, {4, 3, 3}, {0, 0, 0}};
throws_shape(migraphx::op::unsqueeze{{-2}}, s);
}
TEST_CASE(test_unsqueeze_scalar_tensor2)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 1}, {0, 0, 0}};
throws_shape(migraphx::op::unsqueeze{{-2}}, s);
}
TEST_CASE(test_unsqueeze)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
expect_shape(s2, migraphx::op::unsqueeze{{2}}, s1);
}
TEST_CASE(test_unsqueeze_negative_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
expect_shape(s2, migraphx::op::unsqueeze{{-2}}, s1);
} }
template <class T> template <class T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment