Unverified Commit 16a03b39 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Concat transpose bug (#638)



* fix a bug related to concat transpose.

* clang format

* use return instruction to replace the fake instruction
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 8bf97a2f
...@@ -231,11 +231,25 @@ struct find_concat_transpose ...@@ -231,11 +231,25 @@ struct find_concat_transpose
void apply(program& p, const match::matcher_result& mr) const void apply(program& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto s = ins->inputs().front()->get_shape(); auto trans_inputs = ins->inputs();
auto s = trans_inputs.front()->get_shape();
assert(s.transposed()); assert(s.transposed());
auto op = any_cast<op::concat>(ins->get_operator()); auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s); auto permutation = find_permutation(s);
// permutation should be the same for all inputs
if(!std::all_of(trans_inputs.begin(), trans_inputs.end(), [&](auto in) {
return (find_permutation(in->get_shape()) == permutation);
}))
{
return;
}
// axis could be a negative value
int64_t n_dim = static_cast<int64_t>(s.lens().size());
op.axis = (op.axis < 0) ? (op.axis + n_dim) : op.axis;
auto ipermutation = invert_permutation(permutation); auto ipermutation = invert_permutation(permutation);
op.axis = ipermutation[op.axis]; op.axis = ipermutation[op.axis];
......
...@@ -19,7 +19,7 @@ TEST_CASE(double_contig) ...@@ -19,7 +19,7 @@ TEST_CASE(double_contig)
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1); auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1); auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
p.add_instruction(pass_op{}, c2); p.add_return({c2});
EXPECT(p.get_output_shapes().back().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed()); EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p); run_pass(p);
...@@ -36,7 +36,7 @@ TEST_CASE(double_transpose) ...@@ -36,7 +36,7 @@ TEST_CASE(double_transpose)
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1); auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
p.add_instruction(pass_op{}, t2); p.add_return({t2});
EXPECT(p.get_output_shapes().back().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed()); EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p); run_pass(p);
...@@ -55,7 +55,7 @@ TEST_CASE(double_transpose_contig) ...@@ -55,7 +55,7 @@ TEST_CASE(double_transpose_contig)
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1); auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1); auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2); auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2);
p.add_instruction(pass_op{}, c2); p.add_return({c2});
EXPECT(p.get_output_shapes().back().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed()); EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p); run_pass(p);
...@@ -71,7 +71,7 @@ TEST_CASE(single_transpose) ...@@ -71,7 +71,7 @@ TEST_CASE(single_transpose)
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t1); p.add_return({t1});
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed()); EXPECT(p.get_output_shapes().back().transposed());
run_pass(p); run_pass(p);
...@@ -123,7 +123,7 @@ TEST_CASE(reshape_transpose) ...@@ -123,7 +123,7 @@ TEST_CASE(reshape_transpose)
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1); auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
auto ct = p.add_instruction(migraphx::op::contiguous{}, t); auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct); auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
p.add_instruction(pass_op{}, r2); p.add_return({r2});
EXPECT(p.get_output_shapes().back() == s); EXPECT(p.get_output_shapes().back() == s);
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -138,7 +138,7 @@ TEST_CASE(transpose_contiguous) ...@@ -138,7 +138,7 @@ TEST_CASE(transpose_contiguous)
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t); auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c1); p.add_return({c1});
auto out_shape = p.get_output_shapes().back(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -154,7 +154,7 @@ TEST_CASE(transpose_double_contiguous) ...@@ -154,7 +154,7 @@ TEST_CASE(transpose_double_contiguous)
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t); auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1); auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
p.add_instruction(pass_op{}, c2); p.add_return({c2});
auto out_shape = p.get_output_shapes().back(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -170,7 +170,7 @@ TEST_CASE(transpose_partial1) ...@@ -170,7 +170,7 @@ TEST_CASE(transpose_partial1)
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1); auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
p.add_instruction(pass_op{}, t2); p.add_return({t2});
auto out_shape = p.get_output_shapes().back(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -186,7 +186,7 @@ TEST_CASE(transpose_partial2) ...@@ -186,7 +186,7 @@ TEST_CASE(transpose_partial2)
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1); auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2); auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
p.add_instruction(pass_op{}, t3); p.add_return({t3});
auto out_shape = p.get_output_shapes().back(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -203,7 +203,7 @@ TEST_CASE(transpose_partial3) ...@@ -203,7 +203,7 @@ TEST_CASE(transpose_partial3)
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1); auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2); auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3); auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
p.add_instruction(pass_op{}, t4); p.add_return({t4});
auto out_shape = p.get_output_shapes().back(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -217,7 +217,7 @@ TEST_CASE(nop_transpose1) ...@@ -217,7 +217,7 @@ TEST_CASE(nop_transpose1)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x); auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
p.add_instruction(pass_op{}, t); p.add_return({t});
auto out_shape = p.get_output_shapes().back(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -251,7 +251,7 @@ TEST_CASE(nop_transpose3) ...@@ -251,7 +251,7 @@ TEST_CASE(nop_transpose3)
auto concat = p.add_instruction(migraphx::op::concat{3}, x, y); auto concat = p.add_instruction(migraphx::op::concat{3}, x, y);
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat); auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1); auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
p.add_instruction(pass_op{}, t2); p.add_return({t2});
auto out_shape = p.get_output_shapes().back(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -269,7 +269,7 @@ TEST_CASE(concat_transpose1) ...@@ -269,7 +269,7 @@ TEST_CASE(concat_transpose1)
auto yt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y); auto yt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt); auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat); auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
p.add_instruction(pass_op{}, t); p.add_return({t});
auto out_shape = p.get_output_shapes().back(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -289,9 +289,9 @@ TEST_CASE(concat_transpose2) ...@@ -289,9 +289,9 @@ TEST_CASE(concat_transpose2)
auto y = p.add_parameter("y", s); auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x); auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y); auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt); auto concat = p.add_instruction(migraphx::op::concat{-1}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat); auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t); p.add_return({t});
auto out_shape = p.get_output_shapes().back(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -313,7 +313,7 @@ TEST_CASE(concat_transpose3) ...@@ -313,7 +313,7 @@ TEST_CASE(concat_transpose3)
auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y); auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt); auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat); auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t); p.add_return({t});
auto out_shape = p.get_output_shapes().back(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -325,6 +325,25 @@ TEST_CASE(concat_transpose3) ...@@ -325,6 +325,25 @@ TEST_CASE(concat_transpose3)
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1); EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
} }
TEST_CASE(concat_transpose4)
{
migraphx::program p;
auto sx = migraphx::shape{migraphx::shape::float_type, {1, 1, 12, 64}};
auto sy = migraphx::shape{migraphx::shape::float_type, {1, 12, 1, 64}};
auto x = p.add_parameter("x", sx);
auto y = p.add_parameter("y", sy);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_return({t});
migraphx::program p1 = p;
run_pass(p);
EXPECT(p1 == p);
}
TEST_CASE(nested_concat) TEST_CASE(nested_concat)
{ {
migraphx::program p; migraphx::program p;
...@@ -334,7 +353,7 @@ TEST_CASE(nested_concat) ...@@ -334,7 +353,7 @@ TEST_CASE(nested_concat)
auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y); auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x); auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2); auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2);
p.add_instruction(pass_op{}, concat3); p.add_return({concat3});
auto out_shape = p.get_output_shapes().back(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -354,7 +373,7 @@ TEST_CASE(nested_concat_partial) ...@@ -354,7 +373,7 @@ TEST_CASE(nested_concat_partial)
auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y); auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x); auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2, l); auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2, l);
p.add_instruction(pass_op{}, concat3); p.add_return({concat3});
auto out_shape = p.get_output_shapes().back(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -383,7 +402,7 @@ TEST_CASE(double_slice1) ...@@ -383,7 +402,7 @@ TEST_CASE(double_slice1)
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {256}}); auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {256}}, x); auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {256}}, x);
auto slice2 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {64}}, slice1); auto slice2 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {64}}, slice1);
p1.add_instruction(pass_op{}, slice2); p1.add_return({slice2});
} }
run_pass(p1); run_pass(p1);
...@@ -391,7 +410,7 @@ TEST_CASE(double_slice1) ...@@ -391,7 +410,7 @@ TEST_CASE(double_slice1)
{ {
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {256}}); auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice = p2.add_instruction(migraphx::op::slice{{0}, {64}, {96}}, x); auto slice = p2.add_instruction(migraphx::op::slice{{0}, {64}, {96}}, x);
p2.add_instruction(pass_op{}, slice); p2.add_return({slice});
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -403,7 +422,7 @@ TEST_CASE(double_slice2) ...@@ -403,7 +422,7 @@ TEST_CASE(double_slice2)
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {256}}); auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x); auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
auto slice2 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {32}}, slice1); auto slice2 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {32}}, slice1);
p1.add_instruction(pass_op{}, slice2); p1.add_return({slice2});
} }
run_pass(p1); run_pass(p1);
...@@ -411,7 +430,7 @@ TEST_CASE(double_slice2) ...@@ -411,7 +430,7 @@ TEST_CASE(double_slice2)
{ {
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {256}}); auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice = p2.add_instruction(migraphx::op::slice{{0}, {32}, {64}}, x); auto slice = p2.add_instruction(migraphx::op::slice{{0}, {32}, {64}}, x);
p2.add_instruction(pass_op{}, slice); p2.add_return({slice});
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -423,7 +442,7 @@ TEST_CASE(double_slice_multi_axes) ...@@ -423,7 +442,7 @@ TEST_CASE(double_slice_multi_axes)
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {256, 128}}); auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x); auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
auto slice2 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {32}}, slice1); auto slice2 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {32}}, slice1);
p1.add_instruction(pass_op{}, slice2); p1.add_return({slice2});
} }
run_pass(p1); run_pass(p1);
...@@ -431,7 +450,7 @@ TEST_CASE(double_slice_multi_axes) ...@@ -431,7 +450,7 @@ TEST_CASE(double_slice_multi_axes)
{ {
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {256, 128}}); auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
auto slice = p2.add_instruction(migraphx::op::slice{{0, 1}, {32, 0}, {128, 32}}, x); auto slice = p2.add_instruction(migraphx::op::slice{{0, 1}, {32, 0}, {128, 32}}, x);
p2.add_instruction(pass_op{}, slice); p2.add_return({slice});
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment