Commit a4a654c9 authored by Paul's avatar Paul
Browse files

Add either_arg matcher

parent 7c416eb9
......@@ -310,6 +310,16 @@ auto args(Ms... ms)
});
}
auto either_arg(std::size_t i, std::size_t j)
{
return [=](auto m1, auto m2) {
return match::any_of(
match::all_of(arg(i)(m1), arg(j)(m2)),
match::all_of(arg(j)(m1), arg(i)(m2))
);
};
}
} // namespace match
} // namespace migraph
......
......@@ -167,11 +167,10 @@ struct match_conv_bias
context* ctx = nullptr;
auto matcher() const
{
return match::name("gpu::add")(match::any_of(
match::all_of(match::arg(0)(match::broadcast_shape().bind("bias")),
match::arg(1)(match::name("gpu::convolution").bind("conv"))),
match::all_of(match::arg(1)(match::broadcast_shape().bind("bias")),
match::arg(0)(match::name("gpu::convolution").bind("conv")))));
return match::name("gpu::add")(match::either_arg(0, 1)(
match::broadcast_shape().bind("bias"),
match::name("gpu::convolution").bind("conv")
));
}
void apply(program& p, match::matcher_result r) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment