Commit 7f4c7809 authored by Paul's avatar Paul
Browse files

Formatting

parent 2e51006e
......@@ -21,7 +21,7 @@ struct matcher_context
std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; }
template<class M>
template <class M>
bool matched(M m, instruction_ref ins)
{
return m.match(*this, ins) != this->not_found();
......@@ -240,7 +240,7 @@ void find_matches(program& p, Ms&&... ms)
}
}
template<class Op, bool Start, bool Matches>
template <class Op, bool Start, bool Matches>
struct folder
{
template <class... Ts>
......@@ -257,7 +257,7 @@ struct folder
});
}
template<class Selector>
template <class Selector>
auto operator[](Selector select) const
{
return [=](auto... ms) {
......@@ -284,7 +284,7 @@ const constexpr auto none_of = folder<std::logical_or<bool>, false, false>{};
inline auto inputs()
{
return [](auto ins, auto f) {
for(auto&& x:ins->inputs())
for(auto&& x : ins->inputs())
f(x);
};
}
......@@ -292,7 +292,7 @@ inline auto inputs()
inline auto outputs()
{
return [](auto ins, auto f) {
for(auto&& x:ins->outputs())
for(auto&& x : ins->outputs())
f(x);
};
}
......@@ -312,12 +312,11 @@ MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
MIGRAPHX_PRED_MATCHER(same_shapes, instruction_ref ins)
{
if (ins->inputs().empty())
if(ins->inputs().empty())
return false;
auto s = ins->inputs().front()->get_shape();
return std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto x) {
return x->get_shape() == s;
});
return std::all_of(
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
......@@ -336,7 +335,7 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return ctx.not_found();
}
template<class... Ms>
template <class... Ms>
auto skip_output(Ms... ms)
{
auto m = any_of(ms...);
......@@ -345,10 +344,10 @@ auto skip_output(Ms... ms)
if(ins->outputs().size() == 1)
{
auto next = ins->outputs().front();
if (ctx.matched(m, next))
if(ctx.matched(m, next))
{
auto skipped_next = self(next);
if (skipped_next != ctx.not_found())
if(skipped_next != ctx.not_found())
return skipped_next;
}
return next;
......@@ -366,8 +365,9 @@ inline auto name(std::string s)
inline auto name(std::unordered_set<std::string> names)
{
return make_basic_pred_matcher(
[ =, names = std::move(names) ](instruction_ref ins) { return names.count(ins->name()) > 0; });
return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
return names.count(ins->name()) > 0;
});
}
inline auto nargs(std::size_t n)
......
......@@ -25,10 +25,7 @@ const auto& reshaper_names()
return names;
}
bool is_reshaper(instruction_ref ins)
{
return contains(reshaper_names(), ins->name());
}
bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); }
instruction_ref find_transpose_input(instruction_ref ins)
{
......@@ -90,7 +87,8 @@ struct find_reshaper
{
auto matcher() const
{
return match::name(reshaper_names())(match::any_of[match::outputs()](match::name(reshaper_names())));
return match::name(reshaper_names())(
match::any_of[match::outputs()](match::name(reshaper_names())));
}
void apply(program& p, match::matcher_result mr) const
......@@ -139,7 +137,8 @@ struct find_transpose
{
auto matcher() const
{
return match::name("transpose")(match::none_of(match::skip_output(match::name("contiguous"))(match::name("transpose"))));
return match::name("transpose")(match::none_of(
match::skip_output(match::name("contiguous"))(match::name("transpose"))));
}
void apply(program& p, match::matcher_result mr) const
......@@ -172,7 +171,8 @@ struct find_concat_transpose
{
auto matcher() const
{
return match::name("concat")(match::same_shapes(), match::all_of[match::inputs()](match::transpose_shape()));
return match::name("concat")(match::same_shapes(),
match::all_of[match::inputs()](match::transpose_shape()));
}
void apply(program& p, match::matcher_result mr) const
......@@ -187,10 +187,9 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs;
std::transform(
ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto i) { return p.insert_instruction(ins, op::transpose{permutation}, i); });
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
return p.insert_instruction(ins, op::transpose{permutation}, i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutaion}, concat);
p.replace_instruction(ins, t);
......@@ -207,11 +206,7 @@ void simplify_reshapes::apply(program& p) const
// Skip possible dead instructions
if(ins->outputs().empty() and ins != end)
continue;
match::find_matches(p, ins,
find_reshaper{},
find_transpose{},
find_concat_transpose{}
);
match::find_matches(p, ins, find_reshaper{}, find_transpose{}, find_concat_transpose{});
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment