Unverified Commit 53c406bb authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Skip "find_split_transpose" if split has more than one outputs (#2126)

parent 14035176
...@@ -1446,10 +1446,13 @@ struct find_split_transpose ...@@ -1446,10 +1446,13 @@ struct find_split_transpose
{ {
return; return;
} }
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
return i->outputs().size() != 1;
}))
return;
std::vector<instruction_ref> vec_trans(split_outputs.size()); std::vector<instruction_ref> vec_trans(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) { std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) {
assert(i->outputs().size() == 1);
return i->outputs().front(); return i->outputs().front();
}); });
......
...@@ -3035,6 +3035,36 @@ void reorder_slice_trans_diff_perm() ...@@ -3035,6 +3035,36 @@ void reorder_slice_trans_diff_perm()
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<1>); TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<1>);
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<4>); TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<4>);
TEST_CASE(reorder_slice_trans_multi_outputs)
{
migraphx::module m1;
{
auto s = migraphx::shape{migraphx::shape::float_type, {8, 128, 1920}};
auto input = m1.add_parameter("input", s);
std::vector<int64_t> perm = {0, 2, 1};
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}),
input);
auto slc2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
input);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc0);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc1);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto dot = m1.add_instruction(migraphx::make_op("mul"), sum, t2);
auto slc_cont = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
m1.add_return({slc_cont, dot});
};
run_pass(m1);
auto m2 = m1;
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_slice_ins_deps) TEST_CASE(reorder_slice_ins_deps)
{ {
auto create_module = [] { auto create_module = [] {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment