"src/vscode:/vscode.git/clone" did not exist on "603adbe60a62cd52568e5603b73a32cc580115a4"
Commit dcd3d04b authored by Paul's avatar Paul
Browse files

Horizontally fuse contiguous

parent a27dd28c
......@@ -578,6 +578,61 @@ struct find_transpose_contiguous_reshaper_unary
}
};
struct find_transpose_slice
{
auto matcher() const
{
return match::any(match::any_of[match::outputs()](match::name("slice")(match::output(match::name("transpose")))));
}
static std::vector<int64_t> find_common_perm(const std::vector<instruction_ref>& transposes)
{
std::map<std::vector<int64_t>, int64_t> count;
for(auto t:transposes)
{
auto perm = t->get_operator().to_value()["permutation"].to_vector<int64_t>();
count[perm]++;
}
return std::max_element(
count.begin(), count.end(), by(std::less<>{}, [](auto&& p) { return p.second; }))->first;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
std::vector<instruction_ref> splits;
std::copy_if(ins->outputs().begin(), ins->outputs().end(), std::back_inserter(splits), [&](instruction_ref out) {
return out->name() == "slice" and out->outputs().size() == 1 and out->outputs().front()->name() == "transpose";
});
if (splits.size() < 2)
return;
std::vector<instruction_ref> transposes;
std::transform(splits.begin(), splits.end(), std::back_inserter(transposes), [](auto split) {
return split->outputs().front();
});
auto perm = find_common_perm(transposes);
auto iperm = invert_permutation(perm);
auto pre = m.insert_instruction(std::next(ins), make_op("transpose", {{"permutation", perm}}), ins);
for(auto i:range(transposes.size()))
{
auto split = splits[i];
auto t = transposes[i];
auto op = any_cast<op::slice>(split->get_operator());
for(auto& axis:op.axes)
{
axis = iperm[axis];
}
auto new_ins = m.insert_instruction(t, op, pre);
if (t->get_operator() != pre->get_operator())
{
auto curr = t->get_operator().to_value()["permutation"].to_vector<int64_t>();
new_ins = m.insert_instruction(t, make_op("transpose", {{"permutation", reorder_dims(iperm, curr)}}), new_ins);
}
m.replace_instruction(t, new_ins);
}
}
};
void simplify_reshapes::apply(module& m) const
{
for(int i = 0; i < 2; i++)
......@@ -593,6 +648,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_convert{},
find_nested_slice{},
find_nested_concat{},
find_transpose_slice{},
find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(m);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment