Commit caf02990 authored by Paul's avatar Paul
Browse files

Formatting

parent a505910b
...@@ -75,14 +75,12 @@ bool is_no_transpose(const std::vector<int64_t>& dims) ...@@ -75,14 +75,12 @@ bool is_no_transpose(const std::vector<int64_t>& dims)
dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end(); dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
} }
template<class Vector, class Op> template <class Vector, class Op>
std::vector<int64_t> sort_permutation(const Vector& data, Op op) std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{ {
std::vector<std::int64_t> result(data.size()); std::vector<std::int64_t> result(data.size());
std::iota(result.begin(), result.end(), 0); std::iota(result.begin(), result.end(), 0);
std::sort(result.begin(), result.end(), [&](auto x, auto y) { std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
return op(data[x], data[y]);
});
return result; return result;
} }
...@@ -159,28 +157,26 @@ void simplify_reshapes::apply(program& p) const ...@@ -159,28 +157,26 @@ void simplify_reshapes::apply(program& p) const
} }
else if(ins->name() == "concat") else if(ins->name() == "concat")
{ {
if (ins->inputs().empty()) if(ins->inputs().empty())
continue; continue;
auto s = ins->inputs().front()->get_shape(); auto s = ins->inputs().front()->get_shape();
if (none_of(ins->inputs(), [&](auto i) { if(none_of(ins->inputs(), [&](auto i) { return i->get_shape().transposed(); }) or
return i->get_shape().transposed(); none_of(ins->inputs(), [&](auto i) { return i->get_shape() == s; }))
}) or none_of(ins->inputs(), [&](auto i) {
return i->get_shape() == s;
}))
continue; continue;
auto op = any_cast<op::concat>(ins->get_operator()); auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s); auto permutation = find_permutation(s);
auto ipermutaion = invert_permutation(permutation); auto ipermutaion = invert_permutation(permutation);
op.axis = ipermutaion[op.axis]; op.axis = ipermutaion[op.axis];
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) { std::transform(
return p.insert_instruction(ins, op::transpose{permutation}, i); ins->inputs().begin(),
}); ins->inputs().end(),
std::back_inserter(inputs),
[&](auto i) { return p.insert_instruction(ins, op::transpose{permutation}, i); });
auto concat = p.insert_instruction(ins, op, inputs); auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutaion}, concat); auto t = p.insert_instruction(ins, op::transpose{ipermutaion}, concat);
p.replace_instruction(ins, t); p.replace_instruction(ins, t);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment