"vscode:/vscode.git/clone" did not exist on "b76a90438d83cc9bc59b955514c41d52d7c4a61b"
Commit 139eaeed authored by Paul's avatar Paul
Browse files

Fix permutation inversion

parent c87f5621
......@@ -54,10 +54,6 @@ std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t
return result;
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return reorder_dims(permutation, permutation);
}
bool is_no_transpose(const std::vector<int64_t>& dims)
{
......@@ -78,6 +74,11 @@ std::vector<int64_t> sort_permutation(const Vector& data, Op op)
return result;
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
std::vector<int64_t> find_permutation(const shape& s)
{
return sort_permutation(s.strides(), std::greater<>{});
......@@ -189,7 +190,7 @@ struct find_concat_transpose
auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s);
auto ipermutation = invert_permutation(permutation);
op.axis = permutation[op.axis];
op.axis = ipermutation[op.axis];
std::vector<instruction_ref> inputs;
std::transform(
......@@ -199,7 +200,8 @@ struct find_concat_transpose
return p.insert_instruction(ins, op::transpose{permutation}, i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{permutation}, concat);
auto t = p.insert_instruction(ins, op::transpose{ipermutation}, concat);
assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t);
}
};
......@@ -214,8 +216,9 @@ void simplify_reshapes::apply(program& p) const
// Skip possible dead instructions
if(ins->outputs().empty() and ins != end)
continue;
match::find_matches(p, ins, find_nop_reshapes{}, find_reshaper{}, find_transpose{}
// find_concat_transpose{}
match::find_matches(p, ins, find_nop_reshapes{}, find_reshaper{},
find_transpose{},
find_concat_transpose{}
);
}
}
......
......@@ -284,4 +284,26 @@ TEST_CASE(concat_transpose1)
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
}
TEST_CASE(concat_transpose2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != p.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment