Commit 9288a1fe authored by Paul's avatar Paul
Browse files

Adjust permutation axis

parent 1e4e9c57
......@@ -699,17 +699,19 @@ struct find_transpose_slice
ins->get_shape().lens().begin() + axis,
[](auto x) { return x == 1; }))
return;
// Compute axis before transpose to use for unsqueeeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin();
// Make unsqeeze
auto unsqueeze = m.insert_instruction(
ins, make_op("unsqueeze", {{"axes", slice.axes}, {"steps", sdistance}}), ins->inputs());
ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", sdistance}}), ins->inputs());
// Make transpose
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) {
if(i >= axis)
if(i > preaxis)
return i + 1;
return i;
});
perm.insert(perm.begin(), axis);
perm.insert(perm.begin(), preaxis+1);
auto transpose =
m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
// Slice and sqeeze
......
......@@ -1233,14 +1233,16 @@ TEST_CASE(transpose_slice_non_packed_axis)
auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), slice);
m1.add_return({sqrt});
}
auto output_shapes = m1.get_output_shapes();
run_pass(m1);
EXPECT(m1.get_output_shapes() == output_shapes);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}, {"steps", {12}}}), x);
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {12}}}), x);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {1, 0, 3, 2, 4}}}), unsqueeze);
migraphx::make_op("transpose", {{"permutation", {3, 0, 2, 1, 4}}}), unsqueeze);
auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice);
......@@ -1270,14 +1272,16 @@ TEST_CASE(transpose_slice_non_packed_multi_axis)
transpose);
m1.add_return({slice1, transpose2, slice3});
}
auto output_shapes = m1.get_output_shapes();
run_pass(m1);
EXPECT(m1.get_output_shapes() == output_shapes);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}, {"steps", {12}}}), x);
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {12}}}), x);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {1, 0, 3, 2, 4}}}), unsqueeze);
migraphx::make_op("transpose", {{"permutation", {3, 0, 2, 1, 4}}}), unsqueeze);
auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto squeeze1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment