Commit 39bbf87c authored by Paul's avatar Paul
Browse files

Format

parent dcd3d04b
...@@ -582,51 +582,60 @@ struct find_transpose_slice ...@@ -582,51 +582,60 @@ struct find_transpose_slice
{ {
auto matcher() const auto matcher() const
{ {
return match::any(match::any_of[match::outputs()](match::name("slice")(match::output(match::name("transpose"))))); return match::any(match::any_of[match::outputs()](
match::name("slice")(match::output(match::name("transpose")))));
} }
static std::vector<int64_t> find_common_perm(const std::vector<instruction_ref>& transposes) static std::vector<int64_t> find_common_perm(const std::vector<instruction_ref>& transposes)
{ {
std::map<std::vector<int64_t>, int64_t> count; std::map<std::vector<int64_t>, int64_t> count;
for(auto t:transposes) for(auto t : transposes)
{ {
auto perm = t->get_operator().to_value()["permutation"].to_vector<int64_t>(); auto perm = t->get_operator().to_value()["permutation"].to_vector<int64_t>();
count[perm]++; count[perm]++;
} }
return std::max_element( return std::max_element(
count.begin(), count.end(), by(std::less<>{}, [](auto&& p) { return p.second; }))->first; count.begin(), count.end(), by(std::less<>{}, [](auto&& p) { return p.second; }))
->first;
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
std::vector<instruction_ref> splits; std::vector<instruction_ref> splits;
std::copy_if(ins->outputs().begin(), ins->outputs().end(), std::back_inserter(splits), [&](instruction_ref out) { std::copy_if(ins->outputs().begin(),
return out->name() == "slice" and out->outputs().size() == 1 and out->outputs().front()->name() == "transpose"; ins->outputs().end(),
}); std::back_inserter(splits),
if (splits.size() < 2) [&](instruction_ref out) {
return out->name() == "slice" and out->outputs().size() == 1 and
out->outputs().front()->name() == "transpose";
});
if(splits.size() < 2)
return; return;
std::vector<instruction_ref> transposes; std::vector<instruction_ref> transposes;
std::transform(splits.begin(), splits.end(), std::back_inserter(transposes), [](auto split) { std::transform(splits.begin(),
return split->outputs().front(); splits.end(),
}); std::back_inserter(transposes),
auto perm = find_common_perm(transposes); [](auto split) { return split->outputs().front(); });
auto perm = find_common_perm(transposes);
auto iperm = invert_permutation(perm); auto iperm = invert_permutation(perm);
auto pre = m.insert_instruction(std::next(ins), make_op("transpose", {{"permutation", perm}}), ins); auto pre = m.insert_instruction(
for(auto i:range(transposes.size())) std::next(ins), make_op("transpose", {{"permutation", perm}}), ins);
for(auto i : range(transposes.size()))
{ {
auto split = splits[i]; auto split = splits[i];
auto t = transposes[i]; auto t = transposes[i];
auto op = any_cast<op::slice>(split->get_operator()); auto op = any_cast<op::slice>(split->get_operator());
for(auto& axis:op.axes) for(auto& axis : op.axes)
{ {
axis = iperm[axis]; axis = iperm[axis];
} }
auto new_ins = m.insert_instruction(t, op, pre); auto new_ins = m.insert_instruction(t, op, pre);
if (t->get_operator() != pre->get_operator()) if(t->get_operator() != pre->get_operator())
{ {
auto curr = t->get_operator().to_value()["permutation"].to_vector<int64_t>(); auto curr = t->get_operator().to_value()["permutation"].to_vector<int64_t>();
new_ins = m.insert_instruction(t, make_op("transpose", {{"permutation", reorder_dims(iperm, curr)}}), new_ins); new_ins = m.insert_instruction(
t, make_op("transpose", {{"permutation", reorder_dims(iperm, curr)}}), new_ins);
} }
m.replace_instruction(t, new_ins); m.replace_instruction(t, new_ins);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment