Commit 99b79a7d authored by turneram's avatar turneram
Browse files

Move matcher to gelu_erf.hpp

parent 6d582c24
...@@ -67,6 +67,42 @@ inline auto gelu_erf() ...@@ -67,6 +67,42 @@ inline auto gelu_erf()
return gelu_erf([](auto x) { return name(x); }); return gelu_erf([](auto x) { return name(x); });
} }
namespace detail {
template <class F>
struct bert_gelu_erf_matcher
{
F f;
auto erf_fn() const
{
return f("erf")(
used_once(),
arg(0)(used_once(),
f("div")(either_arg(0, 1)(none_of(has_value(1.414f, 1e-3)).bind("x"),
has_value(1.414f, 1e-3)))));
}
auto add_erf() const
{
return f("add")(used_once(), either_arg(0, 1)(erf_fn(), has_value(1.0f)));
}
auto one_half() const { return has_value(0.5f); }
auto matcher() const { return unordered_tree(f("mul"), one_half(), add_erf(), any()); }
};
} // namespace detail
template <class F>
auto bert_gelu_erf(F f)
{
return detail::bert_gelu_erf_matcher<F>{f}.matcher();
}
inline auto bert_gelu_erf()
{
return bert_gelu_erf([](auto x) { return name(x); });
}
} // namespace match } // namespace match
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -25,35 +25,16 @@ ...@@ -25,35 +25,16 @@
#include <migraphx/rewrite_gelu.hpp> #include <migraphx/rewrite_gelu.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/match/gelu_erf.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct find_gelu_erf struct find_gelu_erf
{ {
static auto match_div()
{
return match::name("div")(match::either_arg(0, 1)(
match::any().bind("x"), match::skip_broadcasts(match::has_value(1.414f, 1e-3))));
}
static auto match_erf() { return match::name("erf")(match::arg(0)(match_div())); }
static auto match_add()
{
return match::name("add")(
match::either_arg(0, 1)(match_erf(), match::skip_broadcasts(match::has_value(1.0f))));
}
static auto match_mul()
{
return match::name("mul")(match::either_arg(0, 1)(match::any(), match_add()));
}
auto matcher() const auto matcher() const
{ {
return match::name("mul")( return match::bert_gelu_erf();
match::either_arg(0, 1)(match_mul(), match::skip_broadcasts(match::has_value(0.5f))));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment