Commit b4a9a9c5 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

add tests, rename variables

parent 8d7948aa
......@@ -642,8 +642,8 @@ struct find_broadcast_transpose
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins_lens = ins->get_shape().lens();
auto transpose = r.result;
auto transpose_lens = transpose->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front();
// scalar transformation does not need extra transpose
......@@ -651,7 +651,8 @@ struct find_broadcast_transpose
{
// find common shape
auto in_lens = input->get_shape().lens();
int lens_diff = ins_lens.size() - in_lens.size();
int lens_diff = transpose_lens.size() - in_lens.size();
// insert unsqueeze if input lens < transpose lens
if(lens_diff > 0)
{
std::vector<size_t> unsqueeze_axes(lens_diff);
......@@ -659,11 +660,12 @@ struct find_broadcast_transpose
input = m.insert_instruction(
bcast_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input);
}
input = m.insert_instruction(bcast_ins, ins->get_operator(), input);
// apply transpose before the multibroadcast
input = m.insert_instruction(bcast_ins, transpose->get_operator(), input);
}
auto new_mbcast = m.insert_instruction(
bcast_ins, make_op("multibroadcast", {{"out_lens", ins_lens}}), input);
m.replace_instruction(ins, new_mbcast);
bcast_ins, make_op("multibroadcast", {{"out_lens", transpose_lens}}), input);
m.replace_instruction(transpose, new_mbcast);
}
};
......
......@@ -669,6 +669,23 @@ TEST_CASE(simplify_inner_broadcast_different_broadcasts)
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast_no_common_axis)
{
auto b = migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {5, 10}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 5, 1}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(b, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(simplify_add_conv1)
{
migraphx::module m;
......
......@@ -67,6 +67,54 @@ migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens,
return m;
}
TEST_CASE(broadcast_transpose)
{
migraphx::module m1;
{
auto l = m1.add_parameter("x", {migraphx::shape::float_type, {5}});
auto mb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5}}}), l);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), mb);
m1.add_return({t1});
}
run_pass(m1);
migraphx::module m2;
{
auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}});
auto u1 = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l);
auto t1 = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), u1);
auto mb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3}}}), t1);
m2.add_return({mb});
}
EXPECT(m1 == m2);
}
TEST_CASE(broadcast_transpose_opt)
{
// extra transpose from transformation will be optimized out
migraphx::module m1;
{
auto l = m1.add_parameter("x", {migraphx::shape::float_type, {5}});
auto mb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5}}}), l);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), mb);
m1.add_return({t1});
}
run_pass(m1);
migraphx::module m2;
{
auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}});
auto u1 = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l);
auto mb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 5}}}), u1);
m2.add_return({mb});
}
EXPECT(m1 == m2);
}
TEST_CASE(broadcast_transpose_scalar)
{
migraphx::module m1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment