Commit 00dae07f authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

add generic case for broadcast transpose

parent 434a06cf
...@@ -521,6 +521,24 @@ struct find_inner_broadcast ...@@ -521,6 +521,24 @@ struct find_inner_broadcast
}) < (lens.size() - 1); }) < (lens.size() - 1);
})) }))
return; return;
auto bcast_strides = broadcasts.front()->get_shape().strides().size();
std::vector<size_t> common_axis(bcast_strides, 0);
// go through the strides of each broadcast,
// keep track of values that are equal to 0 in a dimension
for(auto i = 0; i < bcast_strides; i++)
{
for(auto j = 0; j < broadcasts.size(); j++)
{
if(broadcasts[j]->get_shape().strides()[i] == 0)
common_axis[i]++;
}
}
// if no common broadcast axis, transformation is not useful
if(std::find_if(common_axis.begin(), common_axis.end(), [](auto num_common) {
return num_common > 1;
}) == common_axis.end())
return;
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(), std::transform(broadcasts.begin(),
broadcasts.end(), broadcasts.end(),
...@@ -543,6 +561,17 @@ struct find_inner_broadcast ...@@ -543,6 +561,17 @@ struct find_inner_broadcast
return 3; return 3;
})); }));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs); auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
std::vector<shape> broadcast_shapes;
std::transform(broadcasts.begin(), broadcasts.end(), std::back_inserter(broadcast_shapes), [](auto broadcast){
return broadcast->get_shape();
});
std::vector<shape> common_shapes;
std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(common_shapes), [](auto common){
return common->get_shape();
});
if(broadcast_shapes == common_shapes and std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i){
return i->name() == "broadcast" or i->name() == "multibroadcast";}))
return;
m.replace_instruction(ins, broadcasts.front()->get_operator(), op); m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
} }
}; };
......
...@@ -641,10 +641,20 @@ struct find_broadcast_transpose ...@@ -641,10 +641,20 @@ struct find_broadcast_transpose
auto ins_lens = ins->get_shape().lens(); auto ins_lens = ins->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"]; auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front(); auto input = bcast_ins->inputs().front();
// for now, focusing on scalar transformation // scalar transformation does not need extra transpose
if(not input->get_shape().scalar()) if(not input->get_shape().scalar())
return; {
// find common shape
auto in_lens = input->get_shape().lens();
int lens_diff = ins_lens.size() - in_lens.size();
if(lens_diff > 0)
{
std::vector<size_t> unsqueeze_axes(lens_diff);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 0);
input = m.insert_instruction(bcast_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input);
}
input = m.insert_instruction(bcast_ins, ins->get_operator(), input);
}
auto new_mbcast = m.insert_instruction( auto new_mbcast = m.insert_instruction(
bcast_ins, make_op("multibroadcast", {{"out_lens", ins_lens}}), input); bcast_ins, make_op("multibroadcast", {{"out_lens", ins_lens}}), input);
m.replace_instruction(ins, new_mbcast); m.replace_instruction(ins, new_mbcast);
......
...@@ -62,4 +62,38 @@ TEST_CASE(broadcast_transpose_inner_broadcast) ...@@ -62,4 +62,38 @@ TEST_CASE(broadcast_transpose_inner_broadcast)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(broadcast_transpose_inner_broadcast_generic)
{
// first optimizes broadcast+transpose to unsqueeze+transpose+broadcast,
// then finds inner broadcast to become mul+broadcast
migraphx::module m1;
{
auto l1 = m1.add_parameter("x", {migraphx::shape::float_type, {5, 10}});
auto l2 = m1.add_parameter("y", {migraphx::shape::float_type, {5}});
auto mb1 =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), l1);
auto mb2 =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 10, 5}}}), l2);
auto t1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), mb2);
auto mul = m1.add_instruction(migraphx::make_op("mul"), mb1, t1);
m1.add_return({mul});
}
run_pass(m1);
migraphx::module m2;
{
auto l1 = m2.add_parameter("x", {migraphx::shape::float_type, {5, 10}});
auto l2 = m2.add_parameter("y", {migraphx::shape::float_type, {5}});
auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l2);
auto transpose = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze);
auto mb1 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), l1);
auto mb2 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), transpose);
auto mul = m2.add_instruction(migraphx::make_op("mul"), mb1, mb2);
auto mb3 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), mul);
m2.add_return({mb3});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment