Commit 7f4c7809 authored by Paul's avatar Paul
Browse files

Formatting

parent 2e51006e
...@@ -21,7 +21,7 @@ struct matcher_context ...@@ -21,7 +21,7 @@ struct matcher_context
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; } instruction_ref not_found() const { return last; }
template<class M> template <class M>
bool matched(M m, instruction_ref ins) bool matched(M m, instruction_ref ins)
{ {
return m.match(*this, ins) != this->not_found(); return m.match(*this, ins) != this->not_found();
...@@ -240,7 +240,7 @@ void find_matches(program& p, Ms&&... ms) ...@@ -240,7 +240,7 @@ void find_matches(program& p, Ms&&... ms)
} }
} }
template<class Op, bool Start, bool Matches> template <class Op, bool Start, bool Matches>
struct folder struct folder
{ {
template <class... Ts> template <class... Ts>
...@@ -257,7 +257,7 @@ struct folder ...@@ -257,7 +257,7 @@ struct folder
}); });
} }
template<class Selector> template <class Selector>
auto operator[](Selector select) const auto operator[](Selector select) const
{ {
return [=](auto... ms) { return [=](auto... ms) {
...@@ -284,7 +284,7 @@ const constexpr auto none_of = folder<std::logical_or<bool>, false, false>{}; ...@@ -284,7 +284,7 @@ const constexpr auto none_of = folder<std::logical_or<bool>, false, false>{};
inline auto inputs() inline auto inputs()
{ {
return [](auto ins, auto f) { return [](auto ins, auto f) {
for(auto&& x:ins->inputs()) for(auto&& x : ins->inputs())
f(x); f(x);
}; };
} }
...@@ -292,7 +292,7 @@ inline auto inputs() ...@@ -292,7 +292,7 @@ inline auto inputs()
inline auto outputs() inline auto outputs()
{ {
return [](auto ins, auto f) { return [](auto ins, auto f) {
for(auto&& x:ins->outputs()) for(auto&& x : ins->outputs())
f(x); f(x);
}; };
} }
...@@ -312,12 +312,11 @@ MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins) ...@@ -312,12 +312,11 @@ MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
MIGRAPHX_PRED_MATCHER(same_shapes, instruction_ref ins) MIGRAPHX_PRED_MATCHER(same_shapes, instruction_ref ins)
{ {
if (ins->inputs().empty()) if(ins->inputs().empty())
return false; return false;
auto s = ins->inputs().front()->get_shape(); auto s = ins->inputs().front()->get_shape();
return std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return std::all_of(
return x->get_shape() == s; ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
});
} }
MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
...@@ -336,7 +335,7 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins) ...@@ -336,7 +335,7 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return ctx.not_found(); return ctx.not_found();
} }
template<class... Ms> template <class... Ms>
auto skip_output(Ms... ms) auto skip_output(Ms... ms)
{ {
auto m = any_of(ms...); auto m = any_of(ms...);
...@@ -345,10 +344,10 @@ auto skip_output(Ms... ms) ...@@ -345,10 +344,10 @@ auto skip_output(Ms... ms)
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
{ {
auto next = ins->outputs().front(); auto next = ins->outputs().front();
if (ctx.matched(m, next)) if(ctx.matched(m, next))
{ {
auto skipped_next = self(next); auto skipped_next = self(next);
if (skipped_next != ctx.not_found()) if(skipped_next != ctx.not_found())
return skipped_next; return skipped_next;
} }
return next; return next;
...@@ -366,8 +365,9 @@ inline auto name(std::string s) ...@@ -366,8 +365,9 @@ inline auto name(std::string s)
inline auto name(std::unordered_set<std::string> names) inline auto name(std::unordered_set<std::string> names)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
[ =, names = std::move(names) ](instruction_ref ins) { return names.count(ins->name()) > 0; }); return names.count(ins->name()) > 0;
});
} }
inline auto nargs(std::size_t n) inline auto nargs(std::size_t n)
......
...@@ -25,10 +25,7 @@ const auto& reshaper_names() ...@@ -25,10 +25,7 @@ const auto& reshaper_names()
return names; return names;
} }
bool is_reshaper(instruction_ref ins) bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); }
{
return contains(reshaper_names(), ins->name());
}
instruction_ref find_transpose_input(instruction_ref ins) instruction_ref find_transpose_input(instruction_ref ins)
{ {
...@@ -90,7 +87,8 @@ struct find_reshaper ...@@ -90,7 +87,8 @@ struct find_reshaper
{ {
auto matcher() const auto matcher() const
{ {
return match::name(reshaper_names())(match::any_of[match::outputs()](match::name(reshaper_names()))); return match::name(reshaper_names())(
match::any_of[match::outputs()](match::name(reshaper_names())));
} }
void apply(program& p, match::matcher_result mr) const void apply(program& p, match::matcher_result mr) const
...@@ -139,7 +137,8 @@ struct find_transpose ...@@ -139,7 +137,8 @@ struct find_transpose
{ {
auto matcher() const auto matcher() const
{ {
return match::name("transpose")(match::none_of(match::skip_output(match::name("contiguous"))(match::name("transpose")))); return match::name("transpose")(match::none_of(
match::skip_output(match::name("contiguous"))(match::name("transpose"))));
} }
void apply(program& p, match::matcher_result mr) const void apply(program& p, match::matcher_result mr) const
...@@ -172,7 +171,8 @@ struct find_concat_transpose ...@@ -172,7 +171,8 @@ struct find_concat_transpose
{ {
auto matcher() const auto matcher() const
{ {
return match::name("concat")(match::same_shapes(), match::all_of[match::inputs()](match::transpose_shape())); return match::name("concat")(match::same_shapes(),
match::all_of[match::inputs()](match::transpose_shape()));
} }
void apply(program& p, match::matcher_result mr) const void apply(program& p, match::matcher_result mr) const
...@@ -187,10 +187,9 @@ struct find_concat_transpose ...@@ -187,10 +187,9 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform( std::transform(
ins->inputs().begin(), ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
ins->inputs().end(), return p.insert_instruction(ins, op::transpose{permutation}, i);
std::back_inserter(inputs), });
[&](auto i) { return p.insert_instruction(ins, op::transpose{permutation}, i); });
auto concat = p.insert_instruction(ins, op, inputs); auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutaion}, concat); auto t = p.insert_instruction(ins, op::transpose{ipermutaion}, concat);
p.replace_instruction(ins, t); p.replace_instruction(ins, t);
...@@ -207,11 +206,7 @@ void simplify_reshapes::apply(program& p) const ...@@ -207,11 +206,7 @@ void simplify_reshapes::apply(program& p) const
// Skip possible dead instructions // Skip possible dead instructions
if(ins->outputs().empty() and ins != end) if(ins->outputs().empty() and ins != end)
continue; continue;
match::find_matches(p, ins, match::find_matches(p, ins, find_reshaper{}, find_transpose{}, find_concat_transpose{});
find_reshaper{},
find_transpose{},
find_concat_transpose{}
);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment