Commit 9288a1fe authored by Paul's avatar Paul
Browse files

Adjust permutation axis

parent 1e4e9c57
...@@ -699,17 +699,19 @@ struct find_transpose_slice ...@@ -699,17 +699,19 @@ struct find_transpose_slice
ins->get_shape().lens().begin() + axis, ins->get_shape().lens().begin() + axis,
[](auto x) { return x == 1; })) [](auto x) { return x == 1; }))
return; return;
// Compute axis before transpose to use for unsqueeeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin();
// Make unsqeeze // Make unsqeeze
auto unsqueeze = m.insert_instruction( auto unsqueeze = m.insert_instruction(
ins, make_op("unsqueeze", {{"axes", slice.axes}, {"steps", sdistance}}), ins->inputs()); ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", sdistance}}), ins->inputs());
// Make transpose // Make transpose
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) { std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) {
if(i >= axis) if(i > preaxis)
return i + 1; return i + 1;
return i; return i;
}); });
perm.insert(perm.begin(), axis); perm.insert(perm.begin(), preaxis+1);
auto transpose = auto transpose =
m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze); m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
// Slice and sqeeze // Slice and sqeeze
......
...@@ -1233,14 +1233,16 @@ TEST_CASE(transpose_slice_non_packed_axis) ...@@ -1233,14 +1233,16 @@ TEST_CASE(transpose_slice_non_packed_axis)
auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), slice); auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), slice);
m1.add_return({sqrt}); m1.add_return({sqrt});
} }
auto output_shapes = m1.get_output_shapes();
run_pass(m1); run_pass(m1);
EXPECT(m1.get_output_shapes() == output_shapes);
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}}); auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze = auto unsqueeze =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}, {"steps", {12}}}), x); m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {12}}}), x);
auto transpose = m2.add_instruction( auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {1, 0, 3, 2, 4}}}), unsqueeze); migraphx::make_op("transpose", {{"permutation", {3, 0, 2, 1, 4}}}), unsqueeze);
auto slice = m2.add_instruction( auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice); auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice);
...@@ -1270,14 +1272,16 @@ TEST_CASE(transpose_slice_non_packed_multi_axis) ...@@ -1270,14 +1272,16 @@ TEST_CASE(transpose_slice_non_packed_multi_axis)
transpose); transpose);
m1.add_return({slice1, transpose2, slice3}); m1.add_return({slice1, transpose2, slice3});
} }
auto output_shapes = m1.get_output_shapes();
run_pass(m1); run_pass(m1);
EXPECT(m1.get_output_shapes() == output_shapes);
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}}); auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze = auto unsqueeze =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}, {"steps", {12}}}), x); m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {12}}}), x);
auto transpose = m2.add_instruction( auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {1, 0, 3, 2, 4}}}), unsqueeze); migraphx::make_op("transpose", {{"permutation", {3, 0, 2, 1, 4}}}), unsqueeze);
auto slice1 = m2.add_instruction( auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto squeeze1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice1); auto squeeze1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment