Commit 18e96032 authored by Shiv's avatar Shiv
Browse files

add layernom fuse update

parents 2c93aa87 0d94f068
...@@ -43,16 +43,23 @@ struct layernorm_matcher ...@@ -43,16 +43,23 @@ struct layernorm_matcher
auto variance() const auto variance() const
{ {
return f("reduce_mean")(arg(0)(f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))))); return f("reduce_mean")(arg(0)(any_of(
f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))),
f("mul")(arg(0)(x_minus_mean()), arg(1)(x_minus_mean())),
f("sqdiff")(either_arg(0, 1)(any().bind("x"), skip_broadcasts(f("reduce_mean")))))));
} }
auto layernorm_onnx() const auto sqrt_add_eps(const std::string& name) const
{ {
auto add_eps = f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps"))); auto add_eps = f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps")));
return f("div")( return skip_broadcasts(f(name)(arg(0)(any_of(add_eps, variance()))));
arg(0)(x_minus_mean()), }
arg(1)(skip_broadcasts(f("sqrt")(arg(0)(match::any_of(add_eps, variance())))))); auto layernorm_onnx() const
{
auto div_sqrt = f("div")(arg(0)(x_minus_mean()), arg(1)(sqrt_add_eps("sqrt")));
auto mul_rsqrt = f("mul")(either_arg(0, 1)(x_minus_mean(), sqrt_add_eps("rsqrt")));
return any(any_of(div_sqrt, mul_rsqrt));
} }
auto matcher() const { return layernorm_onnx(); } auto matcher() const { return layernorm_onnx(); }
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/transpose.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/common.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
...@@ -340,12 +341,18 @@ struct find_inner_broadcast ...@@ -340,12 +341,18 @@ struct find_inner_broadcast
std::back_inserter(inputs), std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); }); [](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) { if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape(); return i->get_shape() != inputs.front()->get_shape() and
i->get_shape().elements() != 1;
})) }))
return; return;
auto op = m.insert_instruction(ins, ins->get_operator(), inputs); auto b_it = std::find_if(broadcasts.begin(), broadcasts.end(), [&](auto i) {
m.replace_instruction(ins, broadcasts.front()->get_operator(), op); return not i->get_shape().scalar();
});
if(b_it == broadcasts.end())
b_it = broadcasts.begin();
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
m.replace_instruction(ins, (*b_it)->get_operator(), op);
} }
}; };
......
...@@ -30,14 +30,14 @@ ...@@ -30,14 +30,14 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module_pass_manager;
namespace gpu { namespace gpu {
struct prefuse_ops struct prefuse_ops
{ {
std::string name() const { return "gpu::prefuse_ops"; } std::string name() const { return "gpu::prefuse_ops"; }
void apply(module& m) const; void apply(module_pass_manager& m) const;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -90,7 +92,9 @@ struct find_layernorm ...@@ -90,7 +92,9 @@ struct find_layernorm
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto eps = r.instructions["eps"]->eval().at<float>(); float eps = 0;
if(contains(r.instructions, "eps"))
eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, layernorm{eps}, x_ins); m.replace_instruction(ins, layernorm{eps}, x_ins);
} }
...@@ -100,26 +104,26 @@ struct find_add_layernorm ...@@ -100,26 +104,26 @@ struct find_add_layernorm
{ {
auto matcher() const auto matcher() const
{ {
return match::layernorm()( return match::name("gpu::prelayernorm")(
match::var("x")(match::name("add")(match::used_once()).bind("add"))); match::args(match::name("add")(match::used_once()).bind("add")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
float eps = 0; auto op = any_cast<layernorm>(ins->get_operator());
if(contains(r.instructions, "eps"))
eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, add_layernorm{eps}, add_ins->inputs()); m.replace_instruction(ins, add_layernorm{op.epsilon}, add_ins->inputs());
} }
}; };
} // namespace } // namespace
void prefuse_ops::apply(module& m) const void prefuse_ops::apply(module_pass_manager& mpm) const
{ {
match::find_matches(m, find_add_layernorm{}, find_layernorm{}); match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{});
} }
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment