Commit 17129c88 authored by Paul's avatar Paul
Browse files

Merge branch 'master' into fusion

parents 050a354e 2a2b4a97
...@@ -4,6 +4,6 @@ danmar/cppcheck@f965e5873 -DHAVE_RULES=1 ...@@ -4,6 +4,6 @@ danmar/cppcheck@f965e5873 -DHAVE_RULES=1
ROCm-Developer-Tools/HIP@3a41f286203968421c557338d6fb39c36f3c717c ROCm-Developer-Tools/HIP@3a41f286203968421c557338d6fb39c36f3c717c
# Needed for clang-ocl # Needed for clang-ocl
RadeonOpenCompute/rocm-cmake@d82a77c --build RadeonOpenCompute/rocm-cmake@d82a77c --build
RadeonOpenCompute/clang-ocl@a180592885ecae5b8beadf667c633c246cec82b6 RadeonOpenCompute/clang-ocl@799713643b5591a3b877c586ef2c7fbc012af819
# python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36 # python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36
-f requirements.txt -f requirements.txt
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
namespace migraph { namespace migraph {
namespace matchers { namespace match {
struct matcher_context struct matcher_context
{ {
...@@ -167,21 +167,21 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p) ...@@ -167,21 +167,21 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
} }
/// This macro takes care of the boilerplate for defining a matcher /// This macro takes care of the boilerplate for defining a matcher
#define MIGRAPH_BASIC_MATCHER(name, ...) \ #define MIGRAPH_BASIC_MATCHER(name, ...) \
struct name##_m \ struct name##_m \
{ \ { \
instruction_ref match(__VA_ARGS__) const; \ instruction_ref match(__VA_ARGS__) const; \
}; \ }; \
const constexpr auto name = migraph::matchers::basic_matcher<name##_m>{{}}; \ const constexpr auto name = migraph::match::basic_matcher<name##_m>{{}}; \
inline instruction_ref name##_m::match(__VA_ARGS__) const inline instruction_ref name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher /// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPH_PRED_MATCHER(name, ...) \ #define MIGRAPH_PRED_MATCHER(name, ...) \
struct name##_m \ struct name##_m \
{ \ { \
bool operator()(__VA_ARGS__) const; \ bool operator()(__VA_ARGS__) const; \
}; \ }; \
const constexpr auto name = migraph::matchers::basic_matcher<predicate_matcher<name##_m>>{{}}; \ const constexpr auto name = migraph::match::basic_matcher<predicate_matcher<name##_m>>{{}}; \
inline bool name##_m::operator()(__VA_ARGS__) const inline bool name##_m::operator()(__VA_ARGS__) const
struct matcher_result struct matcher_result
...@@ -294,7 +294,7 @@ struct args_impl_ints ...@@ -294,7 +294,7 @@ struct args_impl_ints
template <std::size_t... Ns, class... Ms> template <std::size_t... Ns, class... Ms>
auto args_impl(args_impl_ints<Ns...>, Ms... ms) auto args_impl(args_impl_ints<Ns...>, Ms... ms)
{ {
return matchers::all_of(nargs(sizeof...(Ns)), arg(Ns)(ms)...); return match::all_of(nargs(sizeof...(Ns)), arg(Ns)(ms)...);
} }
template <class... Ms> template <class... Ms>
...@@ -306,7 +306,7 @@ auto args(Ms... ms) ...@@ -306,7 +306,7 @@ auto args(Ms... ms)
}); });
} }
} // namespace matchers } // namespace match
} // namespace migraph } // namespace migraph
......
...@@ -184,16 +184,17 @@ void memory_coloring_impl::build() ...@@ -184,16 +184,17 @@ void memory_coloring_impl::build()
void memory_coloring_impl::register_operand_alias() void memory_coloring_impl::register_operand_alias()
{ {
operand_alias["hip::allocate"] = -1; operand_alias["hip::allocate"] = -1;
operand_alias["@outline"] = -1; operand_alias["hip::load_literal"] = -1;
operand_alias["check_context"] = -1; operand_alias["@outline"] = -1;
operand_alias["@literal"] = -1; operand_alias["check_context"] = -1;
operand_alias["@param"] = -1; operand_alias["@literal"] = -1;
operand_alias["transpose"] = 0; operand_alias["@param"] = -1;
operand_alias["flatten"] = 0; operand_alias["transpose"] = 0;
operand_alias["broadcast"] = 1; operand_alias["flatten"] = 0;
operand_alias["reshape"] = 0; operand_alias["broadcast"] = 1;
operand_alias["pass"] = 0; operand_alias["reshape"] = 0;
operand_alias["pass"] = 0;
} }
void memory_coloring_impl::rewrite() void memory_coloring_impl::rewrite()
......
...@@ -29,12 +29,12 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const ...@@ -29,12 +29,12 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
lowering{ctx}, lowering{ctx},
memory_coloring{"hip::allocate"},
fuse_ops{}, fuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_contiguous{}, eliminate_contiguous{},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
memory_coloring{"hip::allocate"},
eliminate_workspace{}, eliminate_workspace{},
eliminate_allocation{"hip::allocate"}, eliminate_allocation{"hip::allocate"},
check_context<context>{}, check_context<context>{},
......
...@@ -3,15 +3,15 @@ ...@@ -3,15 +3,15 @@
#include <test.hpp> #include <test.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
namespace matchers = migraph::matchers; namespace match = migraph::match;
template <class M> template <class M>
migraph::matchers::matcher_result find_match(migraph::program& p, M&& m) migraph::match::matcher_result find_match(migraph::program& p, M&& m)
{ {
migraph::matchers::matcher_result result; migraph::match::matcher_result result;
for(auto ins : migraph::iterator_for(p)) for(auto ins : migraph::iterator_for(p))
{ {
result = migraph::matchers::match_instruction(p, ins, m); result = migraph::match::match_instruction(p, ins, m);
if(result.result != p.end()) if(result.result != p.end())
return result; return result;
} }
...@@ -22,7 +22,7 @@ void match1() ...@@ -22,7 +22,7 @@ void match1()
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(1); auto l = p.add_literal(1);
auto m = matchers::standard_shape(); auto m = match::standard_shape();
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == l}); EXPECT(bool{r.result == l});
} }
...@@ -34,7 +34,7 @@ void match_name1() ...@@ -34,7 +34,7 @@ void match_name1()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum"); auto m = match::name("sum");
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
...@@ -46,7 +46,7 @@ void match_name2() ...@@ -46,7 +46,7 @@ void match_name2()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("min"); auto m = match::name("min");
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
...@@ -58,7 +58,7 @@ void match_name3() ...@@ -58,7 +58,7 @@ void match_name3()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::standard_shape()); auto m = match::name("sum")(match::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
...@@ -70,8 +70,7 @@ void match_arg1() ...@@ -70,8 +70,7 @@ void match_arg1()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")), auto m = match::name("sum")(match::arg(0)(match::name("@literal")), match::standard_shape());
matchers::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
...@@ -83,8 +82,7 @@ void match_arg2() ...@@ -83,8 +82,7 @@ void match_arg2()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = auto m = match::name("sum")(match::arg(0)(match::name("sum")), match::standard_shape());
matchers::name("sum")(matchers::arg(0)(matchers::name("sum")), matchers::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
...@@ -96,8 +94,7 @@ void match_arg3() ...@@ -96,8 +94,7 @@ void match_arg3()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(1)(matchers::name("@literal")), auto m = match::name("sum")(match::arg(1)(match::name("@literal")), match::standard_shape());
matchers::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
...@@ -109,9 +106,8 @@ void match_arg4() ...@@ -109,9 +106,8 @@ void match_arg4()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum); auto pass = p.add_instruction(pass_op{}, sum);
auto m = auto m = match::name("pass")(match::arg(0)(match::name("sum")), match::standard_shape());
matchers::name("pass")(matchers::arg(0)(matchers::name("sum")), matchers::standard_shape()); auto r = find_match(p, m);
auto r = find_match(p, m);
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
...@@ -122,8 +118,7 @@ void match_arg5() ...@@ -122,8 +118,7 @@ void match_arg5()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = auto m = match::name("pass")(match::arg(1)(match::name("sum")), match::standard_shape());
matchers::name("pass")(matchers::arg(1)(matchers::name("sum")), matchers::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
...@@ -135,7 +130,7 @@ void match_arg6() ...@@ -135,7 +130,7 @@ void match_arg6()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal"))); auto m = match::name("sum")(match::arg(0)(match::name("@literal")));
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
...@@ -147,8 +142,8 @@ void match_arg7() ...@@ -147,8 +142,8 @@ void match_arg7()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")), auto m = match::name("sum")(match::arg(0)(match::name("@literal")),
matchers::arg(1)(matchers::name("@literal"))); match::arg(1)(match::name("@literal")));
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
...@@ -160,9 +155,8 @@ void match_args1() ...@@ -160,9 +155,8 @@ void match_args1()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")( auto m = match::name("sum")(match::args(match::name("@literal"), match::name("@literal")),
matchers::args(matchers::name("@literal"), matchers::name("@literal")), match::standard_shape());
matchers::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
...@@ -174,9 +168,8 @@ void match_args2() ...@@ -174,9 +168,8 @@ void match_args2()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = auto m = match::name("sum")(match::args(match::name("@literal"), match::name("sum")),
matchers::name("sum")(matchers::args(matchers::name("@literal"), matchers::name("sum")), match::standard_shape());
matchers::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
...@@ -188,8 +181,7 @@ void match_args3() ...@@ -188,8 +181,7 @@ void match_args3()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::args(matchers::name("@literal")), auto m = match::name("sum")(match::args(match::name("@literal")), match::standard_shape());
matchers::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
...@@ -202,9 +194,8 @@ void match_args4() ...@@ -202,9 +194,8 @@ void match_args4()
auto sum1 = p.add_instruction(sum_op{}, one, two); auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two); auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2); p.add_instruction(pass_op{}, sum2);
auto m = auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
matchers::name("sum")(matchers::args(matchers::name("sum"), matchers::name("@literal")), match::standard_shape());
matchers::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
...@@ -216,9 +207,8 @@ void match_args5() ...@@ -216,9 +207,8 @@ void match_args5()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
matchers::name("sum")(matchers::args(matchers::name("sum"), matchers::name("@literal")), match::standard_shape());
matchers::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
...@@ -230,9 +220,8 @@ void match_args6() ...@@ -230,9 +220,8 @@ void match_args6()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum); auto pass = p.add_instruction(pass_op{}, sum);
auto m = auto m = match::name("pass")(match::args(match::name("sum")), match::standard_shape());
matchers::name("pass")(matchers::args(matchers::name("sum")), matchers::standard_shape()); auto r = find_match(p, m);
auto r = find_match(p, m);
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
...@@ -243,9 +232,9 @@ void match_args7() ...@@ -243,9 +232,9 @@ void match_args7()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum); auto pass = p.add_instruction(pass_op{}, sum);
auto m = matchers::name("pass")(matchers::args(matchers::name("sum")(matchers::args( auto m = match::name("pass")(match::args(match::name("sum")(match::args(
matchers::name("@literal"), matchers::name("@literal")))), match::name("@literal"), match::name("@literal")))),
matchers::standard_shape()); match::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
...@@ -257,8 +246,8 @@ void match_all_of1() ...@@ -257,8 +246,8 @@ void match_all_of1()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::all_of(matchers::arg(0)(matchers::name("@literal")), auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
matchers::arg(1)(matchers::name("@literal")))); match::arg(1)(match::name("@literal"))));
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
...@@ -270,8 +259,8 @@ void match_all_of2() ...@@ -270,8 +259,8 @@ void match_all_of2()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::all_of(matchers::arg(0)(matchers::name("sum")), auto m = match::name("sum")(
matchers::arg(1)(matchers::name("@literal")))); match::all_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
...@@ -283,8 +272,8 @@ void match_any_of1() ...@@ -283,8 +272,8 @@ void match_any_of1()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::any_of(matchers::arg(0)(matchers::name("sum")), auto m = match::name("sum")(
matchers::arg(1)(matchers::name("@literal")))); match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
...@@ -296,8 +285,8 @@ void match_any_of2() ...@@ -296,8 +285,8 @@ void match_any_of2()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::any_of(matchers::arg(0)(matchers::name("sum")), auto m = match::name("sum")(
matchers::arg(1)(matchers::name("sum")))); match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
...@@ -309,8 +298,8 @@ void match_none_of1() ...@@ -309,8 +298,8 @@ void match_none_of1()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::none_of(matchers::arg(0)(matchers::name("sum")), auto m = match::name("sum")(
matchers::arg(1)(matchers::name("sum")))); match::none_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
...@@ -322,8 +311,8 @@ void match_none_of2() ...@@ -322,8 +311,8 @@ void match_none_of2()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::none_of(matchers::arg(0)(matchers::name("@literal")), auto m = match::name("sum")(match::none_of(match::arg(0)(match::name("@literal")),
matchers::arg(1)(matchers::name("@literal")))); match::arg(1)(match::name("@literal"))));
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
...@@ -335,12 +324,11 @@ void match_bind1() ...@@ -335,12 +324,11 @@ void match_bind1()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum); auto pass = p.add_instruction(pass_op{}, sum);
auto m = matchers::name("pass")( auto m = match::name("pass")(
matchers::args( match::args(match::name("sum")(match::args(match::name("@literal").bind("one"),
matchers::name("sum")(matchers::args(matchers::name("@literal").bind("one"), match::name("@literal").bind("two")))
matchers::name("@literal").bind("two"))) .bind("sum")),
.bind("sum")), match::standard_shape())
matchers::standard_shape())
.bind("pass"); .bind("pass");
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.instructions.at("one") == one}); EXPECT(bool{r.instructions.at("one") == one});
...@@ -353,20 +341,17 @@ void match_bind1() ...@@ -353,20 +341,17 @@ void match_bind1()
struct match_find_sum struct match_find_sum
{ {
migraph::instruction_ref ins; migraph::instruction_ref ins;
auto matcher() const { return matchers::name("sum"); } auto matcher() const { return match::name("sum"); }
void apply(migraph::program&, matchers::matcher_result r) const void apply(migraph::program&, match::matcher_result r) const { EXPECT(bool{r.result == ins}); }
{
EXPECT(bool{r.result == ins});
}
}; };
struct match_find_literal struct match_find_literal
{ {
migraph::instruction_ref ins; migraph::instruction_ref ins;
auto matcher() const { return matchers::name("@literal"); } auto matcher() const { return match::name("@literal"); }
void apply(migraph::program&, matchers::matcher_result r) const void apply(migraph::program&, match::matcher_result r) const
{ {
EXPECT(bool{r.result != ins}); EXPECT(bool{r.result != ins});
EXPECT(r.result->name() == "@literal"); EXPECT(r.result->name() == "@literal");
...@@ -380,7 +365,7 @@ void match_finder() ...@@ -380,7 +365,7 @@ void match_finder()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
matchers::find_matches(p, match_find_sum{sum}, match_find_literal{sum}); match::find_matches(p, match_find_sum{sum}, match_find_literal{sum});
} }
int main() int main()
......
#include <migraph/memory_coloring.hpp> #include <migraph/memory_coloring.hpp>
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
#include <migraph/generate.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
...@@ -112,10 +113,22 @@ void test4() ...@@ -112,10 +113,22 @@ void test4()
EXPECT(p.get_parameter_shape("scratch").bytes() == 672); EXPECT(p.get_parameter_shape("scratch").bytes() == 672);
} }
void literal_test()
{
migraph::program p;
auto lit = generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
p.add_literal(lit);
p.compile(memory_coloring_target{});
auto result = p.eval({});
EXPECT(lit == result);
}
int main() int main()
{ {
test1(); test1();
test2(); test2();
test3(); test3();
test4(); test4();
literal_test();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment