Commit 6d34b90f authored by Paul's avatar Paul
Browse files

Format

parent 00df057a
...@@ -191,8 +191,7 @@ struct find_dot_add ...@@ -191,8 +191,7 @@ struct find_dot_add
{ {
return match::name("dot")(match::either_arg(0, 1)( return match::name("dot")(match::either_arg(0, 1)(
match::name("add")( match::name("add")(
match::either_arg(0, 1)( match::either_arg(0, 1)(match::any().bind("x"),
match::any().bind("x"),
match::any_of(match::is_constant()).bind("b")), match::any_of(match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())), match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()), match::used_once()),
...@@ -210,7 +209,7 @@ struct find_dot_add ...@@ -210,7 +209,7 @@ struct find_dot_add
const bool flipped = a_ins == ins->inputs().back(); const bool flipped = a_ins == ins->inputs().back();
auto insert_dot = [&](auto x, auto y) { auto insert_dot = [&](auto x, auto y) {
if (flipped) if(flipped)
return m.insert_instruction(ins, make_op("dot"), y, x); return m.insert_instruction(ins, make_op("dot"), y, x);
else else
return m.insert_instruction(ins, make_op("dot"), x, y); return m.insert_instruction(ins, make_op("dot"), x, y);
...@@ -283,24 +282,24 @@ struct find_inner_broadcast ...@@ -283,24 +282,24 @@ struct find_inner_broadcast
{ {
auto matcher() const auto matcher() const
{ {
return pointwise( return pointwise(match::all_of[match::inputs()](
match::all_of[match::inputs()](match::broadcast_shape(), match::name("broadcast", "multibroadcast"))); match::broadcast_shape(), match::name("broadcast", "multibroadcast")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto inputs = ins->inputs(); auto inputs = ins->inputs();
if (inputs.empty()) if(inputs.empty())
return; return;
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
if (contains({"broadcast", "multibroadcast"}, i->name())) if(contains({"broadcast", "multibroadcast"}, i->name()))
return i->inputs().front(); return i->inputs().front();
else else
return i; return i;
}); });
if (not std::all_of(inputs.begin(), inputs.end(), [&](auto& x) { if(not std::all_of(inputs.begin(), inputs.end(), [&](auto& x) {
return x->get_shape() == inputs.front()->get_shape(); return x->get_shape() == inputs.front()->get_shape();
})) }))
return; return;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment