"src/vscode:/vscode.git/clone" did not exist on "4e3ca5861a59c7e82aad43d49ec3a622188d96d8"
Commit f92195d0 authored by Paul's avatar Paul
Browse files

Format

parent b3af63ac
...@@ -834,9 +834,7 @@ struct find_conv_dot_horiz_fusion ...@@ -834,9 +834,7 @@ struct find_conv_dot_horiz_fusion
MIGRAPHX_PRED_MATCHER(horiz_dot_weights, instruction_ref ins) MIGRAPHX_PRED_MATCHER(horiz_dot_weights, instruction_ref ins)
{ {
auto pred = [&](auto name) { auto pred = [&](auto name) {
return [=](auto i) { return [=](auto i) { return i->name() == name and i->inputs().back() == ins; };
return i->name() == name and i->inputs().back() == ins;
};
}; };
return std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")) > 1; return std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")) > 1;
} }
...@@ -849,9 +847,10 @@ struct find_dot_horiz_fusion_weights ...@@ -849,9 +847,10 @@ struct find_dot_horiz_fusion_weights
{ {
auto ins = r.result; auto ins = r.result;
std::vector<instruction_ref> dots; std::vector<instruction_ref> dots;
std::copy_if(ins->outputs().begin(), ins->outputs().end(), std::back_inserter(dots), [&](auto i) { std::copy_if(ins->outputs().begin(),
return i->name() == "dot" and i->inputs().back() == ins; ins->outputs().end(),
}); std::back_inserter(dots),
[&](auto i) { return i->name() == "dot" and i->inputs().back() == ins; });
std::sort(dots.begin(), dots.end(), by(std::less<>{}, [&](auto i) { std::sort(dots.begin(), dots.end(), by(std::less<>{}, [&](auto i) {
return std::distance(ins, i); return std::distance(ins, i);
})); }));
...@@ -861,30 +860,32 @@ struct find_dot_horiz_fusion_weights ...@@ -861,30 +860,32 @@ struct find_dot_horiz_fusion_weights
return contains(dots, input); return contains(dots, input);
}); });
}); });
if (is_used) if(is_used)
return; return;
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
std::transform( std::transform(dots.begin(), dots.end(), std::back_inserter(args), [&](auto x) {
dots.begin(), dots.end(), std::back_inserter(args), [&](auto x) { return x->inputs().front(); }); return x->inputs().front();
});
auto axis = args.front()->get_shape().lens().size() - 2; auto axis = args.front()->get_shape().lens().size() - 2;
auto last = dots.back(); auto last = dots.back();
auto weights = last->inputs().back(); auto weights = last->inputs().back();
auto concat = auto concat = m.insert_instruction(last, make_op("concat", {{"axis", axis}}), args);
m.insert_instruction(last, make_op("concat", {{"axis", axis}}), args);
auto fused = m.insert_instruction(last, make_op("dot"), concat, weights); auto fused = m.insert_instruction(last, make_op("dot"), concat, weights);
int64_t offset = 0; int64_t offset = 0;
for(auto arg : dots) for(auto arg : dots)
{ {
int64_t len = arg->get_shape().lens()[axis]; int64_t len = arg->get_shape().lens()[axis];
auto slice = m.insert_instruction(last, make_op("slice", auto slice = m.insert_instruction(
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}), fused); last,
make_op("slice",
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}),
fused);
m.replace_instruction(arg, slice); m.replace_instruction(arg, slice);
offset += len; offset += len;
} }
} }
}; };
struct find_div_const struct find_div_const
{ {
auto matcher() const auto matcher() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment