Commit a505910b authored by Paul's avatar Paul
Browse files

Simplify concat/transpose

parent 692274e5
......@@ -3,6 +3,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_set>
......@@ -59,6 +60,11 @@ std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t
return result;
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return reorder_dims(permutation, permutation);
}
bool is_no_transpose(const std::vector<int64_t>& dims)
{
if(dims.empty())
......@@ -69,6 +75,22 @@ bool is_no_transpose(const std::vector<int64_t>& dims)
dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
}
template<class Vector, class Op>
std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
std::vector<std::int64_t> result(data.size());
std::iota(result.begin(), result.end(), 0);
std::sort(result.begin(), result.end(), [&](auto x, auto y) {
return op(data[x], data[y]);
});
return result;
}
std::vector<int64_t> find_permutation(const shape& s)
{
return sort_permutation(s.strides(), std::greater<>{});
}
void simplify_reshapes::apply(program& p) const
{
auto end = std::prev(p.end());
......@@ -135,6 +157,31 @@ void simplify_reshapes::apply(program& p) const
p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
}
}
else if(ins->name() == "concat")
{
if (ins->inputs().empty())
continue;
auto s = ins->inputs().front()->get_shape();
if (none_of(ins->inputs(), [&](auto i) {
return i->get_shape().transposed();
}) or none_of(ins->inputs(), [&](auto i) {
return i->get_shape() == s;
}))
continue;
auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s);
auto ipermutaion = invert_permutation(permutation);
op.axis = ipermutaion[op.axis];
std::vector<instruction_ref> inputs;
std::transform(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
return p.insert_instruction(ins, op::transpose{permutation}, i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutaion}, concat);
p.replace_instruction(ins, t);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment