Commit cb801f60 authored by Paul's avatar Paul
Browse files

Format

parent 010f59e5
...@@ -267,33 +267,29 @@ struct find_double_add_lit_broadcast ...@@ -267,33 +267,29 @@ struct find_double_add_lit_broadcast
struct find_inner_broadcast struct find_inner_broadcast
{ {
auto matcher() const auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); }
{
return pointwise(
match::all_of[match::inputs()](match::broadcast()));
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto broadcasts = ins->inputs(); auto broadcasts = ins->inputs();
if (broadcasts.empty()) if(broadcasts.empty())
return; return;
if (std::any_of(broadcasts.begin(), broadcasts.end(), [&](auto i) { if(std::any_of(broadcasts.begin(), broadcasts.end(), [&](auto i) {
return i->get_operator() != broadcasts.front()->get_operator(); return i->get_operator() != broadcasts.front()->get_operator();
})) }))
return; return;
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(), broadcasts.end(), std::back_inserter(inputs), [](auto i) { std::transform(broadcasts.begin(),
return i->inputs().front(); broadcasts.end(),
}); std::back_inserter(inputs),
if (std::any_of(inputs.begin(), inputs.end(), [&](auto i) { [](auto i) { return i->inputs().front(); });
return i->get_shape() != inputs.front()->get_shape(); if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
})) return i->get_shape() != inputs.front()->get_shape();
}))
return; return;
auto op = m.insert_instruction( auto op = m.insert_instruction(ins, ins->get_operator(), inputs);
ins, ins->get_operator(), inputs);
m.replace_instruction(ins, broadcasts.front()->get_operator(), op); m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
} }
}; };
...@@ -422,8 +418,9 @@ struct find_splits ...@@ -422,8 +418,9 @@ struct find_splits
{ {
auto matcher() const auto matcher() const
{ {
return match::any(match::any_of[match::outputs()](match::name("slice")( return match::any(
match::any_of[match::outputs()](match::pointwise(match::any_of(match::nargs(1), match::nargs(2))), reduction())))); match::any_of[match::outputs()](match::name("slice")(match::any_of[match::outputs()](
match::pointwise(match::any_of(match::nargs(1), match::nargs(2))), reduction()))));
} }
static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2) static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment