"...composable_kernel_rocm.git" did not exist on "09338e1f061dd2740a26b9461d1a6f0697efe886"
Unverified Commit 53c406bb authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Skip "find_split_transpose" if split has more than one outputs (#2126)

parent 14035176
...@@ -1446,10 +1446,13 @@ struct find_split_transpose ...@@ -1446,10 +1446,13 @@ struct find_split_transpose
{ {
return; return;
} }
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
return i->outputs().size() != 1;
}))
return;
std::vector<instruction_ref> vec_trans(split_outputs.size()); std::vector<instruction_ref> vec_trans(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) { std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) {
assert(i->outputs().size() == 1);
return i->outputs().front(); return i->outputs().front();
}); });
......
...@@ -603,8 +603,8 @@ TEST_CASE(simplify_inner_broadcast_scalar) ...@@ -603,8 +603,8 @@ TEST_CASE(simplify_inner_broadcast_scalar)
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}}); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}}); auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto yb = auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 384}}}), y); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 384}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb); auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
...@@ -630,8 +630,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims) ...@@ -630,8 +630,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}}); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}}); auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto yb = auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), y); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb); auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
...@@ -3035,6 +3035,36 @@ void reorder_slice_trans_diff_perm() ...@@ -3035,6 +3035,36 @@ void reorder_slice_trans_diff_perm()
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<1>); TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<1>);
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<4>); TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<4>);
TEST_CASE(reorder_slice_trans_multi_outputs)
{
migraphx::module m1;
{
auto s = migraphx::shape{migraphx::shape::float_type, {8, 128, 1920}};
auto input = m1.add_parameter("input", s);
std::vector<int64_t> perm = {0, 2, 1};
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}),
input);
auto slc2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
input);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc0);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc1);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto dot = m1.add_instruction(migraphx::make_op("mul"), sum, t2);
auto slc_cont = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
m1.add_return({slc_cont, dot});
};
run_pass(m1);
auto m2 = m1;
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_slice_ins_deps) TEST_CASE(reorder_slice_ins_deps)
{ {
auto create_module = [] { auto create_module = [] {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment