"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "a7ecf8bfa98c1b23f1a05be787e14704aa87a0a3"
Commit 7f4c7809 authored by Paul's avatar Paul
Browse files

Formatting

parent 2e51006e
...@@ -21,7 +21,7 @@ struct matcher_context ...@@ -21,7 +21,7 @@ struct matcher_context
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; } instruction_ref not_found() const { return last; }
template<class M> template <class M>
bool matched(M m, instruction_ref ins) bool matched(M m, instruction_ref ins)
{ {
return m.match(*this, ins) != this->not_found(); return m.match(*this, ins) != this->not_found();
...@@ -240,7 +240,7 @@ void find_matches(program& p, Ms&&... ms) ...@@ -240,7 +240,7 @@ void find_matches(program& p, Ms&&... ms)
} }
} }
template<class Op, bool Start, bool Matches> template <class Op, bool Start, bool Matches>
struct folder struct folder
{ {
template <class... Ts> template <class... Ts>
...@@ -257,7 +257,7 @@ struct folder ...@@ -257,7 +257,7 @@ struct folder
}); });
} }
template<class Selector> template <class Selector>
auto operator[](Selector select) const auto operator[](Selector select) const
{ {
return [=](auto... ms) { return [=](auto... ms) {
...@@ -266,8 +266,8 @@ struct folder ...@@ -266,8 +266,8 @@ struct folder
bool matches = Start; bool matches = Start;
select(start, [&](auto ins) { select(start, [&](auto ins) {
matches = op(matches, fold([&](auto x, auto y) { matches = op(matches, fold([&](auto x, auto y) {
return op(x, y.match(ctx, ins) == ctx.not_found()); return op(x, y.match(ctx, ins) == ctx.not_found());
})(Start, ms...)); })(Start, ms...));
}); });
if(matches == Matches) if(matches == Matches)
return start; return start;
...@@ -277,14 +277,14 @@ struct folder ...@@ -277,14 +277,14 @@ struct folder
} }
}; };
const constexpr auto all_of = folder<std::logical_and<bool>, true, true>{}; const constexpr auto all_of = folder<std::logical_and<bool>, true, true>{};
const constexpr auto any_of = folder<std::logical_or<bool>, false, true>{}; const constexpr auto any_of = folder<std::logical_or<bool>, false, true>{};
const constexpr auto none_of = folder<std::logical_or<bool>, false, false>{}; const constexpr auto none_of = folder<std::logical_or<bool>, false, false>{};
inline auto inputs() inline auto inputs()
{ {
return [](auto ins, auto f) { return [](auto ins, auto f) {
for(auto&& x:ins->inputs()) for(auto&& x : ins->inputs())
f(x); f(x);
}; };
} }
...@@ -292,7 +292,7 @@ inline auto inputs() ...@@ -292,7 +292,7 @@ inline auto inputs()
inline auto outputs() inline auto outputs()
{ {
return [](auto ins, auto f) { return [](auto ins, auto f) {
for(auto&& x:ins->outputs()) for(auto&& x : ins->outputs())
f(x); f(x);
}; };
} }
...@@ -312,12 +312,11 @@ MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins) ...@@ -312,12 +312,11 @@ MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
MIGRAPHX_PRED_MATCHER(same_shapes, instruction_ref ins) MIGRAPHX_PRED_MATCHER(same_shapes, instruction_ref ins)
{ {
if (ins->inputs().empty()) if(ins->inputs().empty())
return false; return false;
auto s = ins->inputs().front()->get_shape(); auto s = ins->inputs().front()->get_shape();
return std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return std::all_of(
return x->get_shape() == s; ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
});
} }
MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
...@@ -336,7 +335,7 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins) ...@@ -336,7 +335,7 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return ctx.not_found(); return ctx.not_found();
} }
template<class... Ms> template <class... Ms>
auto skip_output(Ms... ms) auto skip_output(Ms... ms)
{ {
auto m = any_of(ms...); auto m = any_of(ms...);
...@@ -345,10 +344,10 @@ auto skip_output(Ms... ms) ...@@ -345,10 +344,10 @@ auto skip_output(Ms... ms)
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
{ {
auto next = ins->outputs().front(); auto next = ins->outputs().front();
if (ctx.matched(m, next)) if(ctx.matched(m, next))
{ {
auto skipped_next = self(next); auto skipped_next = self(next);
if (skipped_next != ctx.not_found()) if(skipped_next != ctx.not_found())
return skipped_next; return skipped_next;
} }
return next; return next;
...@@ -366,8 +365,9 @@ inline auto name(std::string s) ...@@ -366,8 +365,9 @@ inline auto name(std::string s)
inline auto name(std::unordered_set<std::string> names) inline auto name(std::unordered_set<std::string> names)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
[ =, names = std::move(names) ](instruction_ref ins) { return names.count(ins->name()) > 0; }); return names.count(ins->name()) > 0;
});
} }
inline auto nargs(std::size_t n) inline auto nargs(std::size_t n)
......
...@@ -25,10 +25,7 @@ const auto& reshaper_names() ...@@ -25,10 +25,7 @@ const auto& reshaper_names()
return names; return names;
} }
bool is_reshaper(instruction_ref ins) bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); }
{
return contains(reshaper_names(), ins->name());
}
instruction_ref find_transpose_input(instruction_ref ins) instruction_ref find_transpose_input(instruction_ref ins)
{ {
...@@ -90,7 +87,8 @@ struct find_reshaper ...@@ -90,7 +87,8 @@ struct find_reshaper
{ {
auto matcher() const auto matcher() const
{ {
return match::name(reshaper_names())(match::any_of[match::outputs()](match::name(reshaper_names()))); return match::name(reshaper_names())(
match::any_of[match::outputs()](match::name(reshaper_names())));
} }
void apply(program& p, match::matcher_result mr) const void apply(program& p, match::matcher_result mr) const
...@@ -139,14 +137,15 @@ struct find_transpose ...@@ -139,14 +137,15 @@ struct find_transpose
{ {
auto matcher() const auto matcher() const
{ {
return match::name("transpose")(match::none_of(match::skip_output(match::name("contiguous"))(match::name("transpose")))); return match::name("transpose")(match::none_of(
match::skip_output(match::name("contiguous"))(match::name("transpose"))));
} }
void apply(program& p, match::matcher_result mr) const void apply(program& p, match::matcher_result mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto x = ins; auto x = ins;
auto t = ins; auto t = ins;
std::vector<std::int64_t> dims(ins->get_shape().lens().size()); std::vector<std::int64_t> dims(ins->get_shape().lens().size());
std::iota(dims.begin(), dims.end(), 0); std::iota(dims.begin(), dims.end(), 0);
do do
...@@ -172,13 +171,14 @@ struct find_concat_transpose ...@@ -172,13 +171,14 @@ struct find_concat_transpose
{ {
auto matcher() const auto matcher() const
{ {
return match::name("concat")(match::same_shapes(), match::all_of[match::inputs()](match::transpose_shape())); return match::name("concat")(match::same_shapes(),
match::all_of[match::inputs()](match::transpose_shape()));
} }
void apply(program& p, match::matcher_result mr) const void apply(program& p, match::matcher_result mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto s = ins->inputs().front()->get_shape(); auto s = ins->inputs().front()->get_shape();
auto op = any_cast<op::concat>(ins->get_operator()); auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s); auto permutation = find_permutation(s);
...@@ -187,10 +187,9 @@ struct find_concat_transpose ...@@ -187,10 +187,9 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform( std::transform(
ins->inputs().begin(), ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
ins->inputs().end(), return p.insert_instruction(ins, op::transpose{permutation}, i);
std::back_inserter(inputs), });
[&](auto i) { return p.insert_instruction(ins, op::transpose{permutation}, i); });
auto concat = p.insert_instruction(ins, op, inputs); auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutaion}, concat); auto t = p.insert_instruction(ins, op::transpose{ipermutaion}, concat);
p.replace_instruction(ins, t); p.replace_instruction(ins, t);
...@@ -207,11 +206,7 @@ void simplify_reshapes::apply(program& p) const ...@@ -207,11 +206,7 @@ void simplify_reshapes::apply(program& p) const
// Skip possible dead instructions // Skip possible dead instructions
if(ins->outputs().empty() and ins != end) if(ins->outputs().empty() and ins != end)
continue; continue;
match::find_matches(p, ins, match::find_matches(p, ins, find_reshaper{}, find_transpose{}, find_concat_transpose{});
find_reshaper{},
find_transpose{},
find_concat_transpose{}
);
} }
} }
......
...@@ -362,10 +362,10 @@ TEST_CASE(match_none_of2) ...@@ -362,10 +362,10 @@ TEST_CASE(match_none_of2)
TEST_CASE(match_output1) TEST_CASE(match_output1)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one); auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two); auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::output(match::name("sum"))); auto m = match::name("minus")(match::output(match::name("sum")));
auto r = find_match(p, m); auto r = find_match(p, m);
...@@ -375,10 +375,10 @@ TEST_CASE(match_output1) ...@@ -375,10 +375,10 @@ TEST_CASE(match_output1)
TEST_CASE(match_output2) TEST_CASE(match_output2)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one); auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two); auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::output(match::name("sum"))); auto m = match::name("@literal")(match::output(match::name("sum")));
auto r = find_match(p, m); auto r = find_match(p, m);
...@@ -388,10 +388,10 @@ TEST_CASE(match_output2) ...@@ -388,10 +388,10 @@ TEST_CASE(match_output2)
TEST_CASE(match_skip_output1) TEST_CASE(match_skip_output1)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one); auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two); auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum"))); auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m); auto r = find_match(p, m);
...@@ -401,11 +401,11 @@ TEST_CASE(match_skip_output1) ...@@ -401,11 +401,11 @@ TEST_CASE(match_skip_output1)
TEST_CASE(match_skip_output2) TEST_CASE(match_skip_output2)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one); auto minus = p.add_instruction(minus_op{}, two, one);
auto minus_pass = p.add_instruction(pass_op{}, minus); auto minus_pass = p.add_instruction(pass_op{}, minus);
auto sum = p.add_instruction(sum_op{}, minus_pass, two); auto sum = p.add_instruction(sum_op{}, minus_pass, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum"))); auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m); auto r = find_match(p, m);
...@@ -415,13 +415,13 @@ TEST_CASE(match_skip_output2) ...@@ -415,13 +415,13 @@ TEST_CASE(match_skip_output2)
TEST_CASE(match_skip_output3) TEST_CASE(match_skip_output3)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one); auto minus = p.add_instruction(minus_op{}, two, one);
auto minus_pass1 = p.add_instruction(pass_op{}, minus); auto minus_pass1 = p.add_instruction(pass_op{}, minus);
auto minus_pass2 = p.add_instruction(pass_op{}, minus_pass1); auto minus_pass2 = p.add_instruction(pass_op{}, minus_pass1);
auto minus_pass3 = p.add_instruction(pass_op{}, minus_pass2); auto minus_pass3 = p.add_instruction(pass_op{}, minus_pass2);
auto sum = p.add_instruction(sum_op{}, minus_pass3, two); auto sum = p.add_instruction(sum_op{}, minus_pass3, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum"))); auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m); auto r = find_match(p, m);
...@@ -431,10 +431,10 @@ TEST_CASE(match_skip_output3) ...@@ -431,10 +431,10 @@ TEST_CASE(match_skip_output3)
TEST_CASE(match_skip_output4) TEST_CASE(match_skip_output4)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto pass = p.add_instruction(pass_op{}, one); auto pass = p.add_instruction(pass_op{}, one);
auto sum = p.add_instruction(sum_op{}, pass, two); auto sum = p.add_instruction(sum_op{}, pass, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum"))); auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m); auto r = find_match(p, m);
...@@ -444,8 +444,8 @@ TEST_CASE(match_skip_output4) ...@@ -444,8 +444,8 @@ TEST_CASE(match_skip_output4)
TEST_CASE(match_skip_output5) TEST_CASE(match_skip_output5)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto pass = p.add_instruction(pass_op{}, one); auto pass = p.add_instruction(pass_op{}, one);
auto sum1 = p.add_instruction(sum_op{}, pass, two); auto sum1 = p.add_instruction(sum_op{}, pass, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, one); auto sum2 = p.add_instruction(sum_op{}, sum1, one);
...@@ -459,12 +459,12 @@ TEST_CASE(match_skip_output5) ...@@ -459,12 +459,12 @@ TEST_CASE(match_skip_output5)
TEST_CASE(match_skip_output6) TEST_CASE(match_skip_output6)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one); auto minus = p.add_instruction(minus_op{}, two, one);
auto sum1 = p.add_instruction(sum_op{}, minus, two); auto sum1 = p.add_instruction(sum_op{}, minus, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, one); auto sum2 = p.add_instruction(sum_op{}, sum1, one);
auto sum3 = p.add_instruction(sum_op{}, sum2, two); auto sum3 = p.add_instruction(sum_op{}, sum2, two);
p.add_instruction(pass_op{}, sum3); p.add_instruction(pass_op{}, sum3);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum"))); auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m); auto r = find_match(p, m);
...@@ -474,11 +474,11 @@ TEST_CASE(match_skip_output6) ...@@ -474,11 +474,11 @@ TEST_CASE(match_skip_output6)
TEST_CASE(match_skip_output7) TEST_CASE(match_skip_output7)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto minus1 = p.add_instruction(minus_op{}, two, one); auto minus1 = p.add_instruction(minus_op{}, two, one);
auto minus2 = p.add_instruction(minus_op{}, two, minus1); auto minus2 = p.add_instruction(minus_op{}, two, minus1);
auto sum = p.add_instruction(sum_op{}, one, minus2); auto sum = p.add_instruction(sum_op{}, one, minus2);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus"))); auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus")));
auto r = find_match(p, m); auto r = find_match(p, m);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment