"vscode:/vscode.git/clone" did not exist on "3bcc796d382025f6a7a9bc24cadf2d44b554790d"
Unverified Commit 728d083d authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add cpu fusion for gelu and layernorm (#761)



* Add eliminate_data_type pass

* Formatting

* Auto convert quant ops

* Formatting

* Flip the order of decompose

* Compute max size differently

* Formatting

* Clamp values in convert

* Formatting

* Fix loss of precision in reduce

* Formatting

* Fix bugs in reduction

* Fix accumulator type in reference softmax implementation

* Formatting

* Update convert test

* Remove unused variables

* Remove unnecessary quant_dot check

* Formatting

* Add tests

* Formatting

* Remove unused code

* Remove duplicate ops

* Remove blaze dependency

* Use set since shape::type_t is no hashable on gcc 5

* Formatting

* Add dnnl binary op

* Formatting

* Add binary and eltwise

* Formatting

* Add softmax

* Formatting

* Remove unused operators

* Add missing files

* Formatting

* Add lrn

* Formatting

* Add deconvolution

* Formatting

* Change allocate default

* Add reorder

* Formatting

* Add reductions

* Formatting

* Sort lines

* Change literals in another loop

* Add pow operator

* Formatting

* Add pow operator

* Formatting

* Make sure shapes are packed

* Allow broadcasted inputs

* Remove unused operators

* Simplify functions

* Remove softmax

* Add sub and erf functions

* Formatting

* Fix bug

* Formatting

* Improve parallism

* Formatting

* Allow multiple batch dimensions

* Formatting

* Move literal transforms out of lowering

* Formatting

* Add gather operator

* Sort lines

* Add early exit for carry

* Formatting

* Add missing concat

* Rename macro

* Fix deep nesting

* Formatting

* Fix cppcheck issues

* Remov else

* Move attribute to typedef

* Formatting

* Disable maybe-uninitialized warning since its broken on gcc

* Add constexpr default constructor

* Formatting

* Fix compiler warnings

* Fix adjust_allocation test

* Add layernorm matcher

* Add gelu_erf matcher

* Formatting

* Add gelu_tanh matcher

* Formatting

* Remove match namespace

* Formatting

* Use matcher instead of string

* Formatting

* Add fusions

* Formatting

* Make input a const ref

* Make this explicit for gcc 5
Co-authored-by: default avatarShucai Xiao <shucai@gmail.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 51fb672d
#ifndef MIGRAPHX_GUARD_MATCH_GELU_ERF_HPP
#define MIGRAPHX_GUARD_MATCH_GELU_ERF_HPP
#include <migraphx/config.hpp>
#include <migraphx/matcher.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace match {
namespace detail {
template <class F>
struct gelu_erf_matcher
{
F f;
auto erf_fn() const
{
return f("erf")(
used_once(),
arg(0)(used_once(),
f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"),
has_value(M_SQRT1_2, 1e-3)))));
}
auto add_erf() const
{
return f("add")(used_once(), either_arg(0, 1)(erf_fn(), has_value(1.0f)));
}
auto one_half() const { return has_value(0.5f); }
auto matcher() const { return unordered_tree(f("mul"), one_half(), add_erf(), any()); }
};
} // namespace detail
template <class F>
auto gelu_erf(F f)
{
return detail::gelu_erf_matcher<F>{f}.matcher();
}
inline auto gelu_erf()
{
return gelu_erf([](auto x) { return name(x); });
}
} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MATCH_GELU_ERF_HPP
#ifndef MIGRAPHX_GUARD_MATCH_GELU_TANH_HPP
#define MIGRAPHX_GUARD_MATCH_GELU_TANH_HPP
#include <migraphx/config.hpp>
#include <migraphx/matcher.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace match {
namespace detail {
template <class F>
struct gelu_tanh_matcher
{
F f;
auto pow_fn() const { return f("pow")(used_once(), arg(1)(has_value(3.0f))); }
auto tanh_fn() const
{
return f("tanh")(
used_once(),
arg(0)(f("mul")(either_arg(0, 1)(has_value(sqrt(M_2_PI), 1e-3),
f("add")(any_arg(0, 1)(f("mul")(either_arg(0, 1)(
has_value(0.044715f), pow_fn()))))))));
}
auto matcher() const
{
return f("mul")(used_once(),
either_arg(0, 1)(any().bind("x"),
f("add")(any_arg(0, 1)(f("mul")(
either_arg(0, 1)(has_value(0.5f), tanh_fn()))))));
}
};
} // namespace detail
template <class F>
auto gelu_tanh(F f)
{
return detail::gelu_tanh_matcher<F>{f}.matcher();
}
inline auto gelu_tanh()
{
return gelu_tanh([](auto x) { return name(x); });
}
} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MATCH_GELU_TANH_HPP
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_MATCH_LAYERNORM_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_MATCH_LAYERNORM_HPP
#include <migraphx/config.hpp>
#include <migraphx/matcher.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace match {
namespace detail {
template <class F>
struct layernorm_matcher
{
F f;
auto x_minus_mean() const
{
return f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(f("reduce_mean"))));
}
auto variance() const
{
return f("reduce_mean")(arg(0)(f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f)))));
}
auto layernorm_onnx() const
{
return f("div")(arg(0)(x_minus_mean()),
arg(1)(skip_broadcasts(f("sqrt")(
arg(0)(f("add")(either_arg(0, 1)(variance(), has_value(1e-12f))))))));
}
auto matcher() const { return layernorm_onnx(); }
};
} // namespace detail
template <class F>
auto layernorm(F f)
{
return detail::layernorm_matcher<F>{f}.matcher();
}
inline auto layernorm()
{
return layernorm([](auto x) { return name(x); });
}
} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -258,6 +258,22 @@ void find_matches(module& p, Ms&&... ms) ...@@ -258,6 +258,22 @@ void find_matches(module& p, Ms&&... ms)
} }
} }
template <class M, class F>
struct find_generic_match
{
M m;
F f;
M matcher() const { return m; }
void apply(module& mod, const matcher_result& mr) const { f(mod, mr); }
};
template <class M, class F>
find_generic_match<M, F> make_match_finder(M m, F f)
{
return {m, f};
}
template <class M> template <class M>
struct find_skip struct find_skip
{ {
...@@ -452,6 +468,22 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in ...@@ -452,6 +468,22 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
return ctx.not_found(); return ctx.not_found();
} }
template <class... Ms>
auto skip(Ms... ms)
{
auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) {
if(ins->inputs().size() == 1 and ctx.matched(m, ins))
{
auto next = ins->inputs().front();
return self(next);
}
return ins;
})(start);
});
}
template <class... Ms> template <class... Ms>
auto skip_output(Ms... ms) auto skip_output(Ms... ms)
{ {
...@@ -547,15 +579,17 @@ inline auto any_arg(std::size_t i, std::size_t j) ...@@ -547,15 +579,17 @@ inline auto any_arg(std::size_t i, std::size_t j)
return [=](auto m) { return match::any_of(arg(i)(m), arg(j)(m)); }; return [=](auto m) { return match::any_of(arg(i)(m), arg(j)(m)); };
} }
template <std::size_t N> template <std::size_t N, class M>
std::size_t std::size_t tree_leafs_impl(matcher_context& ctx,
tree_leafs_impl(std::array<instruction_ref, N>& leafs, const std::string& s, instruction_ref ins) std::array<instruction_ref, N>& leafs,
M m,
instruction_ref ins)
{ {
std::size_t idx = 0; std::size_t idx = 0;
fix([&](auto self, auto i) { fix([&](auto self, auto i) {
if(idx == leafs.size()) if(idx == leafs.size())
return; return;
if(i->name() == s and i->inputs().size() >= 2) if(ctx.matched(m, i) and i->inputs().size() >= 2)
{ {
self(i->inputs()[0]); self(i->inputs()[0]);
self(i->inputs()[1]); self(i->inputs()[1]);
...@@ -567,13 +601,13 @@ tree_leafs_impl(std::array<instruction_ref, N>& leafs, const std::string& s, ins ...@@ -567,13 +601,13 @@ tree_leafs_impl(std::array<instruction_ref, N>& leafs, const std::string& s, ins
return idx; return idx;
} }
template <class... Ms> template <class M, class... Ms>
auto tree(std::string s, Ms... ms) auto tree(M main_op, Ms... ms)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
// Flatten leaf nodes // Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs; std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(leafs, s, ins); std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size()) if(idx != leafs.size())
return ctx.not_found(); return ctx.not_found();
// Use explicit captures to workaround ICE on gcc // Use explicit captures to workaround ICE on gcc
...@@ -586,13 +620,13 @@ auto tree(std::string s, Ms... ms) ...@@ -586,13 +620,13 @@ auto tree(std::string s, Ms... ms)
}); });
} }
template <class... Ms> template <class M, class... Ms>
auto unordered_tree(std::string s, Ms... ms) auto unordered_tree(M main_op, Ms... ms)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
// Flatten leaf nodes // Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs; std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(leafs, s, ins); std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size()) if(idx != leafs.size())
return ctx.not_found(); return ctx.not_found();
// Use explicit captures to workaround ICE on gcc // Use explicit captures to workaround ICE on gcc
...@@ -624,10 +658,16 @@ auto same_shape(Ms... ms) ...@@ -624,10 +658,16 @@ auto same_shape(Ms... ms)
return all_of(same_shape(ms)...); return all_of(same_shape(ms)...);
} }
template <class... Ms>
auto skip_broadcasts(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous"))(ms...);
}
template <class T> template <class T>
inline auto has_value(T x, float tolerance = 1e-6) inline auto has_value(T x, float tolerance = 1e-6)
{ {
return make_basic_pred_matcher([=](instruction_ref ins) { return skip_broadcasts(make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "@literal") if(ins->name() != "@literal")
return false; return false;
auto l = ins->get_literal(); auto l = ins->get_literal();
...@@ -640,7 +680,7 @@ inline auto has_value(T x, float tolerance = 1e-6) ...@@ -640,7 +680,7 @@ inline auto has_value(T x, float tolerance = 1e-6)
b = true; b = true;
}); });
return b; return b;
}); }));
} }
inline auto has_attribute(const std::string& name) inline auto has_attribute(const std::string& name)
......
...@@ -14,6 +14,7 @@ add_library(migraphx_cpu ...@@ -14,6 +14,7 @@ add_library(migraphx_cpu
erf.cpp erf.cpp
gather.cpp gather.cpp
gemm.cpp gemm.cpp
layernorm.cpp
logsoftmax.cpp logsoftmax.cpp
lowering.cpp lowering.cpp
lrn.cpp lrn.cpp
......
#include <migraphx/config.hpp>
#include <migraphx/cpu/dnnl.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
struct dnnl_layernorm : dnnl_op<dnnl_layernorm, dnnl::layer_normalization_forward>
{
float epsilon = 1e-12f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.epsilon, "epsilon"));
}
std::string name() const { return "dnnl::layernorm"; }
shape compute_shape(std::vector<shape> inputs) const
{
// Compensate for allocation
inputs.pop_back();
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
// Call to get_primitive to make sure an algo is available
this->get_primitive(this->to_memory_desc(s, inputs));
return s;
}
dnnl::layer_normalization_forward::desc
get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {dnnl::prop_kind::forward_inference,
m.at(DNNL_ARG_SRC),
1e-12f,
dnnl::normalization_flags::none};
}
};
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -29,6 +29,10 @@ ...@@ -29,6 +29,10 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/tune_axis.hpp> #include <migraphx/tune_axis.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/match/gelu_erf.hpp>
#include <migraphx/match/gelu_tanh.hpp>
#include <migraphx/matcher.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <iostream> #include <iostream>
...@@ -336,6 +340,21 @@ struct cpu_apply ...@@ -336,6 +340,21 @@ struct cpu_apply
} }
} }
template <class M>
auto fuse_match(M matcher, const operation& op, const std::vector<std::string>& bind_inputs)
{
return match::make_match_finder(matcher, [=](auto&, const auto& r) {
auto ins = r.result;
std::vector<instruction_ref> inputs;
std::transform(bind_inputs.begin(),
bind_inputs.end(),
std::back_inserter(inputs),
[&](const auto& s) { return r.instructions.at(s); });
inputs.push_back(this->insert_allocation(ins, ins->get_shape()));
this->modl->replace_instruction(ins, op, inputs);
});
}
void init() void init()
{ {
create_output_names(); create_output_names();
...@@ -388,6 +407,15 @@ struct cpu_apply ...@@ -388,6 +407,15 @@ struct cpu_apply
void apply() void apply()
{ {
init(); init();
// Apply fusion matchers first
match::find_matches(*modl,
fuse_match(match::gelu_erf(),
make_op("dnnl::eltwise", {{"algo", "eltwise_gelu_erf"}}),
{"x"}),
fuse_match(match::gelu_tanh(),
make_op("dnnl::eltwise", {{"algo", "eltwise_gelu_tanh"}}),
{"x"}),
fuse_match(match::layernorm(), make_op("dnnl::layernorm"), {"x"}));
// Apply these operators first so the inputs can be const folded // Apply these operators first so the inputs can be const folded
for(auto it : iterator_for(*modl)) for(auto it : iterator_for(*modl))
{ {
......
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
#include <migraphx/gpu/device/add_tanh.hpp> #include <migraphx/gpu/device/add_tanh.hpp>
#include <migraphx/gpu/device/mul_add_relu.hpp> #include <migraphx/gpu/device/mul_add_relu.hpp>
#include <migraphx/gpu/device/add.hpp> #include <migraphx/gpu/device/add.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/match/gelu_erf.hpp>
#include <migraphx/match/gelu_tanh.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
...@@ -295,39 +298,11 @@ void move_standard_front(std::vector<instruction_ref>& args) ...@@ -295,39 +298,11 @@ void move_standard_front(std::vector<instruction_ref>& args)
std::swap(*it, args.front()); std::swap(*it, args.front());
} }
auto gpu_name(const std::string& s) { return match::name("gpu::" + s); }
struct find_layernorm struct find_layernorm
{ {
template <class... Ts> auto matcher() const { return match::layernorm(&gpu_name); }
static auto multibroadcast_op(Ts... xs)
{
return match::name("multibroadcast")(match::arg(0)(xs...));
}
static auto x_minus_mean()
{
return match::name("gpu::sub")(
match::arg(0)(match::any().bind("x")),
match::arg(1)(multibroadcast_op(match::name("gpu::reduce_mean"))));
}
static auto variance()
{
return match::name("gpu::reduce_mean")(match::arg(0)(
match::name("gpu::pow")(match::arg(0)(x_minus_mean()),
match::arg(1)(multibroadcast_op(match::has_value(2.0f))))));
}
static auto layernorm_onnx()
{
return match::name("gpu::div")(
match::arg(0)(x_minus_mean()),
match::arg(1)(multibroadcast_op(
match::name("gpu::sqrt")(match::arg(0)(match::name("gpu::add")(match::either_arg(
0, 1)(variance(), multibroadcast_op(match::has_value(1e-12f)))))))));
}
auto matcher() const { return layernorm_onnx(); }
void apply(module& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
...@@ -366,30 +341,7 @@ struct find_triadd_layernorm ...@@ -366,30 +341,7 @@ struct find_triadd_layernorm
struct find_gelu struct find_gelu
{ {
auto matcher() const { return match::gelu_erf(&gpu_name); }
static auto erf_fn()
{
return match::name("gpu::erf")(
match::used_once(),
match::arg(0)(match::used_once(),
match::name("gpu::mul")(match::either_arg(0, 1)(
match::none_of(match::has_value(M_SQRT1_2, 1e-3)).bind("x"),
match::has_value(M_SQRT1_2, 1e-3)))));
}
static auto add_erf()
{
return match::name("gpu::add")(
match::used_once(),
match::either_arg(0, 1)(erf_fn(), match::args(match::has_value(1.0f))));
}
static auto one_half() { return match::args(match::has_value(0.5f)); }
auto matcher() const
{
return match::unordered_tree("gpu::mul", one_half(), add_erf(), match::any());
}
void apply(module& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
...@@ -425,32 +377,7 @@ struct find_gelu_new ...@@ -425,32 +377,7 @@ struct find_gelu_new
{ {
bool fast_math = true; bool fast_math = true;
static auto pow_fn() auto matcher() const { return match::gelu_tanh(&gpu_name); }
{
return match::name("gpu::pow")(match::used_once(),
match::arg(1)(match::args(match::has_value(3.0f))));
}
static auto tanh_fn()
{
return match::name("gpu::tanh")(
match::used_once(),
match::arg(0)(match::name("gpu::mul")(match::either_arg(0, 1)(
match::args(match::has_value(sqrt(M_2_PI), 1e-3)),
match::name("gpu::add")(
match::any_arg(0, 1)(match::name("gpu::mul")(match::either_arg(0, 1)(
match::args(match::has_value(0.044715f)), pow_fn()))))))));
}
auto matcher() const
{
return match::name("gpu::mul")(
match::used_once(),
match::either_arg(0, 1)(
match::any().bind("x"),
match::name("gpu::add")(match::any_arg(0, 1)(match::name("gpu::mul")(
match::either_arg(0, 1)(match::args(match::has_value(0.5f)), tanh_fn()))))));
}
void apply(module& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
......
...@@ -978,7 +978,8 @@ TEST_CASE(match_tree1) ...@@ -978,7 +978,8 @@ TEST_CASE(match_tree1)
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three); auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(1), match::has_value(2), match::has_value(3)); auto m = match::tree(
match::name("sum"), match::has_value(1), match::has_value(2), match::has_value(3));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
...@@ -994,7 +995,8 @@ TEST_CASE(match_tree2) ...@@ -994,7 +995,8 @@ TEST_CASE(match_tree2)
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three); auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(2), match::has_value(1), match::has_value(3)); auto m = match::tree(
match::name("sum"), match::has_value(2), match::has_value(1), match::has_value(3));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm->end()});
} }
...@@ -1010,7 +1012,8 @@ TEST_CASE(match_tree3) ...@@ -1010,7 +1012,8 @@ TEST_CASE(match_tree3)
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, three, sum1); auto sum2 = mm->add_instruction(sum_op{}, three, sum1);
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(3), match::has_value(1), match::has_value(2)); auto m = match::tree(
match::name("sum"), match::has_value(3), match::has_value(1), match::has_value(2));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
...@@ -1026,8 +1029,11 @@ TEST_CASE(match_tree4) ...@@ -1026,8 +1029,11 @@ TEST_CASE(match_tree4)
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three); auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = match::tree( auto m = match::tree(match::name("sum"),
"sum", match::has_value(1), match::has_value(2), match::has_value(3), match::has_value(4)); match::has_value(1),
match::has_value(2),
match::has_value(3),
match::has_value(4));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm->end()});
} }
...@@ -1043,7 +1049,7 @@ TEST_CASE(match_tree5) ...@@ -1043,7 +1049,7 @@ TEST_CASE(match_tree5)
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three); auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(2), match::has_value(3)); auto m = match::tree(match::name("sum"), match::has_value(2), match::has_value(3));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm->end()});
} }
...@@ -1059,7 +1065,7 @@ TEST_CASE(match_tree6) ...@@ -1059,7 +1065,7 @@ TEST_CASE(match_tree6)
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three); auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(1), match::has_value(3)); auto m = match::tree(match::name("sum"), match::has_value(1), match::has_value(3));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm->end()});
} }
...@@ -1075,8 +1081,8 @@ TEST_CASE(match_unordered_tree1) ...@@ -1075,8 +1081,8 @@ TEST_CASE(match_unordered_tree1)
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three); auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = auto m = match::unordered_tree(
match::unordered_tree("sum", match::has_value(3), match::has_value(2), match::has_value(1)); match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
...@@ -1092,8 +1098,8 @@ TEST_CASE(match_unordered_tree2) ...@@ -1092,8 +1098,8 @@ TEST_CASE(match_unordered_tree2)
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, three, sum1); auto sum2 = mm->add_instruction(sum_op{}, three, sum1);
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = auto m = match::unordered_tree(
match::unordered_tree("sum", match::has_value(3), match::has_value(2), match::has_value(1)); match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
...@@ -1109,8 +1115,8 @@ TEST_CASE(match_unordered_tree3) ...@@ -1109,8 +1115,8 @@ TEST_CASE(match_unordered_tree3)
auto sum1 = mm->add_instruction(sum_op{}, two, one); auto sum1 = mm->add_instruction(sum_op{}, two, one);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three); auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = auto m = match::unordered_tree(
match::unordered_tree("sum", match::has_value(3), match::has_value(2), match::has_value(1)); match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
...@@ -1126,8 +1132,8 @@ TEST_CASE(match_unordered_tree4) ...@@ -1126,8 +1132,8 @@ TEST_CASE(match_unordered_tree4)
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three); auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = auto m = match::unordered_tree(
match::unordered_tree("sum", match::has_value(4), match::has_value(2), match::has_value(1)); match::name("sum"), match::has_value(4), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()}); EXPECT(bool{r.result == mm->end()});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment