Unverified Commit 16a03b39 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Concat transpose bug (#638)



* fix a bug related to concat transpose.

* clang format

* use return instruction to replace the fake instruction
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 8bf97a2f
......@@ -231,11 +231,25 @@ struct find_concat_transpose
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto s = ins->inputs().front()->get_shape();
auto ins = mr.result;
auto trans_inputs = ins->inputs();
auto s = trans_inputs.front()->get_shape();
assert(s.transposed());
auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s);
auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s);
// permutation should be the same for all inputs
if(!std::all_of(trans_inputs.begin(), trans_inputs.end(), [&](auto in) {
return (find_permutation(in->get_shape()) == permutation);
}))
{
return;
}
// axis could be a negative value
int64_t n_dim = static_cast<int64_t>(s.lens().size());
op.axis = (op.axis < 0) ? (op.axis + n_dim) : op.axis;
auto ipermutation = invert_permutation(permutation);
op.axis = ipermutation[op.axis];
......
......@@ -19,7 +19,7 @@ TEST_CASE(double_contig)
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
p.add_instruction(pass_op{}, c2);
p.add_return({c2});
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p);
......@@ -36,7 +36,7 @@ TEST_CASE(double_transpose)
auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
p.add_instruction(pass_op{}, t2);
p.add_return({t2});
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p);
......@@ -55,7 +55,7 @@ TEST_CASE(double_transpose_contig)
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2);
p.add_instruction(pass_op{}, c2);
p.add_return({c2});
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p);
......@@ -71,7 +71,7 @@ TEST_CASE(single_transpose)
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t1);
p.add_return({t1});
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed());
run_pass(p);
......@@ -123,7 +123,7 @@ TEST_CASE(reshape_transpose)
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
p.add_instruction(pass_op{}, r2);
p.add_return({r2});
EXPECT(p.get_output_shapes().back() == s);
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -138,7 +138,7 @@ TEST_CASE(transpose_contiguous)
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c1);
p.add_return({c1});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -154,7 +154,7 @@ TEST_CASE(transpose_double_contiguous)
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
p.add_instruction(pass_op{}, c2);
p.add_return({c2});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -170,7 +170,7 @@ TEST_CASE(transpose_partial1)
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
p.add_instruction(pass_op{}, t2);
p.add_return({t2});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -186,7 +186,7 @@ TEST_CASE(transpose_partial2)
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
p.add_instruction(pass_op{}, t3);
p.add_return({t3});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -203,7 +203,7 @@ TEST_CASE(transpose_partial3)
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
p.add_instruction(pass_op{}, t4);
p.add_return({t4});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -217,7 +217,7 @@ TEST_CASE(nop_transpose1)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
p.add_instruction(pass_op{}, t);
p.add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -251,7 +251,7 @@ TEST_CASE(nop_transpose3)
auto concat = p.add_instruction(migraphx::op::concat{3}, x, y);
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
p.add_instruction(pass_op{}, t2);
p.add_return({t2});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -269,7 +269,7 @@ TEST_CASE(concat_transpose1)
auto yt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
p.add_instruction(pass_op{}, t);
p.add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -289,9 +289,9 @@ TEST_CASE(concat_transpose2)
auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto concat = p.add_instruction(migraphx::op::concat{-1}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t);
p.add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -313,7 +313,7 @@ TEST_CASE(concat_transpose3)
auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t);
p.add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -325,6 +325,25 @@ TEST_CASE(concat_transpose3)
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}
TEST_CASE(concat_transpose4)
{
migraphx::program p;
auto sx = migraphx::shape{migraphx::shape::float_type, {1, 1, 12, 64}};
auto sy = migraphx::shape{migraphx::shape::float_type, {1, 12, 1, 64}};
auto x = p.add_parameter("x", sx);
auto y = p.add_parameter("y", sy);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_return({t});
migraphx::program p1 = p;
run_pass(p);
EXPECT(p1 == p);
}
TEST_CASE(nested_concat)
{
migraphx::program p;
......@@ -334,7 +353,7 @@ TEST_CASE(nested_concat)
auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2);
p.add_instruction(pass_op{}, concat3);
p.add_return({concat3});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -354,7 +373,7 @@ TEST_CASE(nested_concat_partial)
auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2, l);
p.add_instruction(pass_op{}, concat3);
p.add_return({concat3});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -383,7 +402,7 @@ TEST_CASE(double_slice1)
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {256}}, x);
auto slice2 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {64}}, slice1);
p1.add_instruction(pass_op{}, slice2);
p1.add_return({slice2});
}
run_pass(p1);
......@@ -391,7 +410,7 @@ TEST_CASE(double_slice1)
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice = p2.add_instruction(migraphx::op::slice{{0}, {64}, {96}}, x);
p2.add_instruction(pass_op{}, slice);
p2.add_return({slice});
}
EXPECT(p1 == p2);
}
......@@ -403,7 +422,7 @@ TEST_CASE(double_slice2)
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
auto slice2 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {32}}, slice1);
p1.add_instruction(pass_op{}, slice2);
p1.add_return({slice2});
}
run_pass(p1);
......@@ -411,7 +430,7 @@ TEST_CASE(double_slice2)
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice = p2.add_instruction(migraphx::op::slice{{0}, {32}, {64}}, x);
p2.add_instruction(pass_op{}, slice);
p2.add_return({slice});
}
EXPECT(p1 == p2);
}
......@@ -423,7 +442,7 @@ TEST_CASE(double_slice_multi_axes)
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
auto slice2 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {32}}, slice1);
p1.add_instruction(pass_op{}, slice2);
p1.add_return({slice2});
}
run_pass(p1);
......@@ -431,7 +450,7 @@ TEST_CASE(double_slice_multi_axes)
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
auto slice = p2.add_instruction(migraphx::op::slice{{0, 1}, {32, 0}, {128, 32}}, x);
p2.add_instruction(pass_op{}, slice);
p2.add_return({slice});
}
EXPECT(p1 == p2);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment