Commit 48492136 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Add skip_broadcasts to simplify algebra value matchers

Adds this to handle broadcasted values instead of just scalars
parent 4aeacc17
...@@ -866,14 +866,14 @@ struct find_unit_ops ...@@ -866,14 +866,14 @@ struct find_unit_ops
{ {
auto matcher() const auto matcher() const
{ {
auto mul_1 = match::name("mul")( auto mul_1 = match::name("mul")(match::either_arg(0, 1)(
match::either_arg(0, 1)(match::has_value(1.0f), match::any().bind("x"))); match::skip_broadcasts(match::has_value(1.0f)), match::any().bind("x")));
auto div_1 = auto div_1 = match::name("div")(
match::name("div")(match::args(match::any().bind("x"), match::has_value(1.0f))); match::args(match::any().bind("x"), match::skip_broadcasts(match::has_value(1.0f))));
auto add_0 = match::name("add")( auto add_0 = match::name("add")(match::either_arg(0, 1)(
match::either_arg(0, 1)(match::has_value(0.0f), match::any().bind("x"))); match::skip_broadcasts(match::has_value(0.0f)), match::any().bind("x")));
auto sub_0 = auto sub_0 = match::name("sub")(
match::name("sub")(match::args(match::any().bind("x"), match::has_value(0.0f))); match::args(match::any().bind("x"), match::skip_broadcasts(match::has_value(0.0f))));
return match::any_of(mul_1, div_1, add_0, sub_0); return match::any_of(mul_1, div_1, add_0, sub_0);
} }
...@@ -890,12 +890,12 @@ struct find_neg_unit_ops ...@@ -890,12 +890,12 @@ struct find_neg_unit_ops
{ {
auto matcher() const auto matcher() const
{ {
auto mul_neg_1 = match::name("mul")( auto mul_neg_1 = match::name("mul")(match::either_arg(0, 1)(
match::either_arg(0, 1)(match::has_value(-1.0f), match::any().bind("x"))); match::skip_broadcasts(match::has_value(-1.0f)), match::any().bind("x")));
auto div_neg_1 = auto div_neg_1 = match::name("div")(
match::name("div")(match::args(match::any().bind("x"), match::has_value(-1.0f))); match::args(match::any().bind("x"), match::skip_broadcasts(match::has_value(-1.0f))));
auto sub_0 = auto sub_0 = match::name("sub")(
match::name("sub")(match::args(match::has_value(0.0f), match::any().bind("x"))); match::args(match::skip_broadcasts(match::has_value(0.0f)), match::any().bind("x")));
return match::any_of(mul_neg_1, div_neg_1, sub_0); return match::any_of(mul_neg_1, div_neg_1, sub_0);
} }
...@@ -913,10 +913,10 @@ struct find_zero_ops ...@@ -913,10 +913,10 @@ struct find_zero_ops
{ {
auto matcher() const auto matcher() const
{ {
auto mul_zero = match::name("mul")( auto mul_zero = match::name("mul")(match::either_arg(0, 1)(
match::either_arg(0, 1)(match::has_value(0.0f).bind("x"), match::any())); match::skip_broadcasts(match::has_value(0.0f).bind("x")), match::any()));
auto div_zero = auto div_zero = match::name("div")(
match::name("div")(match::args(match::has_value(0.0f).bind("x"), match::any())); match::args(match::skip_broadcasts(match::has_value(0.0f).bind("x")), match::any()));
return match::any_of(mul_zero, div_zero); return match::any_of(mul_zero, div_zero);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment