Commit f92195d0 authored by Paul's avatar Paul
Browse files

Format

parent b3af63ac
......@@ -834,9 +834,7 @@ struct find_conv_dot_horiz_fusion
MIGRAPHX_PRED_MATCHER(horiz_dot_weights, instruction_ref ins)
{
auto pred = [&](auto name) {
return [=](auto i) {
return i->name() == name and i->inputs().back() == ins;
};
return [=](auto i) { return i->name() == name and i->inputs().back() == ins; };
};
return std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")) > 1;
}
......@@ -849,9 +847,10 @@ struct find_dot_horiz_fusion_weights
{
auto ins = r.result;
std::vector<instruction_ref> dots;
std::copy_if(ins->outputs().begin(), ins->outputs().end(), std::back_inserter(dots), [&](auto i) {
return i->name() == "dot" and i->inputs().back() == ins;
});
std::copy_if(ins->outputs().begin(),
ins->outputs().end(),
std::back_inserter(dots),
[&](auto i) { return i->name() == "dot" and i->inputs().back() == ins; });
std::sort(dots.begin(), dots.end(), by(std::less<>{}, [&](auto i) {
return std::distance(ins, i);
}));
......@@ -861,30 +860,32 @@ struct find_dot_horiz_fusion_weights
return contains(dots, input);
});
});
if (is_used)
if(is_used)
return;
std::vector<instruction_ref> args;
std::transform(
dots.begin(), dots.end(), std::back_inserter(args), [&](auto x) { return x->inputs().front(); });
std::transform(dots.begin(), dots.end(), std::back_inserter(args), [&](auto x) {
return x->inputs().front();
});
auto axis = args.front()->get_shape().lens().size() - 2;
auto last = dots.back();
auto weights = last->inputs().back();
auto concat =
m.insert_instruction(last, make_op("concat", {{"axis", axis}}), args);
auto concat = m.insert_instruction(last, make_op("concat", {{"axis", axis}}), args);
auto fused = m.insert_instruction(last, make_op("dot"), concat, weights);
int64_t offset = 0;
for(auto arg : dots)
{
int64_t len = arg->get_shape().lens()[axis];
auto slice = m.insert_instruction(last, make_op("slice",
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}), fused);
auto slice = m.insert_instruction(
last,
make_op("slice",
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}),
fused);
m.replace_instruction(arg, slice);
offset += len;
}
}
};
struct find_div_const
{
auto matcher() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment