Commit 17129c88 authored by Paul's avatar Paul
Browse files

Merge branch 'master' into fusion

parents 050a354e 2a2b4a97
......@@ -4,6 +4,6 @@ danmar/cppcheck@f965e5873 -DHAVE_RULES=1
ROCm-Developer-Tools/HIP@3a41f286203968421c557338d6fb39c36f3c717c
# Needed for clang-ocl
RadeonOpenCompute/rocm-cmake@d82a77c --build
RadeonOpenCompute/clang-ocl@a180592885ecae5b8beadf667c633c246cec82b6
RadeonOpenCompute/clang-ocl@799713643b5591a3b877c586ef2c7fbc012af819
# python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36
-f requirements.txt
......@@ -10,7 +10,7 @@
namespace migraph {
namespace matchers {
namespace match {
struct matcher_context
{
......@@ -167,21 +167,21 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
}
/// This macro takes care of the boilerplate for defining a matcher
#define MIGRAPH_BASIC_MATCHER(name, ...) \
struct name##_m \
{ \
instruction_ref match(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraph::matchers::basic_matcher<name##_m>{{}}; \
#define MIGRAPH_BASIC_MATCHER(name, ...) \
struct name##_m \
{ \
instruction_ref match(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraph::match::basic_matcher<name##_m>{{}}; \
inline instruction_ref name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPH_PRED_MATCHER(name, ...) \
struct name##_m \
{ \
bool operator()(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraph::matchers::basic_matcher<predicate_matcher<name##_m>>{{}}; \
#define MIGRAPH_PRED_MATCHER(name, ...) \
struct name##_m \
{ \
bool operator()(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraph::match::basic_matcher<predicate_matcher<name##_m>>{{}}; \
inline bool name##_m::operator()(__VA_ARGS__) const
struct matcher_result
......@@ -294,7 +294,7 @@ struct args_impl_ints
template <std::size_t... Ns, class... Ms>
auto args_impl(args_impl_ints<Ns...>, Ms... ms)
{
return matchers::all_of(nargs(sizeof...(Ns)), arg(Ns)(ms)...);
return match::all_of(nargs(sizeof...(Ns)), arg(Ns)(ms)...);
}
template <class... Ms>
......@@ -306,7 +306,7 @@ auto args(Ms... ms)
});
}
} // namespace matchers
} // namespace match
} // namespace migraph
......
......@@ -184,16 +184,17 @@ void memory_coloring_impl::build()
void memory_coloring_impl::register_operand_alias()
{
operand_alias["hip::allocate"] = -1;
operand_alias["@outline"] = -1;
operand_alias["check_context"] = -1;
operand_alias["@literal"] = -1;
operand_alias["@param"] = -1;
operand_alias["transpose"] = 0;
operand_alias["flatten"] = 0;
operand_alias["broadcast"] = 1;
operand_alias["reshape"] = 0;
operand_alias["pass"] = 0;
operand_alias["hip::allocate"] = -1;
operand_alias["hip::load_literal"] = -1;
operand_alias["@outline"] = -1;
operand_alias["check_context"] = -1;
operand_alias["@literal"] = -1;
operand_alias["@param"] = -1;
operand_alias["transpose"] = 0;
operand_alias["flatten"] = 0;
operand_alias["broadcast"] = 1;
operand_alias["reshape"] = 0;
operand_alias["pass"] = 0;
}
void memory_coloring_impl::rewrite()
......
......@@ -29,12 +29,12 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
simplify_reshapes{},
dead_code_elimination{},
lowering{ctx},
memory_coloring{"hip::allocate"},
fuse_ops{},
dead_code_elimination{},
eliminate_contiguous{},
dead_code_elimination{},
write_literals{&ctx},
memory_coloring{"hip::allocate"},
eliminate_workspace{},
eliminate_allocation{"hip::allocate"},
check_context<context>{},
......
......@@ -3,15 +3,15 @@
#include <test.hpp>
#include <basic_ops.hpp>
namespace matchers = migraph::matchers;
namespace match = migraph::match;
template <class M>
migraph::matchers::matcher_result find_match(migraph::program& p, M&& m)
migraph::match::matcher_result find_match(migraph::program& p, M&& m)
{
migraph::matchers::matcher_result result;
migraph::match::matcher_result result;
for(auto ins : migraph::iterator_for(p))
{
result = migraph::matchers::match_instruction(p, ins, m);
result = migraph::match::match_instruction(p, ins, m);
if(result.result != p.end())
return result;
}
......@@ -22,7 +22,7 @@ void match1()
{
migraph::program p;
auto l = p.add_literal(1);
auto m = matchers::standard_shape();
auto m = match::standard_shape();
auto r = find_match(p, m);
EXPECT(bool{r.result == l});
}
......@@ -34,7 +34,7 @@ void match_name1()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum");
auto m = match::name("sum");
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
......@@ -46,7 +46,7 @@ void match_name2()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("min");
auto m = match::name("min");
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
......@@ -58,7 +58,7 @@ void match_name3()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::standard_shape());
auto m = match::name("sum")(match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
......@@ -70,8 +70,7 @@ void match_arg1()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")),
matchers::standard_shape());
auto m = match::name("sum")(match::arg(0)(match::name("@literal")), match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
......@@ -83,8 +82,7 @@ void match_arg2()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("sum")(matchers::arg(0)(matchers::name("sum")), matchers::standard_shape());
auto m = match::name("sum")(match::arg(0)(match::name("sum")), match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
......@@ -96,8 +94,7 @@ void match_arg3()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(1)(matchers::name("@literal")),
matchers::standard_shape());
auto m = match::name("sum")(match::arg(1)(match::name("@literal")), match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
......@@ -109,9 +106,8 @@ void match_arg4()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("pass")(matchers::arg(0)(matchers::name("sum")), matchers::standard_shape());
auto r = find_match(p, m);
auto m = match::name("pass")(match::arg(0)(match::name("sum")), match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == pass});
}
......@@ -122,8 +118,7 @@ void match_arg5()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("pass")(matchers::arg(1)(matchers::name("sum")), matchers::standard_shape());
auto m = match::name("pass")(match::arg(1)(match::name("sum")), match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
......@@ -135,7 +130,7 @@ void match_arg6()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")));
auto m = match::name("sum")(match::arg(0)(match::name("@literal")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
......@@ -147,8 +142,8 @@ void match_arg7()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")),
matchers::arg(1)(matchers::name("@literal")));
auto m = match::name("sum")(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
......@@ -160,9 +155,8 @@ void match_args1()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(
matchers::args(matchers::name("@literal"), matchers::name("@literal")),
matchers::standard_shape());
auto m = match::name("sum")(match::args(match::name("@literal"), match::name("@literal")),
match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
......@@ -174,9 +168,8 @@ void match_args2()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("sum")(matchers::args(matchers::name("@literal"), matchers::name("sum")),
matchers::standard_shape());
auto m = match::name("sum")(match::args(match::name("@literal"), match::name("sum")),
match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
......@@ -188,8 +181,7 @@ void match_args3()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::args(matchers::name("@literal")),
matchers::standard_shape());
auto m = match::name("sum")(match::args(match::name("@literal")), match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
......@@ -202,9 +194,8 @@ void match_args4()
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m =
matchers::name("sum")(matchers::args(matchers::name("sum"), matchers::name("@literal")),
matchers::standard_shape());
auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
}
......@@ -216,9 +207,8 @@ void match_args5()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("sum")(matchers::args(matchers::name("sum"), matchers::name("@literal")),
matchers::standard_shape());
auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
......@@ -230,9 +220,8 @@ void match_args6()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("pass")(matchers::args(matchers::name("sum")), matchers::standard_shape());
auto r = find_match(p, m);
auto m = match::name("pass")(match::args(match::name("sum")), match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == pass});
}
......@@ -243,9 +232,9 @@ void match_args7()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum);
auto m = matchers::name("pass")(matchers::args(matchers::name("sum")(matchers::args(
matchers::name("@literal"), matchers::name("@literal")))),
matchers::standard_shape());
auto m = match::name("pass")(match::args(match::name("sum")(match::args(
match::name("@literal"), match::name("@literal")))),
match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == pass});
}
......@@ -257,8 +246,8 @@ void match_all_of1()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::all_of(matchers::arg(0)(matchers::name("@literal")),
matchers::arg(1)(matchers::name("@literal"))));
auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
......@@ -270,8 +259,8 @@ void match_all_of2()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::all_of(matchers::arg(0)(matchers::name("sum")),
matchers::arg(1)(matchers::name("@literal"))));
auto m = match::name("sum")(
match::all_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
......@@ -283,8 +272,8 @@ void match_any_of1()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::any_of(matchers::arg(0)(matchers::name("sum")),
matchers::arg(1)(matchers::name("@literal"))));
auto m = match::name("sum")(
match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
......@@ -296,8 +285,8 @@ void match_any_of2()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::any_of(matchers::arg(0)(matchers::name("sum")),
matchers::arg(1)(matchers::name("sum"))));
auto m = match::name("sum")(
match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
......@@ -309,8 +298,8 @@ void match_none_of1()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::none_of(matchers::arg(0)(matchers::name("sum")),
matchers::arg(1)(matchers::name("sum"))));
auto m = match::name("sum")(
match::none_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
......@@ -322,8 +311,8 @@ void match_none_of2()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::none_of(matchers::arg(0)(matchers::name("@literal")),
matchers::arg(1)(matchers::name("@literal"))));
auto m = match::name("sum")(match::none_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
......@@ -335,12 +324,11 @@ void match_bind1()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum);
auto m = matchers::name("pass")(
matchers::args(
matchers::name("sum")(matchers::args(matchers::name("@literal").bind("one"),
matchers::name("@literal").bind("two")))
.bind("sum")),
matchers::standard_shape())
auto m = match::name("pass")(
match::args(match::name("sum")(match::args(match::name("@literal").bind("one"),
match::name("@literal").bind("two")))
.bind("sum")),
match::standard_shape())
.bind("pass");
auto r = find_match(p, m);
EXPECT(bool{r.instructions.at("one") == one});
......@@ -353,20 +341,17 @@ void match_bind1()
struct match_find_sum
{
migraph::instruction_ref ins;
auto matcher() const { return matchers::name("sum"); }
auto matcher() const { return match::name("sum"); }
void apply(migraph::program&, matchers::matcher_result r) const
{
EXPECT(bool{r.result == ins});
}
void apply(migraph::program&, match::matcher_result r) const { EXPECT(bool{r.result == ins}); }
};
struct match_find_literal
{
migraph::instruction_ref ins;
auto matcher() const { return matchers::name("@literal"); }
auto matcher() const { return match::name("@literal"); }
void apply(migraph::program&, matchers::matcher_result r) const
void apply(migraph::program&, match::matcher_result r) const
{
EXPECT(bool{r.result != ins});
EXPECT(r.result->name() == "@literal");
......@@ -380,7 +365,7 @@ void match_finder()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
matchers::find_matches(p, match_find_sum{sum}, match_find_literal{sum});
match::find_matches(p, match_find_sum{sum}, match_find_literal{sum});
}
int main()
......
#include <migraph/memory_coloring.hpp>
#include <migraph/operators.hpp>
#include <migraph/generate.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -112,10 +113,22 @@ void test4()
EXPECT(p.get_parameter_shape("scratch").bytes() == 672);
}
void literal_test()
{
migraph::program p;
auto lit = generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
p.add_literal(lit);
p.compile(memory_coloring_target{});
auto result = p.eval({});
EXPECT(lit == result);
}
int main()
{
test1();
test2();
test3();
test4();
literal_test();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment