Commit 7f4c7809 authored by Paul's avatar Paul
Browse files

Formatting

parent 2e51006e
......@@ -21,7 +21,7 @@ struct matcher_context
std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; }
template<class M>
template <class M>
bool matched(M m, instruction_ref ins)
{
return m.match(*this, ins) != this->not_found();
......@@ -240,7 +240,7 @@ void find_matches(program& p, Ms&&... ms)
}
}
template<class Op, bool Start, bool Matches>
template <class Op, bool Start, bool Matches>
struct folder
{
template <class... Ts>
......@@ -257,7 +257,7 @@ struct folder
});
}
template<class Selector>
template <class Selector>
auto operator[](Selector select) const
{
return [=](auto... ms) {
......@@ -266,8 +266,8 @@ struct folder
bool matches = Start;
select(start, [&](auto ins) {
matches = op(matches, fold([&](auto x, auto y) {
return op(x, y.match(ctx, ins) == ctx.not_found());
})(Start, ms...));
return op(x, y.match(ctx, ins) == ctx.not_found());
})(Start, ms...));
});
if(matches == Matches)
return start;
......@@ -277,14 +277,14 @@ struct folder
}
};
const constexpr auto all_of = folder<std::logical_and<bool>, true, true>{};
const constexpr auto any_of = folder<std::logical_or<bool>, false, true>{};
const constexpr auto all_of = folder<std::logical_and<bool>, true, true>{};
const constexpr auto any_of = folder<std::logical_or<bool>, false, true>{};
const constexpr auto none_of = folder<std::logical_or<bool>, false, false>{};
inline auto inputs()
{
return [](auto ins, auto f) {
for(auto&& x:ins->inputs())
for(auto&& x : ins->inputs())
f(x);
};
}
......@@ -292,7 +292,7 @@ inline auto inputs()
inline auto outputs()
{
return [](auto ins, auto f) {
for(auto&& x:ins->outputs())
for(auto&& x : ins->outputs())
f(x);
};
}
......@@ -312,12 +312,11 @@ MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
MIGRAPHX_PRED_MATCHER(same_shapes, instruction_ref ins)
{
if (ins->inputs().empty())
if(ins->inputs().empty())
return false;
auto s = ins->inputs().front()->get_shape();
return std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto x) {
return x->get_shape() == s;
});
return std::all_of(
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
......@@ -336,7 +335,7 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return ctx.not_found();
}
template<class... Ms>
template <class... Ms>
auto skip_output(Ms... ms)
{
auto m = any_of(ms...);
......@@ -345,10 +344,10 @@ auto skip_output(Ms... ms)
if(ins->outputs().size() == 1)
{
auto next = ins->outputs().front();
if (ctx.matched(m, next))
if(ctx.matched(m, next))
{
auto skipped_next = self(next);
if (skipped_next != ctx.not_found())
if(skipped_next != ctx.not_found())
return skipped_next;
}
return next;
......@@ -366,8 +365,9 @@ inline auto name(std::string s)
inline auto name(std::unordered_set<std::string> names)
{
return make_basic_pred_matcher(
[ =, names = std::move(names) ](instruction_ref ins) { return names.count(ins->name()) > 0; });
return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
return names.count(ins->name()) > 0;
});
}
inline auto nargs(std::size_t n)
......
......@@ -25,10 +25,7 @@ const auto& reshaper_names()
return names;
}
bool is_reshaper(instruction_ref ins)
{
return contains(reshaper_names(), ins->name());
}
bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); }
instruction_ref find_transpose_input(instruction_ref ins)
{
......@@ -90,7 +87,8 @@ struct find_reshaper
{
auto matcher() const
{
return match::name(reshaper_names())(match::any_of[match::outputs()](match::name(reshaper_names())));
return match::name(reshaper_names())(
match::any_of[match::outputs()](match::name(reshaper_names())));
}
void apply(program& p, match::matcher_result mr) const
......@@ -139,14 +137,15 @@ struct find_transpose
{
auto matcher() const
{
return match::name("transpose")(match::none_of(match::skip_output(match::name("contiguous"))(match::name("transpose"))));
return match::name("transpose")(match::none_of(
match::skip_output(match::name("contiguous"))(match::name("transpose"))));
}
void apply(program& p, match::matcher_result mr) const
{
auto ins = mr.result;
auto x = ins;
auto t = ins;
auto x = ins;
auto t = ins;
std::vector<std::int64_t> dims(ins->get_shape().lens().size());
std::iota(dims.begin(), dims.end(), 0);
do
......@@ -172,13 +171,14 @@ struct find_concat_transpose
{
auto matcher() const
{
return match::name("concat")(match::same_shapes(), match::all_of[match::inputs()](match::transpose_shape()));
return match::name("concat")(match::same_shapes(),
match::all_of[match::inputs()](match::transpose_shape()));
}
void apply(program& p, match::matcher_result mr) const
{
auto ins = mr.result;
auto s = ins->inputs().front()->get_shape();
auto s = ins->inputs().front()->get_shape();
auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s);
......@@ -187,10 +187,9 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs;
std::transform(
ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto i) { return p.insert_instruction(ins, op::transpose{permutation}, i); });
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
return p.insert_instruction(ins, op::transpose{permutation}, i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutaion}, concat);
p.replace_instruction(ins, t);
......@@ -207,11 +206,7 @@ void simplify_reshapes::apply(program& p) const
// Skip possible dead instructions
if(ins->outputs().empty() and ins != end)
continue;
match::find_matches(p, ins,
find_reshaper{},
find_transpose{},
find_concat_transpose{}
);
match::find_matches(p, ins, find_reshaper{}, find_transpose{}, find_concat_transpose{});
}
}
......
......@@ -362,10 +362,10 @@ TEST_CASE(match_none_of2)
TEST_CASE(match_output1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::output(match::name("sum")));
auto r = find_match(p, m);
......@@ -375,10 +375,10 @@ TEST_CASE(match_output1)
TEST_CASE(match_output2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::output(match::name("sum")));
auto r = find_match(p, m);
......@@ -388,10 +388,10 @@ TEST_CASE(match_output2)
TEST_CASE(match_skip_output1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
......@@ -401,11 +401,11 @@ TEST_CASE(match_skip_output1)
TEST_CASE(match_skip_output2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto minus_pass = p.add_instruction(pass_op{}, minus);
auto sum = p.add_instruction(sum_op{}, minus_pass, two);
auto sum = p.add_instruction(sum_op{}, minus_pass, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
......@@ -415,13 +415,13 @@ TEST_CASE(match_skip_output2)
TEST_CASE(match_skip_output3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto minus_pass1 = p.add_instruction(pass_op{}, minus);
auto minus_pass2 = p.add_instruction(pass_op{}, minus_pass1);
auto minus_pass3 = p.add_instruction(pass_op{}, minus_pass2);
auto sum = p.add_instruction(sum_op{}, minus_pass3, two);
auto sum = p.add_instruction(sum_op{}, minus_pass3, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
......@@ -431,10 +431,10 @@ TEST_CASE(match_skip_output3)
TEST_CASE(match_skip_output4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto pass = p.add_instruction(pass_op{}, one);
auto sum = p.add_instruction(sum_op{}, pass, two);
auto sum = p.add_instruction(sum_op{}, pass, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
......@@ -444,8 +444,8 @@ TEST_CASE(match_skip_output4)
TEST_CASE(match_skip_output5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto pass = p.add_instruction(pass_op{}, one);
auto sum1 = p.add_instruction(sum_op{}, pass, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, one);
......@@ -459,12 +459,12 @@ TEST_CASE(match_skip_output5)
TEST_CASE(match_skip_output6)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum1 = p.add_instruction(sum_op{}, minus, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, one);
auto sum3 = p.add_instruction(sum_op{}, sum2, two);
auto sum1 = p.add_instruction(sum_op{}, minus, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, one);
auto sum3 = p.add_instruction(sum_op{}, sum2, two);
p.add_instruction(pass_op{}, sum3);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
......@@ -474,11 +474,11 @@ TEST_CASE(match_skip_output6)
TEST_CASE(match_skip_output7)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus1 = p.add_instruction(minus_op{}, two, one);
auto minus2 = p.add_instruction(minus_op{}, two, minus1);
auto sum = p.add_instruction(sum_op{}, one, minus2);
auto sum = p.add_instruction(sum_op{}, one, minus2);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus")));
auto r = find_match(p, m);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment