Commit 17abf67e authored by charlie's avatar charlie
Browse files

Fix stuff and add tests

parent 4ed53bc6
...@@ -79,8 +79,7 @@ struct unsqueeze ...@@ -79,8 +79,7 @@ struct unsqueeze
std::size_t k = 0; std::size_t k = 0;
for(auto i : range(new_ndim)) for(auto i : range(new_ndim))
{ {
auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin(); if(std::find(axes.begin(), axes.end(), i) != axes.end())
if(axis_idx < axes.size())
{ {
dyn_dims.push_back({1, 1, 0}); dyn_dims.push_back({1, 1, 0});
} }
......
...@@ -1999,6 +1999,19 @@ TEST_CASE(test_unsqueeze) ...@@ -1999,6 +1999,19 @@ TEST_CASE(test_unsqueeze)
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
} }
TEST_CASE(test_unsqueeze_dyn)
{
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type,
{{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}}};
expect_shape(s3, migraphx::make_op("unsqueeze", {{"axes", {2, 4}}}), s1);
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2, 4}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_step) TEST_CASE(test_unsqueeze_step)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
...@@ -2030,13 +2043,27 @@ TEST_CASE(test_unsqueeze_mismatch_step_axis) ...@@ -2030,13 +2043,27 @@ TEST_CASE(test_unsqueeze_mismatch_step_axis)
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2, 3}}}), s1); throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2, 3}}}), s1);
} }
TEST_CASE(test_unsqueeze_negative_axis) TEST_CASE(test_unsqueeze_negative_axis1)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 5, 1, 3}}; migraphx::shape s2{migraphx::shape::float_type, {4, 5, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
} }
TEST_CASE(test_unsqueeze_negative_axis2)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 5, 3, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-1}}}), s1);
}
TEST_CASE(test_unsqueeze_negative_axis3)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 5, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-3}}}), s1);
}
TEST_CASE(test_unsqueeze_scalar) TEST_CASE(test_unsqueeze_scalar)
{ {
migraphx::shape s1{migraphx::shape::float_type, {1}, {0}}; migraphx::shape s1{migraphx::shape::float_type, {1}, {0}};
......
...@@ -7362,6 +7362,25 @@ TEST_CASE(unsqueeze_test) ...@@ -7362,6 +7362,25 @@ TEST_CASE(unsqueeze_test)
} }
} }
TEST_CASE(unsqueeze_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {3, 3, 0}}};
auto p0 = mm->add_parameter("x", s1);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), p0);
p.compile(migraphx::ref::target{});
std::vector<float> input_data(4 * 3 * 3);
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4, 3, 3}};
params0["x"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
EXPECT(result.get_shape() == s2);
}
TEST_CASE(where_test) TEST_CASE(where_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment