Commit bd70cd8d authored by Paul's avatar Paul
Browse files

Merge branch 'fuse-horiz-contiguous2' into bert-opt2

parents 68727a64 9962a56d
......@@ -74,6 +74,9 @@ struct unsqueeze
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
}
if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size);
......
......@@ -275,7 +275,7 @@ struct find_concat_transpose
{
auto matcher() const
{
return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
return match::name("concat")(match::all_of[match::inputs()](match::name("transpose")));
}
void apply(module& m, const match::matcher_result& mr) const
......
......@@ -1551,7 +1551,7 @@ TEST_CASE(test_unsqueeze_step_non_divisable)
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_step_non_zero)
TEST_CASE(test_unsqueeze_step_zero)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {0}}}), s1);
......@@ -1563,6 +1563,12 @@ TEST_CASE(test_unsqueeze_step_at_end)
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {3}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_mismatch_step_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2, 3}}}), s1);
}
TEST_CASE(test_unsqueeze_negative_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
......@@ -1659,6 +1665,13 @@ TEST_CASE(test_unsqueeze_multiple_axes_4)
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {5, 4, 2}}}), s1);
}
TEST_CASE(test_unsqueeze_multiple_axes_step)
{
migraphx::shape s1{migraphx::shape::float_type, {3, 4, 10}};
migraphx::shape s2{migraphx::shape::float_type, {3, 4, 2, 5, 1, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2, 4, 5}}, {"steps", {2}}}), s1);
}
TEST_CASE(transpose_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
......
......@@ -1141,6 +1141,38 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast)
EXPECT(m1 == m2);
}
TEST_CASE(transpose_unsqueeze_concat)
{
migraphx::module m1;
{
auto l0 = m1.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto l1 = m1.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l1);
auto l2 = m1.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t axis = 3;
std::transform(
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg);
});
m1.add_instruction(migraphx::make_op("concat", {{"axis", axis}}), unsqueezed_args);
}
// TODO: This could be simplified to a single transpose after concat
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice)
{
migraphx::module m1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment