Commit 53c744bf authored by Khalique's avatar Khalique
Browse files

Merge branch 'tf-transpose' of...

Merge branch 'tf-transpose' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into tf-transpose_ss
parents c808514d 9ca52537
...@@ -259,13 +259,13 @@ struct lazy_or ...@@ -259,13 +259,13 @@ struct lazy_or
}; };
template <class Op, bool Start, bool Matches> template <class Op, bool Start, bool Matches>
struct folder struct match_fold_f
{ {
template <class... Ms> template <class... Ms>
static bool fold_match(matcher_context& ctx, instruction_ref ins, Ms... ms) static bool fold_matchers(matcher_context& ctx, instruction_ref ins, Ms... ms)
{ {
Op op; Op op;
auto matched = [&](auto m) { return [&] { return ctx.matched(m, ins); }; }; auto matched = [&](auto m) { return [=, &ctx] { return ctx.matched(m, ins); }; };
return fold([&](auto x, auto y) { return op(always(x), matched(y)); })(Start, ms...); return fold([&](auto x, auto y) { return op(always(x), matched(y)); })(Start, ms...);
} }
...@@ -273,7 +273,7 @@ struct folder ...@@ -273,7 +273,7 @@ struct folder
auto operator()(Ts... ms) const auto operator()(Ts... ms) const
{ {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = folder::fold_match(ctx, ins, ms...); bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
if(matches == Matches) if(matches == Matches)
return ins; return ins;
return ctx.not_found(); return ctx.not_found();
...@@ -283,12 +283,18 @@ struct folder ...@@ -283,12 +283,18 @@ struct folder
template <class Selector> template <class Selector>
auto operator[](Selector select) const auto operator[](Selector select) const
{ {
return [=](auto... ms) { return [=](auto... mms) {
// Workaround ICE on gcc by packing matchers into an object
auto mpack = pack(mms...);
return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) { return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) {
Op op; Op op;
bool matches = Start; bool matches = Start;
select(start, [&](auto ins) { select(start, [&](auto ins) {
auto fm = [&] { return folder::fold_match(ctx, ins, ms...); }; auto fm = [&] {
return mpack([&](auto... ms) {
return match_fold_f::fold_matchers(ctx, ins, ms...);
});
};
matches = op(always(matches), fm); matches = op(always(matches), fm);
}); });
if(matches == Matches) if(matches == Matches)
...@@ -299,9 +305,9 @@ struct folder ...@@ -299,9 +305,9 @@ struct folder
} }
}; };
const constexpr auto all_of = folder<lazy_and, true, true>{}; const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
const constexpr auto any_of = folder<lazy_or, false, true>{}; const constexpr auto any_of = match_fold_f<lazy_or, false, true>{};
const constexpr auto none_of = folder<lazy_or, false, false>{}; const constexpr auto none_of = match_fold_f<lazy_or, false, false>{};
inline auto inputs() inline auto inputs()
{ {
......
...@@ -162,6 +162,42 @@ TEST_CASE(match_arg8) ...@@ -162,6 +162,42 @@ TEST_CASE(match_arg8)
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_nargs1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2), match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::nargs(2)));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_args1) TEST_CASE(match_args1)
{ {
migraphx::program p; migraphx::program p;
...@@ -321,6 +357,19 @@ TEST_CASE(match_all_of2) ...@@ -321,6 +357,19 @@ TEST_CASE(match_all_of2)
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
TEST_CASE(match_all_of3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::all_of(
match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal")))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_any_of1) TEST_CASE(match_any_of1)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment