"src/vscode:/vscode.git/clone" did not exist on "4af5b21d56c24ee3ae86b28998917dba37a7cfc7"
Commit 010f59e5 authored by Paul's avatar Paul
Browse files

Improve inner broadcasts for more operators

parent ad73abbc
...@@ -562,6 +562,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in ...@@ -562,6 +562,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
return nullopt; return nullopt;
} }
MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
{
return contains({"broadcast", "multibroadcast"}, ins->name());
}
template <class... Ms> template <class... Ms>
auto skip(Ms... ms) auto skip(Ms... ms)
{ {
...@@ -811,8 +816,7 @@ inline auto has_attribute(const std::string& name) ...@@ -811,8 +816,7 @@ inline auto has_attribute(const std::string& name)
template <class... Ms> template <class... Ms>
auto pointwise(Ms... ms) auto pointwise(Ms... ms)
{ {
return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)), return match::has_attribute("pointwise")(ms...);
ms...);
} }
} // namespace match } // namespace match
......
...@@ -270,25 +270,31 @@ struct find_inner_broadcast ...@@ -270,25 +270,31 @@ struct find_inner_broadcast
auto matcher() const auto matcher() const
{ {
return pointwise( return pointwise(
match::nargs(2), match::all_of[match::inputs()](match::broadcast()));
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto broadcasts = ins->inputs();
auto y_ins = r.instructions["y"]; if (broadcasts.empty())
return;
auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator()); if (std::any_of(broadcasts.begin(), broadcasts.end(), [&](auto i) {
auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator()); return i->get_operator() != broadcasts.front()->get_operator();
}))
if(xbroadcast.axis != ybroadcast.axis) return;
std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(), broadcasts.end(), std::back_inserter(inputs), [](auto i) {
return i->inputs().front();
});
if (std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape();
}))
return; return;
auto op = m.insert_instruction( auto op = m.insert_instruction(
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front()); ins, ins->get_operator(), inputs);
m.replace_instruction(ins, xbroadcast, op); m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
} }
}; };
...@@ -417,7 +423,7 @@ struct find_splits ...@@ -417,7 +423,7 @@ struct find_splits
auto matcher() const auto matcher() const
{ {
return match::any(match::any_of[match::outputs()](match::name("slice")( return match::any(match::any_of[match::outputs()](match::name("slice")(
match::any_of[match::outputs()](match::pointwise(), reduction())))); match::any_of[match::outputs()](match::pointwise(match::any_of(match::nargs(1), match::nargs(2))), reduction()))));
} }
static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2) static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment