Commit a6ea39ea authored by Paul's avatar Paul
Browse files

Formatting

parent c46b5480
...@@ -469,7 +469,7 @@ inline auto name(std::unordered_set<std::string> names) ...@@ -469,7 +469,7 @@ inline auto name(std::unordered_set<std::string> names)
}); });
} }
template<class... Ts> template <class... Ts>
inline auto name(std::string s, Ts... xs) inline auto name(std::string s, Ts... xs)
{ {
return name(std::unordered_set<std::string>{s, xs...}); return name(std::unordered_set<std::string>{s, xs...});
......
...@@ -140,7 +140,8 @@ struct find_inner_broadcast ...@@ -140,7 +140,8 @@ struct find_inner_broadcast
{ {
auto matcher() const auto matcher() const
{ {
return match::name("mul", "add")(match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y"))); return match::name("mul", "add")(
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -152,10 +153,11 @@ struct find_inner_broadcast ...@@ -152,10 +153,11 @@ struct find_inner_broadcast
auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator()); auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator());
auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator()); auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator());
if (xbroadcast.axis != ybroadcast.axis) if(xbroadcast.axis != ybroadcast.axis)
return; return;
auto op = p.insert_instruction(ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front()); auto op = p.insert_instruction(
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front());
p.replace_instruction(ins, xbroadcast, op); p.replace_instruction(ins, xbroadcast, op);
} }
}; };
......
...@@ -325,8 +325,8 @@ struct find_mul_add ...@@ -325,8 +325,8 @@ struct find_mul_add
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::add")( return match::name("gpu::add")(match::either_arg(0, 1)(
match::either_arg(0, 1)(match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b"))); match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -350,7 +350,8 @@ struct find_mul_add_relu ...@@ -350,7 +350,8 @@ struct find_mul_add_relu
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::relu")(match::arg(0)(match::name("hip::mul_add")(match::used_once()).bind("mul_add"))); return match::name("gpu::relu")(
match::arg(0)(match::name("hip::mul_add")(match::used_once()).bind("mul_add")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment