Commit 03afd1ed authored by Paul's avatar Paul
Browse files

Fix matcher bugs in either_arg

parent 43a96492
......@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name)
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
auto result = m.match(ctx, ins);
if(result != ctx.not_found())
ctx.instructions.emplace(name, ins);
ctx.instructions[name] = ins;
return result;
});
}
......
......@@ -52,6 +52,7 @@ struct find_mul_conv
}
};
// a * (x + b) => a * x + a * b
struct find_mul_add
{
auto matcher() const
......@@ -60,7 +61,7 @@ struct find_mul_add
match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("y")),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
......@@ -70,12 +71,13 @@ struct find_mul_add
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
auto y_ins = r.instructions["y"];
assert(x_ins != b_ins);
auto xa_ins = p.insert_instruction(ins, op::mul{}, x_ins, a_ins);
auto ya_ins = p.insert_instruction(ins, op::mul{}, y_ins, a_ins);
p.replace_instruction(ins, op::add{}, xa_ins, ya_ins);
auto ax_ins = p.insert_instruction(ins, op::mul{}, a_ins, x_ins);
auto ab_ins = p.insert_instruction(ins, op::mul{}, a_ins, b_ins);
p.replace_instruction(ins, op::add{}, ax_ins, ab_ins);
}
};
......
......@@ -5,6 +5,8 @@
namespace match = migraphx::match;
MIGRAPHX_PRED_MATCHER(throws, migraphx::instruction_ref) { MIGRAPHX_THROW("Matcher throws"); }
template <class M>
migraphx::match::matcher_result find_match(migraphx::program& p, M&& m)
{
......@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3)
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_either_args_any1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_all_of1)
{
migraphx::program p;
......@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3)
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_lazy_any_of)
{
migraphx::program p;
auto one = p.add_literal(1);
p.add_instruction(pass_op{}, one);
auto m = match::any_of(match::any(), throws());
auto r = find_match(p, m);
EXPECT(bool{r.result == one});
}
TEST_CASE(match_lazy_all_of)
{
migraphx::program p;
auto one = p.add_literal(1);
p.add_instruction(pass_op{}, one);
auto m = match::all_of(match::none(), throws());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_lazy_none_of)
{
migraphx::program p;
auto one = p.add_literal(1);
p.add_instruction(pass_op{}, one);
auto m = match::none_of(match::any(), throws());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_any_of1)
{
migraphx::program p;
......@@ -396,6 +503,92 @@ TEST_CASE(match_any_of2)
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_any_of_lazy1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"), match::args(match::name("sum"), match::name("sum")).bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
EXPECT(not migraphx::contains(r.instructions, "y"));
}
TEST_CASE(match_any_of_lazy2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::name("@literal"), match::name("@literal")).bind("x"), match::args(match::any(), match::any()).bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
EXPECT(not migraphx::contains(r.instructions, "y"));
}
TEST_CASE(match_any_of_lazy3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"), match::args(match::name("@literal"), match::name("@literal")).bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
EXPECT(not migraphx::contains(r.instructions, "y"));
}
TEST_CASE(match_any_of_lazy4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::name("@literal").bind("x1"), match::name("@literal").bind("y1")), match::args(match::any().bind("x2"), match::any().bind("y2"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1"));
EXPECT(bool{r.instructions["x1"] == one});
EXPECT(bool{r.instructions["y1"] == two});
EXPECT(not migraphx::contains(r.instructions, "x2"));
EXPECT(not migraphx::contains(r.instructions, "y2"));
}
TEST_CASE(match_any_of_lazy5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::any().bind("x1"), match::any().bind("y1")), match::args(match::name("@literal").bind("x2"), match::name("@literal").bind("y2"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1"));
EXPECT(bool{r.instructions["x1"] == one});
EXPECT(bool{r.instructions["y1"] == two});
EXPECT(not migraphx::contains(r.instructions, "x2"));
EXPECT(not migraphx::contains(r.instructions, "y2"));
}
TEST_CASE(match_none_of1)
{
migraphx::program p;
......
......@@ -150,4 +150,30 @@ TEST_CASE(simplify_mul_conv1)
EXPECT(new_conv->outputs().front()->name() != "mul");
}
TEST_CASE(simplify_mul_add)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum = p1.add_instruction(migraphx::op::add{}, one, x);
auto mul = p1.add_instruction(migraphx::op::mul{}, sum, two);
p1.add_instruction(pass_op{}, mul);
}
p1.compile(simplify_algebra_target{});
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto mul1 = p2.add_instruction(migraphx::op::mul{}, two, x);
auto mul2 = p2.add_instruction(migraphx::op::mul{}, two, one);
auto sum = p2.add_instruction(migraphx::op::add{}, mul1, mul2);
p2.add_instruction(pass_op{}, sum);
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment