Unverified Commit c72a047f authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Dont match or bind to global instructions (#826)



* Add optional header

* Formatting

* Use optional in the matcher

* Foramtting

* Remove program from tests

* Formatting

* Dont bind or match non-local variables

* Formatting

* Fix gcc 5 error

* Format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 9a5e0c06
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/program.hpp> #include <migraphx/optional.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -19,24 +19,51 @@ namespace match { ...@@ -19,24 +19,51 @@ namespace match {
struct matcher_context struct matcher_context
{ {
matcher_context(instruction_ref i) : last(i) {} matcher_context(module& m) : mod(&m) {}
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; }
template <class M> template <class M>
bool matched(M m, instruction_ref ins) bool matched(M m, instruction_ref ins)
{ {
return m.match(*this, ins) != this->not_found(); return has_value(m.match(*this, ins));
} }
template <class M> template <class M>
auto lazy_match(M m, instruction_ref ins) bool matched(M m, optional<instruction_ref> ins)
{
if(ins)
return has_value(m.match(*this, *ins));
return false;
}
template <class M, class I>
auto lazy_match(M m, I ins)
{ {
return [=] { return this->matched(m, ins); }; return [=] { return this->matched(m, ins); };
} }
bool has_instruction(instruction_ref ins) const
{
if(mod == nullptr)
return true;
return mod->has_instruction(ins);
}
bool has_instruction(optional<instruction_ref> ins) const
{
if(ins)
return this->has_instruction(*ins);
return false;
}
bool is_last(instruction_ref ins) const
{
assert(mod->begin() != mod->end());
assert(this->has_instruction(ins));
return ins == std::prev(mod->end());
}
private: private:
instruction_ref last; module* mod = nullptr;
}; };
/// Convert a predicate function into a matcher /// Convert a predicate function into a matcher
...@@ -45,12 +72,11 @@ struct predicate_matcher ...@@ -45,12 +72,11 @@ struct predicate_matcher
{ {
P p; P p;
instruction_ref match(const matcher_context& ctx, instruction_ref ins) const optional<instruction_ref> match(const matcher_context&, instruction_ref ins) const
{ {
assert(ins != ctx.not_found());
if(p(ins)) if(p(ins))
return ins; return optional<instruction_ref>{ins};
return ctx.not_found(); return nullopt;
} }
}; };
...@@ -60,11 +86,7 @@ struct function_matcher ...@@ -60,11 +86,7 @@ struct function_matcher
{ {
F f; F f;
instruction_ref match(matcher_context& ctx, instruction_ref ins) const auto match(matcher_context& ctx, instruction_ref ins) const { return f(ctx, ins); }
{
assert(ins != ctx.not_found());
return f(ctx, ins);
}
}; };
/// Convert a function into a matcher /// Convert a function into a matcher
...@@ -79,12 +101,17 @@ template <class M> ...@@ -79,12 +101,17 @@ template <class M>
auto bind_match(M m, std::string name) auto bind_match(M m, std::string name)
{ {
return make_function_matcher( return make_function_matcher(
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) { [ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins)
auto result = m.match(ctx, ins); ->optional<instruction_ref> {
if(result != ctx.not_found()) auto result = m.match(ctx, ins);
ctx.instructions[name] = ins; if(result)
return result; {
}); if(not ctx.has_instruction(ins))
return nullopt;
ctx.instructions[name] = ins;
}
return result;
});
} }
/// Convert a matcher to a bindable matcher /// Convert a matcher to a bindable matcher
...@@ -95,10 +122,7 @@ struct bindable_matcher ...@@ -95,10 +122,7 @@ struct bindable_matcher
auto bind(std::string name) const { return bind_match(m, std::move(name)); } auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
{
return m.match(ctx, ins);
}
}; };
/// Create a bindable matcher /// Create a bindable matcher
...@@ -126,7 +150,10 @@ using bool_list = std::initializer_list<bool>; ...@@ -126,7 +150,10 @@ using bool_list = std::initializer_list<bool>;
struct id_matcher struct id_matcher
{ {
instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; } auto match(matcher_context&, instruction_ref ins) const
{
return optional<instruction_ref>{ins};
}
}; };
/// The basic matcher provides the all_of composability of the matcher /// The basic matcher provides the all_of composability of the matcher
...@@ -140,26 +167,23 @@ struct basic_matcher ...@@ -140,26 +167,23 @@ struct basic_matcher
{ {
// Copy m because we cant capture `this` by value // Copy m because we cant capture `this` by value
auto mm = m; auto mm = m;
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_bf_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins); auto result = mm.match(ctx, ins);
if(result != ctx.not_found()) if(result)
{ {
bool matches = fold([&](auto x, auto y) { bool matches =
return x and y.match(ctx, result) != ctx.not_found(); fold([&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...);
})(true, ms...);
if(matches) if(matches)
return result; return result;
} }
return ctx.not_found(); return nullopt;
}); });
} }
auto bind(std::string name) const { return bind_match(m, std::move(name)); } auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
{
return m.match(ctx, ins);
}
}; };
/// Create a basic matcher from a matcher /// Create a basic matcher from a matcher
...@@ -185,7 +209,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p) ...@@ -185,7 +209,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
/// Create a typed-erased matcher /// Create a typed-erased matcher
using any_matcher_base = basic_matcher< using any_matcher_base = basic_matcher<
function_matcher<std::function<instruction_ref(matcher_context&, instruction_ref)>>>; function_matcher<std::function<optional<instruction_ref>(matcher_context&, instruction_ref)>>>;
struct any_matcher : any_matcher_base struct any_matcher : any_matcher_base
{ {
template <class M> template <class M>
...@@ -198,10 +222,10 @@ struct any_matcher : any_matcher_base ...@@ -198,10 +222,10 @@ struct any_matcher : any_matcher_base
#define MIGRAPHX_BASIC_MATCHER(name, ...) \ #define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \ struct name##_m \
{ \ { \
instruction_ref match(__VA_ARGS__) const; \ optional<instruction_ref> match(__VA_ARGS__) const; \
}; \ }; \
const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \ const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \
inline instruction_ref name##_m::match(__VA_ARGS__) const inline optional<instruction_ref> name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher /// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPHX_PRED_MATCHER(name, ...) \ #define MIGRAPHX_PRED_MATCHER(name, ...) \
...@@ -221,21 +245,29 @@ struct matcher_result ...@@ -221,21 +245,29 @@ struct matcher_result
/// Match a single instruction /// Match a single instruction
template <class M> template <class M>
matcher_result match_instruction(module& p, instruction_ref ins, M&& m) matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{ {
assert(ins != p.end()); assert(ins != mod.end());
assert(mod.has_instruction(ins));
matcher_context ctx{mod};
matcher_result result; matcher_result result;
matcher_context ctx{p.end()}; if(m.match(ctx, ins))
result.result = m.match(ctx, ins); {
result.instructions = ctx.instructions; result.result = ins;
result.instructions = ctx.instructions;
}
else
{
result.result = mod.end();
}
return result; return result;
} }
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the program /// Find matches for an instruction in the module
template <class... Ms> template <class... Ms>
void find_matches(module& p, instruction_ref ins, Ms&&... ms) void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
{ {
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 #if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const const
...@@ -246,27 +278,27 @@ void find_matches(module& p, instruction_ref ins, Ms&&... ms) ...@@ -246,27 +278,27 @@ void find_matches(module& p, instruction_ref ins, Ms&&... ms)
[&](auto&& m) { [&](auto&& m) {
if(match) if(match)
return; return;
auto r = match_instruction(p, ins, m.matcher()); auto r = match_instruction(mod, ins, m.matcher());
if(r.result == p.end()) if(r.result == mod.end())
return; return;
if(trace) if(trace)
{ {
std::cout << "Matched by " << get_type_name(m) << std::endl; std::cout << "Matched by " << get_type_name(m) << std::endl;
p.debug_print(ins); mod.debug_print(ins);
} }
m.apply(p, r); m.apply(mod, r);
match = true; match = true;
}, },
ms...); ms...);
} }
/// Find matches in a program /// Find matches in a module
template <class... Ms> template <class... Ms>
void find_matches(module& p, Ms&&... ms) void find_matches(module& mod, Ms&&... ms)
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(mod))
{ {
find_matches(p, ins, ms...); find_matches(mod, ins, ms...);
} }
} }
...@@ -339,12 +371,13 @@ struct match_fold_f ...@@ -339,12 +371,13 @@ struct match_fold_f
template <class... Ts> template <class... Ts>
auto operator()(Ts... ms) const auto operator()(Ts... ms) const
{ {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_bf_matcher(
bool matches = match_fold_f::fold_matchers(ctx, ins, ms...); [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
if(matches == Matches) bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
return ins; if(matches == Matches)
return ctx.not_found(); return {ins};
}); return nullopt;
});
} }
template <class Selector> template <class Selector>
...@@ -353,17 +386,18 @@ struct match_fold_f ...@@ -353,17 +386,18 @@ struct match_fold_f
return [=](auto... ms) { return [=](auto... ms) {
// Workaround ICE on gcc by packing matchers into an object // Workaround ICE on gcc by packing matchers into an object
auto mpack = pack(ms...); auto mpack = pack(ms...);
return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) { return make_bf_matcher(
Op op; [=](matcher_context& ctx, instruction_ref start) -> optional<instruction_ref> {
bool matches = Start; Op op;
select(start, [&](auto ins) { bool matches = Start;
auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); }; select(start, [&](auto ins) {
matches = op(always(matches), fm)(); auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
matches = op(always(matches), fm)();
});
if(matches == Matches)
return {start};
return nullopt;
}); });
if(matches == Matches)
return start;
return ctx.not_found();
});
}; };
} }
}; };
...@@ -420,64 +454,29 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins) ...@@ -420,64 +454,29 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; }); ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
} }
MIGRAPHX_BASIC_MATCHER(output, const matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
return ins->outputs().front(); return {ins->outputs().front()};
return ctx.not_found(); return nullopt;
} }
MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
return ins; return {ins};
if(ins->outputs().empty() and std::next(ins) == ctx.not_found()) if(ins->outputs().empty() and ctx.is_last(ins))
return ins; return {ins};
return ctx.not_found(); return nullopt;
}
inline auto used_once_recursive(std::size_t depth)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) {
// Used once
if(start->outputs().size() == 1)
return start;
// Unused
if(start->outputs().empty())
{
if(std::next(start) == ctx.not_found())
return start;
else
return ctx.not_found();
}
// Check for dead instructions
auto is_dead = fix<bool>([&](auto self, auto ins, auto n) {
if(n == 0)
return false;
if(ins->get_shape().elements() == 0)
return false;
if(ins->outputs().empty() and std::next(ins) != ctx.not_found())
return true;
return std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
return self(i, n - 1);
});
});
auto dead = std::count_if(start->outputs().begin(), start->outputs().end(), [&](auto i) {
return is_dead(i, depth);
});
if(dead + 1 == start->outputs().size())
return start;
return ctx.not_found();
});
} }
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); } MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().empty() and ins != std::prev(ctx.not_found())) if(ins->outputs().empty() and not ctx.is_last(ins))
return ins; return {ins};
return ctx.not_found(); return nullopt;
} }
template <class... Ms> template <class... Ms>
...@@ -485,14 +484,15 @@ auto skip(Ms... ms) ...@@ -485,14 +484,15 @@ auto skip(Ms... ms)
{ {
auto m = any_of(ms...); auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) { return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) { return fix<optional<instruction_ref>>(
if(ins->inputs().size() == 1 and ctx.matched(m, ins)) [&](auto self, auto ins) -> optional<instruction_ref> {
{ if(ins->inputs().size() == 1 and ctx.matched(m, ins))
auto next = ins->inputs().front(); {
return self(next); auto next = ins->inputs().front();
} return self(next);
return ins; }
})(start); return ins;
})(start);
}); });
} }
...@@ -501,20 +501,21 @@ auto skip_output(Ms... ms) ...@@ -501,20 +501,21 @@ auto skip_output(Ms... ms)
{ {
auto m = any_of(ms...); auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) { return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) { return fix<optional<instruction_ref>>(
if(ins->outputs().size() == 1) [&](auto self, auto ins) -> optional<instruction_ref> {
{ if(ins->outputs().size() == 1)
auto next = ins->outputs().front();
if(ctx.matched(m, next))
{ {
auto skipped_next = self(next); auto next = ins->outputs().front();
if(skipped_next != ctx.not_found()) if(ctx.matched(m, next))
return skipped_next; {
auto skipped_next = self(next);
if(skipped_next)
return skipped_next;
}
return next;
} }
return next; return nullopt;
} })(start);
return ctx.not_found();
})(start);
}); });
} }
...@@ -550,11 +551,12 @@ inline auto nargs(std::size_t n) ...@@ -550,11 +551,12 @@ inline auto nargs(std::size_t n)
inline auto arg(std::size_t i) inline auto arg(std::size_t i)
{ {
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
if(i < ins->inputs().size()) [=](const matcher_context&, instruction_ref ins) -> optional<instruction_ref> {
return ins->inputs()[i]; if(i < ins->inputs().size())
return ctx.not_found(); return ins->inputs()[i];
}); return nullopt;
});
} }
// Workaround for bugs in clang // Workaround for bugs in clang
...@@ -616,52 +618,55 @@ std::size_t tree_leafs_impl(matcher_context& ctx, ...@@ -616,52 +618,55 @@ std::size_t tree_leafs_impl(matcher_context& ctx,
template <class M, class... Ms> template <class M, class... Ms>
auto tree(M main_op, Ms... ms) auto tree(M main_op, Ms... ms)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
// Flatten leaf nodes [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
std::array<instruction_ref, sizeof...(Ms)> leafs; // Flatten leaf nodes
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins); std::array<instruction_ref, sizeof...(Ms)> leafs;
if(idx != leafs.size()) std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
return ctx.not_found(); if(idx != leafs.size())
// Use explicit captures to workaround ICE on gcc return nullopt;
bool found = sequence_c<sizeof...(Ms)>([&ms..., &ctx, &leafs](auto... is) { // Use explicit captures to workaround ICE on gcc
return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)(); bool found = sequence_c<sizeof...(Ms)>([&ms..., &ctx, &leafs](auto... is) {
return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)();
});
if(not found)
return nullopt;
return ins;
}); });
if(not found)
return ctx.not_found();
return ins;
});
} }
template <class M, class... Ms> template <class M, class... Ms>
auto unordered_tree(M main_op, Ms... ms) auto unordered_tree(M main_op, Ms... ms)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
// Flatten leaf nodes [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
std::array<instruction_ref, sizeof...(Ms)> leafs; // Flatten leaf nodes
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins); std::array<instruction_ref, sizeof...(Ms)> leafs;
if(idx != leafs.size()) std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
return ctx.not_found(); if(idx != leafs.size())
// Use explicit captures to workaround ICE on gcc return nullopt;
bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) { // Use explicit captures to workaround ICE on gcc
return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) { bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...); return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) {
})(ms...)(); return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...);
})(ms...)();
});
if(not found)
return nullopt;
return ins;
}); });
if(not found)
return ctx.not_found();
return ins;
});
} }
template <class M> template <class M>
auto same_shape(M m) auto same_shape(M m)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
auto i = m.match(ctx, ins); [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
if(i != ctx.not_found() and i->get_shape() == ins->get_shape()) auto i = m.match(ctx, ins);
return ins; if(i and (*i)->get_shape() == ins->get_shape())
return ctx.not_found(); return ins;
}); return nullopt;
});
} }
template <class... Ms> template <class... Ms>
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#endif
#if __has_include(<experimental/optional>) && __cplusplus >= 201103L
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#else
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#if MIGRAPHX_HAS_OPTIONAL
#include <optional>
#elif MIGRAPHX_HAS_OPTIONAL_TS
#include <experimental/optional>
#else
#error "No optional include available"
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#if MIGRAPHX_HAS_OPTIONAL
template <class T>
using optional = std::optional<T>;
using nullopt_t = std::nullopt_t;
constexpr auto nullopt = std::nullopt;
#elif MIGRAPHX_HAS_OPTIONAL_TS
template <class T>
using optional = std::experimental::optional<T>;
using nullopt_t = std::experimental::nullopt_t;
constexpr auto nullopt = std::experimental::nullopt;
#endif
template <class T>
bool has_value(const optional<T>& x)
{
return x != nullopt;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
...@@ -22,576 +22,498 @@ migraphx::match::matcher_result find_match(migraphx::module& modl, M&& m) ...@@ -22,576 +22,498 @@ migraphx::match::matcher_result find_match(migraphx::module& modl, M&& m)
void match1() void match1()
{ {
migraphx::program p; migraphx::module mm;
auto l = mm.add_literal(1);
auto* mm = p.get_main_module(); auto m = match::standard_shape();
auto l = mm->add_literal(1); auto r = find_match(mm, m);
auto m = match::standard_shape();
auto r = find_match(*mm, m);
EXPECT(bool{r.result == l}); EXPECT(bool{r.result == l});
} }
TEST_CASE(match_name1) TEST_CASE(match_name1)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum"); auto m = match::name("sum");
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_name2) TEST_CASE(match_name2)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("min"); auto m = match::name("min");
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_name3) TEST_CASE(match_name3)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::standard_shape()); auto m = match::name("sum")(match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_arg1) TEST_CASE(match_arg1)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(0)(match::name("@literal")), match::standard_shape()); auto m = match::name("sum")(match::arg(0)(match::name("@literal")), match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_arg2) TEST_CASE(match_arg2)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(0)(match::name("sum")), match::standard_shape()); auto m = match::name("sum")(match::arg(0)(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_arg3) TEST_CASE(match_arg3)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(1)(match::name("@literal")), match::standard_shape()); auto m = match::name("sum")(match::arg(1)(match::name("@literal")), match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_arg4) TEST_CASE(match_arg4)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto pass = mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto pass = mm->add_instruction(pass_op{}, sum);
auto m = match::name("pass")(match::arg(0)(match::name("sum")), match::standard_shape()); auto m = match::name("pass")(match::arg(0)(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
TEST_CASE(match_arg5) TEST_CASE(match_arg5)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("pass")(match::arg(1)(match::name("sum")), match::standard_shape()); auto m = match::name("pass")(match::arg(1)(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_arg6) TEST_CASE(match_arg6)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(0)(match::name("@literal"))); auto m = match::name("sum")(match::arg(0)(match::name("@literal")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_arg7) TEST_CASE(match_arg7)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(0)(match::name("@literal")), auto m = match::name("sum")(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))); match::arg(1)(match::name("@literal")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_arg8) TEST_CASE(match_arg8)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")), auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))), match::arg(1)(match::name("@literal"))),
match::standard_shape()); match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_nargs1) TEST_CASE(match_nargs1)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2)); auto m = match::name("sum")(match::nargs(2));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_nargs2) TEST_CASE(match_nargs2)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2), match::standard_shape()); auto m = match::name("sum")(match::nargs(2), match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_nargs3) TEST_CASE(match_nargs3)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::nargs(2))); auto m = match::name("sum")(match::all_of(match::nargs(2)));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_args1) TEST_CASE(match_args1)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::args(match::name("@literal"), match::name("@literal")), auto m = match::name("sum")(match::args(match::name("@literal"), match::name("@literal")),
match::standard_shape()); match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_args2) TEST_CASE(match_args2)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::args(match::name("@literal"), match::name("sum")), auto m = match::name("sum")(match::args(match::name("@literal"), match::name("sum")),
match::standard_shape()); match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_args3) TEST_CASE(match_args3)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::args(match::name("@literal")), match::standard_shape()); auto m = match::name("sum")(match::args(match::name("@literal")), match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_args4) TEST_CASE(match_args4)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")), auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
match::standard_shape()); match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
TEST_CASE(match_args5) TEST_CASE(match_args5)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")), auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
match::standard_shape()); match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_args6) TEST_CASE(match_args6)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto pass = mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto pass = mm->add_instruction(pass_op{}, sum);
auto m = match::name("pass")(match::args(match::name("sum")), match::standard_shape()); auto m = match::name("pass")(match::args(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
TEST_CASE(match_args7) TEST_CASE(match_args7)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto pass = mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto pass = mm->add_instruction(pass_op{}, sum);
auto m = match::name("pass")(match::args(match::name("sum")(match::args( auto m = match::name("pass")(match::args(match::name("sum")(match::args(
match::name("@literal"), match::name("@literal")))), match::name("@literal"), match::name("@literal")))),
match::standard_shape()); match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
TEST_CASE(match_either_args1) TEST_CASE(match_either_args1)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
auto m = auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("sum"), match::name("@literal"))); match::name("sum")(match::either_arg(0, 1)(match::name("sum"), match::name("@literal")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
TEST_CASE(match_either_args2) TEST_CASE(match_either_args2)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
auto m = auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("@literal"), match::name("sum"))); match::name("sum")(match::either_arg(0, 1)(match::name("@literal"), match::name("sum")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
TEST_CASE(match_either_args3) TEST_CASE(match_either_args3)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
auto m = auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("pass"), match::name("@literal"))); match::name("sum")(match::either_arg(0, 1)(match::name("pass"), match::name("@literal")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_either_args_any1) TEST_CASE(match_either_args_any1)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
auto m = auto m =
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y"))); match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1}); EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
} }
TEST_CASE(match_either_args_any2) TEST_CASE(match_either_args_any2)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
auto m = match::name("sum")( auto m = match::name("sum")(
match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y"))); match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1}); EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
} }
TEST_CASE(match_either_args_any3) TEST_CASE(match_either_args_any3)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
auto m = match::name("sum")( auto m = match::name("sum")(
match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y"))); match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1}); EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
} }
TEST_CASE(match_either_args_any4) TEST_CASE(match_either_args_any4)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
auto m = match::name("sum")( auto m = match::name("sum")(
match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y"))); match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
} }
TEST_CASE(match_either_args_any5) TEST_CASE(match_either_args_any5)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
auto m = match::name("sum")( auto m = match::name("sum")(
match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y"))); match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
} }
TEST_CASE(match_all_of1) TEST_CASE(match_all_of1)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")), auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal")))); match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_all_of2) TEST_CASE(match_all_of2)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")( auto m = match::name("sum")(
match::all_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal")))); match::all_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_all_of3) TEST_CASE(match_all_of3)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::all_of( auto m = match::name("sum")(match::all_of(match::all_of(
match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal"))))); match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal")))));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_lazy_any_of) TEST_CASE(match_lazy_any_of)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); mm.add_instruction(pass_op{}, one);
auto one = mm->add_literal(1);
mm->add_instruction(pass_op{}, one);
auto m = match::any_of(match::any(), throws()); auto m = match::any_of(match::any(), throws());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == one}); EXPECT(bool{r.result == one});
} }
TEST_CASE(match_lazy_all_of) TEST_CASE(match_lazy_all_of)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); mm.add_instruction(pass_op{}, one);
auto one = mm->add_literal(1);
mm->add_instruction(pass_op{}, one);
auto m = match::all_of(match::none(), throws()); auto m = match::all_of(match::none(), throws());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_lazy_none_of) TEST_CASE(match_lazy_none_of)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); mm.add_instruction(pass_op{}, one);
auto one = mm->add_literal(1);
mm->add_instruction(pass_op{}, one);
auto m = match::none_of(match::any(), throws()); auto m = match::none_of(match::any(), throws());
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_any_of1) TEST_CASE(match_any_of1)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")( auto m = match::name("sum")(
match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal")))); match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_any_of2) TEST_CASE(match_any_of2)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")( auto m = match::name("sum")(
match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum")))); match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_any_of_lazy1) TEST_CASE(match_any_of_lazy1)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")( auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"), match::any_of(match::args(match::any(), match::any()).bind("x"),
match::args(match::name("sum"), match::name("sum")).bind("y"))); match::args(match::name("sum"), match::name("sum")).bind("y")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x")); EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum}); EXPECT(bool{r.instructions["x"] == sum});
...@@ -600,17 +522,15 @@ TEST_CASE(match_any_of_lazy1) ...@@ -600,17 +522,15 @@ TEST_CASE(match_any_of_lazy1)
TEST_CASE(match_any_of_lazy2) TEST_CASE(match_any_of_lazy2)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")( auto m = match::name("sum")(
match::any_of(match::args(match::name("@literal"), match::name("@literal")).bind("x"), match::any_of(match::args(match::name("@literal"), match::name("@literal")).bind("x"),
match::args(match::any(), match::any()).bind("y"))); match::args(match::any(), match::any()).bind("y")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x")); EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum}); EXPECT(bool{r.instructions["x"] == sum});
...@@ -619,17 +539,15 @@ TEST_CASE(match_any_of_lazy2) ...@@ -619,17 +539,15 @@ TEST_CASE(match_any_of_lazy2)
TEST_CASE(match_any_of_lazy3) TEST_CASE(match_any_of_lazy3)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")( auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"), match::any_of(match::args(match::any(), match::any()).bind("x"),
match::args(match::name("@literal"), match::name("@literal")).bind("y"))); match::args(match::name("@literal"), match::name("@literal")).bind("y")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x")); EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum}); EXPECT(bool{r.instructions["x"] == sum});
...@@ -638,17 +556,15 @@ TEST_CASE(match_any_of_lazy3) ...@@ -638,17 +556,15 @@ TEST_CASE(match_any_of_lazy3)
TEST_CASE(match_any_of_lazy4) TEST_CASE(match_any_of_lazy4)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::any_of( auto m = match::name("sum")(match::any_of(
match::args(match::name("@literal").bind("x1"), match::name("@literal").bind("y1")), match::args(match::name("@literal").bind("x1"), match::name("@literal").bind("y1")),
match::args(match::any().bind("x2"), match::any().bind("y2")))); match::args(match::any().bind("x2"), match::any().bind("y2"))));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1")); EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1")); EXPECT(migraphx::contains(r.instructions, "y1"));
...@@ -660,17 +576,15 @@ TEST_CASE(match_any_of_lazy4) ...@@ -660,17 +576,15 @@ TEST_CASE(match_any_of_lazy4)
TEST_CASE(match_any_of_lazy5) TEST_CASE(match_any_of_lazy5)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::any_of( auto m = match::name("sum")(match::any_of(
match::args(match::any().bind("x1"), match::any().bind("y1")), match::args(match::any().bind("x1"), match::any().bind("y1")),
match::args(match::name("@literal").bind("x2"), match::name("@literal").bind("y2")))); match::args(match::name("@literal").bind("x2"), match::name("@literal").bind("y2"))));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1")); EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1")); EXPECT(migraphx::contains(r.instructions, "y1"));
...@@ -682,194 +596,170 @@ TEST_CASE(match_any_of_lazy5) ...@@ -682,194 +596,170 @@ TEST_CASE(match_any_of_lazy5)
TEST_CASE(match_none_of1) TEST_CASE(match_none_of1)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")( auto m = match::name("sum")(
match::none_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum")))); match::none_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_none_of2) TEST_CASE(match_none_of2)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::none_of(match::arg(0)(match::name("@literal")), auto m = match::name("sum")(match::none_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal")))); match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_output1) TEST_CASE(match_output1)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto minus = mm.add_instruction(minus_op{}, two, one);
auto two = mm->add_literal(2); auto sum = mm.add_instruction(sum_op{}, minus, two);
auto minus = mm->add_instruction(minus_op{}, two, one); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, minus, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::output(match::name("sum"))); auto m = match::name("minus")(match::output(match::name("sum")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == minus}); EXPECT(bool{r.result == minus});
} }
TEST_CASE(match_output2) TEST_CASE(match_output2)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto minus = mm.add_instruction(minus_op{}, two, one);
auto two = mm->add_literal(2); auto sum = mm.add_instruction(sum_op{}, minus, two);
auto minus = mm->add_instruction(minus_op{}, two, one); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, minus, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::output(match::name("sum"))); auto m = match::name("@literal")(match::output(match::name("sum")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_skip_output1) TEST_CASE(match_skip_output1)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto minus = mm.add_instruction(minus_op{}, two, one);
auto two = mm->add_literal(2); auto sum = mm.add_instruction(sum_op{}, minus, two);
auto minus = mm->add_instruction(minus_op{}, two, one); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, minus, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum"))); auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == minus}); EXPECT(bool{r.result == minus});
} }
TEST_CASE(match_skip_output2) TEST_CASE(match_skip_output2)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto minus = mm.add_instruction(minus_op{}, two, one);
auto two = mm->add_literal(2); auto minus_pass = mm.add_instruction(pass_op{}, minus);
auto minus = mm->add_instruction(minus_op{}, two, one); auto sum = mm.add_instruction(sum_op{}, minus_pass, two);
auto minus_pass = mm->add_instruction(pass_op{}, minus); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, minus_pass, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum"))); auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == minus}); EXPECT(bool{r.result == minus});
} }
TEST_CASE(match_skip_output3) TEST_CASE(match_skip_output3)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto minus = mm.add_instruction(minus_op{}, two, one);
auto two = mm->add_literal(2); auto minus_pass1 = mm.add_instruction(pass_op{}, minus);
auto minus = mm->add_instruction(minus_op{}, two, one); auto minus_pass2 = mm.add_instruction(pass_op{}, minus_pass1);
auto minus_pass1 = mm->add_instruction(pass_op{}, minus); auto minus_pass3 = mm.add_instruction(pass_op{}, minus_pass2);
auto minus_pass2 = mm->add_instruction(pass_op{}, minus_pass1); auto sum = mm.add_instruction(sum_op{}, minus_pass3, two);
auto minus_pass3 = mm->add_instruction(pass_op{}, minus_pass2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, minus_pass3, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum"))); auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == minus}); EXPECT(bool{r.result == minus});
} }
TEST_CASE(match_skip_output4) TEST_CASE(match_skip_output4)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto pass = mm.add_instruction(pass_op{}, one);
auto two = mm->add_literal(2); auto sum = mm.add_instruction(sum_op{}, pass, two);
auto pass = mm->add_instruction(pass_op{}, one); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, pass, two);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum"))); auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == two}); EXPECT(bool{r.result == two});
} }
TEST_CASE(match_skip_output5) TEST_CASE(match_skip_output5)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto pass = mm.add_instruction(pass_op{}, one);
auto two = mm->add_literal(2); auto sum1 = mm.add_instruction(sum_op{}, pass, two);
auto pass = mm->add_instruction(pass_op{}, one); auto sum2 = mm.add_instruction(sum_op{}, sum1, one);
auto sum1 = mm->add_instruction(sum_op{}, pass, two); auto sum3 = mm.add_instruction(sum_op{}, sum2, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, one); mm.add_instruction(pass_op{}, sum3);
auto sum3 = mm->add_instruction(sum_op{}, sum2, two);
mm->add_instruction(pass_op{}, sum3);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum"))); auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_skip_output6) TEST_CASE(match_skip_output6)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto minus = mm.add_instruction(minus_op{}, two, one);
auto two = mm->add_literal(2); auto sum1 = mm.add_instruction(sum_op{}, minus, two);
auto minus = mm->add_instruction(minus_op{}, two, one); auto sum2 = mm.add_instruction(sum_op{}, sum1, one);
auto sum1 = mm->add_instruction(sum_op{}, minus, two); auto sum3 = mm.add_instruction(sum_op{}, sum2, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, one); mm.add_instruction(pass_op{}, sum3);
auto sum3 = mm->add_instruction(sum_op{}, sum2, two);
mm->add_instruction(pass_op{}, sum3);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum"))); auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == minus}); EXPECT(bool{r.result == minus});
} }
TEST_CASE(match_skip_output7) TEST_CASE(match_skip_output7)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto minus1 = mm.add_instruction(minus_op{}, two, one);
auto two = mm->add_literal(2); auto minus2 = mm.add_instruction(minus_op{}, two, minus1);
auto minus1 = mm->add_instruction(minus_op{}, two, one); auto sum = mm.add_instruction(sum_op{}, one, minus2);
auto minus2 = mm->add_instruction(minus_op{}, two, minus1); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, minus2);
mm->add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus"))); auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus")));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == minus1}); EXPECT(bool{r.result == minus1});
} }
TEST_CASE(match_bind1) TEST_CASE(match_bind1)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto pass = mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto pass = mm->add_instruction(pass_op{}, sum);
auto m = match::name("pass")( auto m = match::name("pass")(
match::args(match::name("sum")(match::args(match::name("@literal").bind("one"), match::args(match::name("sum")(match::args(match::name("@literal").bind("one"),
match::name("@literal").bind("two"))) match::name("@literal").bind("two")))
.bind("sum")), .bind("sum")),
match::standard_shape()) match::standard_shape())
.bind("pass"); .bind("pass");
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.instructions.at("one") == one}); EXPECT(bool{r.instructions.at("one") == one});
EXPECT(bool{r.instructions.at("two") == two}); EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum}); EXPECT(bool{r.instructions.at("sum") == sum});
...@@ -877,265 +767,280 @@ TEST_CASE(match_bind1) ...@@ -877,265 +767,280 @@ TEST_CASE(match_bind1)
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
TEST_CASE(match_has_value1) TEST_CASE(match_bind_modules1)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto* child = p.create_module("child");
auto two = child->add_literal(2);
auto sum = child->add_instruction(sum_op{}, one, two);
child->add_instruction(pass_op{}, sum);
mm->add_instruction(mod_pass_op{}, {one}, {child});
auto m = match::name("pass")(
match::args(match::name("sum")(match::args(match::name("@literal").bind("one"),
match::name("@literal").bind("two")))
.bind("sum")),
match::standard_shape())
.bind("pass");
auto r = find_match(*child, m);
EXPECT(not migraphx::contains(r.instructions, "one"));
EXPECT(not migraphx::contains(r.instructions, "two"));
EXPECT(not migraphx::contains(r.instructions, "sum"));
EXPECT(not migraphx::contains(r.instructions, "pass"));
EXPECT(bool{r.result == child->end()});
}
TEST_CASE(match_bind_modules2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto* child = p.create_module("child");
auto two = child->add_literal(2);
auto sum = child->add_instruction(sum_op{}, one, two);
auto pass = child->add_instruction(pass_op{}, sum);
mm->add_instruction(mod_pass_op{}, {one}, {child});
auto m = match::name("pass")(
match::args(match::name("sum")(match::args(match::name("@literal"),
match::name("@literal").bind("two")))
.bind("sum")),
match::standard_shape())
.bind("pass");
auto r = find_match(*child, m);
EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum});
EXPECT(bool{r.instructions.at("pass") == pass});
EXPECT(bool{r.result == pass});
}
auto* mm = p.get_main_module(); TEST_CASE(match_has_value1)
auto one = mm->add_literal(1); {
auto two = mm->add_literal(2); migraphx::module mm;
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto one = mm.add_literal(1);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two); auto two = mm.add_literal(2);
mm->add_instruction(pass_op{}, sum2); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::has_value(1); auto m = match::has_value(1);
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == one}); EXPECT(bool{r.result == one});
} }
TEST_CASE(match_has_value2) TEST_CASE(match_has_value2)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
auto m = match::has_value(2); auto m = match::has_value(2);
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == two}); EXPECT(bool{r.result == two});
} }
TEST_CASE(match_has_value3) TEST_CASE(match_has_value3)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(2))); auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(2)));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1}); EXPECT(bool{r.result == sum1});
} }
TEST_CASE(match_has_value4) TEST_CASE(match_has_value4)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
auto m = match::has_value(3); auto m = match::has_value(3);
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_has_value5) TEST_CASE(match_has_value5)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(3))); auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(3)));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_has_value6) TEST_CASE(match_has_value6)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(2), match::has_value(1))); auto m = match::name("sum")(match::args(match::has_value(2), match::has_value(1)));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_tree1) TEST_CASE(match_tree1)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto three = mm.add_literal(3);
auto two = mm->add_literal(2); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto three = mm->add_literal(3); auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
auto m = match::tree( auto m = match::tree(
match::name("sum"), match::has_value(1), match::has_value(2), match::has_value(3)); match::name("sum"), match::has_value(1), match::has_value(2), match::has_value(3));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
TEST_CASE(match_tree2) TEST_CASE(match_tree2)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto three = mm.add_literal(3);
auto two = mm->add_literal(2); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto three = mm->add_literal(3); auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
auto m = match::tree( auto m = match::tree(
match::name("sum"), match::has_value(2), match::has_value(1), match::has_value(3)); match::name("sum"), match::has_value(2), match::has_value(1), match::has_value(3));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_tree3) TEST_CASE(match_tree3)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto three = mm.add_literal(3);
auto two = mm->add_literal(2); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto three = mm->add_literal(3); auto sum2 = mm.add_instruction(sum_op{}, three, sum1);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, three, sum1);
mm->add_instruction(pass_op{}, sum2);
auto m = match::tree( auto m = match::tree(
match::name("sum"), match::has_value(3), match::has_value(1), match::has_value(2)); match::name("sum"), match::has_value(3), match::has_value(1), match::has_value(2));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
TEST_CASE(match_tree4) TEST_CASE(match_tree4)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto three = mm.add_literal(3);
auto two = mm->add_literal(2); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto three = mm->add_literal(3); auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
auto m = match::tree(match::name("sum"), auto m = match::tree(match::name("sum"),
match::has_value(1), match::has_value(1),
match::has_value(2), match::has_value(2),
match::has_value(3), match::has_value(3),
match::has_value(4)); match::has_value(4));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_tree5) TEST_CASE(match_tree5)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto three = mm.add_literal(3);
auto two = mm->add_literal(2); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto three = mm->add_literal(3); auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
auto m = match::tree(match::name("sum"), match::has_value(2), match::has_value(3)); auto m = match::tree(match::name("sum"), match::has_value(2), match::has_value(3));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_tree6) TEST_CASE(match_tree6)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto three = mm.add_literal(3);
auto two = mm->add_literal(2); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto three = mm->add_literal(3); auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
auto m = match::tree(match::name("sum"), match::has_value(1), match::has_value(3)); auto m = match::tree(match::name("sum"), match::has_value(1), match::has_value(3));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
TEST_CASE(match_unordered_tree1) TEST_CASE(match_unordered_tree1)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto three = mm.add_literal(3);
auto two = mm->add_literal(2); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto three = mm->add_literal(3); auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
auto m = match::unordered_tree( auto m = match::unordered_tree(
match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1)); match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
TEST_CASE(match_unordered_tree2) TEST_CASE(match_unordered_tree2)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto three = mm.add_literal(3);
auto two = mm->add_literal(2); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto three = mm->add_literal(3); auto sum2 = mm.add_instruction(sum_op{}, three, sum1);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, three, sum1);
mm->add_instruction(pass_op{}, sum2);
auto m = match::unordered_tree( auto m = match::unordered_tree(
match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1)); match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
TEST_CASE(match_unordered_tree3) TEST_CASE(match_unordered_tree3)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto three = mm.add_literal(3);
auto two = mm->add_literal(2); auto sum1 = mm.add_instruction(sum_op{}, two, one);
auto three = mm->add_literal(3); auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
auto sum1 = mm->add_instruction(sum_op{}, two, one); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
auto m = match::unordered_tree( auto m = match::unordered_tree(
match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1)); match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
TEST_CASE(match_unordered_tree4) TEST_CASE(match_unordered_tree4)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto three = mm.add_literal(3);
auto two = mm->add_literal(2); auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto three = mm->add_literal(3); auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
auto sum1 = mm->add_instruction(sum_op{}, one, two); mm.add_instruction(pass_op{}, sum2);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
auto m = match::unordered_tree( auto m = match::unordered_tree(
match::name("sum"), match::has_value(4), match::has_value(2), match::has_value(1)); match::name("sum"), match::has_value(4), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm.end()});
} }
struct match_find_sum struct match_find_sum
...@@ -1163,14 +1068,12 @@ struct match_find_literal ...@@ -1163,14 +1068,12 @@ struct match_find_literal
TEST_CASE(match_finder) TEST_CASE(match_finder)
{ {
migraphx::program p; migraphx::module mm;
auto one = mm.add_literal(1);
auto* mm = p.get_main_module(); auto two = mm.add_literal(2);
auto one = mm->add_literal(1); auto sum = mm.add_instruction(sum_op{}, one, two);
auto two = mm->add_literal(2); mm.add_instruction(pass_op{}, sum);
auto sum = mm->add_instruction(sum_op{}, one, two); match::find_matches(mm, match_find_sum{sum}, match_find_literal{sum});
mm->add_instruction(pass_op{}, sum);
match::find_matches(*mm, match_find_sum{sum}, match_find_literal{sum});
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment