Commit b3af63ac authored by Paul's avatar Paul
Browse files

Horizontally fuse gemms that share the same weights

parent 1704bb04
...@@ -831,6 +831,60 @@ struct find_conv_dot_horiz_fusion ...@@ -831,6 +831,60 @@ struct find_conv_dot_horiz_fusion
} }
}; };
MIGRAPHX_PRED_MATCHER(horiz_dot_weights, instruction_ref ins)
{
auto pred = [&](auto name) {
return [=](auto i) {
return i->name() == name and i->inputs().back() == ins;
};
};
return std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")) > 1;
}
struct find_dot_horiz_fusion_weights
{
auto matcher() const { return horiz_dot_weights(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
std::vector<instruction_ref> dots;
std::copy_if(ins->outputs().begin(), ins->outputs().end(), std::back_inserter(dots), [&](auto i) {
return i->name() == "dot" and i->inputs().back() == ins;
});
std::sort(dots.begin(), dots.end(), by(std::less<>{}, [&](auto i) {
return std::distance(ins, i);
}));
// Check if used between operators
const bool is_used = std::any_of(dots.front(), dots.back(), [&](const auto& i) {
return std::any_of(i.inputs().begin(), i.inputs().end(), [&](auto input) {
return contains(dots, input);
});
});
if (is_used)
return;
std::vector<instruction_ref> args;
std::transform(
dots.begin(), dots.end(), std::back_inserter(args), [&](auto x) { return x->inputs().front(); });
auto axis = args.front()->get_shape().lens().size() - 2;
auto last = dots.back();
auto weights = last->inputs().back();
auto concat =
m.insert_instruction(last, make_op("concat", {{"axis", axis}}), args);
auto fused = m.insert_instruction(last, make_op("dot"), concat, weights);
int64_t offset = 0;
for(auto arg : dots)
{
int64_t len = arg->get_shape().lens()[axis];
auto slice = m.insert_instruction(last, make_op("slice",
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}), fused);
m.replace_instruction(arg, slice);
offset += len;
}
}
};
struct find_div_const struct find_div_const
{ {
auto matcher() const auto matcher() const
...@@ -1045,6 +1099,7 @@ void simplify_algebra::apply(module& m) const ...@@ -1045,6 +1099,7 @@ void simplify_algebra::apply(module& m) const
find_add_lit_broadcast{}, find_add_lit_broadcast{},
find_add_convs{}, find_add_convs{},
find_conv_dot_horiz_fusion{}, find_conv_dot_horiz_fusion{},
find_dot_horiz_fusion_weights{},
find_mul_conv{}, find_mul_conv{},
find_mul_slice_conv{}, find_mul_slice_conv{},
find_mul_add{}, find_mul_add{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment