"benchmark/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "e2574ee986d2bdd3d7a1dd1561d69920a6581cd8"
Commit 48492136 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Add skip_broadcasts to simplify algebra value matchers

Adds this to handle broadcasted values instead of just scalars
parent 4aeacc17
......@@ -866,14 +866,14 @@ struct find_unit_ops
{
auto matcher() const
{
auto mul_1 = match::name("mul")(
match::either_arg(0, 1)(match::has_value(1.0f), match::any().bind("x")));
auto div_1 =
match::name("div")(match::args(match::any().bind("x"), match::has_value(1.0f)));
auto add_0 = match::name("add")(
match::either_arg(0, 1)(match::has_value(0.0f), match::any().bind("x")));
auto sub_0 =
match::name("sub")(match::args(match::any().bind("x"), match::has_value(0.0f)));
auto mul_1 = match::name("mul")(match::either_arg(0, 1)(
match::skip_broadcasts(match::has_value(1.0f)), match::any().bind("x")));
auto div_1 = match::name("div")(
match::args(match::any().bind("x"), match::skip_broadcasts(match::has_value(1.0f))));
auto add_0 = match::name("add")(match::either_arg(0, 1)(
match::skip_broadcasts(match::has_value(0.0f)), match::any().bind("x")));
auto sub_0 = match::name("sub")(
match::args(match::any().bind("x"), match::skip_broadcasts(match::has_value(0.0f))));
return match::any_of(mul_1, div_1, add_0, sub_0);
}
......@@ -890,12 +890,12 @@ struct find_neg_unit_ops
{
auto matcher() const
{
auto mul_neg_1 = match::name("mul")(
match::either_arg(0, 1)(match::has_value(-1.0f), match::any().bind("x")));
auto div_neg_1 =
match::name("div")(match::args(match::any().bind("x"), match::has_value(-1.0f)));
auto sub_0 =
match::name("sub")(match::args(match::has_value(0.0f), match::any().bind("x")));
auto mul_neg_1 = match::name("mul")(match::either_arg(0, 1)(
match::skip_broadcasts(match::has_value(-1.0f)), match::any().bind("x")));
auto div_neg_1 = match::name("div")(
match::args(match::any().bind("x"), match::skip_broadcasts(match::has_value(-1.0f))));
auto sub_0 = match::name("sub")(
match::args(match::skip_broadcasts(match::has_value(0.0f)), match::any().bind("x")));
return match::any_of(mul_neg_1, div_neg_1, sub_0);
}
......@@ -913,10 +913,10 @@ struct find_zero_ops
{
auto matcher() const
{
auto mul_zero = match::name("mul")(
match::either_arg(0, 1)(match::has_value(0.0f).bind("x"), match::any()));
auto div_zero =
match::name("div")(match::args(match::has_value(0.0f).bind("x"), match::any()));
auto mul_zero = match::name("mul")(match::either_arg(0, 1)(
match::skip_broadcasts(match::has_value(0.0f).bind("x")), match::any()));
auto div_zero = match::name("div")(
match::args(match::skip_broadcasts(match::has_value(0.0f).bind("x")), match::any()));
return match::any_of(mul_zero, div_zero);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment