Commit 139eaeed authored by Paul's avatar Paul
Browse files

Fix permutation inversion

parent c87f5621
...@@ -54,10 +54,6 @@ std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t ...@@ -54,10 +54,6 @@ std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t
return result; return result;
} }
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return reorder_dims(permutation, permutation);
}
bool is_no_transpose(const std::vector<int64_t>& dims) bool is_no_transpose(const std::vector<int64_t>& dims)
{ {
...@@ -78,6 +74,11 @@ std::vector<int64_t> sort_permutation(const Vector& data, Op op) ...@@ -78,6 +74,11 @@ std::vector<int64_t> sort_permutation(const Vector& data, Op op)
return result; return result;
} }
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
std::vector<int64_t> find_permutation(const shape& s) std::vector<int64_t> find_permutation(const shape& s)
{ {
return sort_permutation(s.strides(), std::greater<>{}); return sort_permutation(s.strides(), std::greater<>{});
...@@ -189,7 +190,7 @@ struct find_concat_transpose ...@@ -189,7 +190,7 @@ struct find_concat_transpose
auto op = any_cast<op::concat>(ins->get_operator()); auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s); auto permutation = find_permutation(s);
auto ipermutation = invert_permutation(permutation); auto ipermutation = invert_permutation(permutation);
op.axis = permutation[op.axis]; op.axis = ipermutation[op.axis];
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform( std::transform(
...@@ -199,7 +200,8 @@ struct find_concat_transpose ...@@ -199,7 +200,8 @@ struct find_concat_transpose
return p.insert_instruction(ins, op::transpose{permutation}, i); return p.insert_instruction(ins, op::transpose{permutation}, i);
}); });
auto concat = p.insert_instruction(ins, op, inputs); auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{permutation}, concat); auto t = p.insert_instruction(ins, op::transpose{ipermutation}, concat);
assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t); p.replace_instruction(ins, t);
} }
}; };
...@@ -214,8 +216,9 @@ void simplify_reshapes::apply(program& p) const ...@@ -214,8 +216,9 @@ void simplify_reshapes::apply(program& p) const
// Skip possible dead instructions // Skip possible dead instructions
if(ins->outputs().empty() and ins != end) if(ins->outputs().empty() and ins != end)
continue; continue;
match::find_matches(p, ins, find_nop_reshapes{}, find_reshaper{}, find_transpose{} match::find_matches(p, ins, find_nop_reshapes{}, find_reshaper{},
// find_concat_transpose{} find_transpose{},
find_concat_transpose{}
); );
} }
} }
......
...@@ -284,4 +284,26 @@ TEST_CASE(concat_transpose1) ...@@ -284,4 +284,26 @@ TEST_CASE(concat_transpose1)
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3); EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
} }
TEST_CASE(concat_transpose2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != p.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment