Unverified Commit 01907e27 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Update matcher to better match layernorm fusions in other models (#1548) (#1567)

* Fuse layernorm with different patterns
* Only match when using the last axis
parent b1d3c954
...@@ -36,23 +36,46 @@ template <class F> ...@@ -36,23 +36,46 @@ template <class F>
struct layernorm_matcher struct layernorm_matcher
{ {
F f; F f;
auto last_axis() const
{
return make_basic_pred_matcher([](instruction_ref ins) {
auto v = ins->get_operator().to_value();
if(not v.contains("axes"))
return false;
auto axes = v["axes"].to_vector<std::size_t>();
if(axes.size() != 1)
return false;
return axes.front() == ins->inputs().front()->get_shape().lens().size() - 1;
});
}
auto reduce_mean() const { return f("reduce_mean")(last_axis()); }
auto x_minus_mean() const auto x_minus_mean() const
{ {
return f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(f("reduce_mean")))); return f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(reduce_mean())));
} }
auto variance() const auto variance() const
{ {
return f("reduce_mean")(arg(0)(f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))))); return reduce_mean()(arg(0)(any_of(
f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))),
f("mul")(arg(0)(x_minus_mean()), arg(1)(x_minus_mean())),
f("sqdiff")(either_arg(0, 1)(any().bind("x"), skip_broadcasts(reduce_mean()))))));
} }
auto layernorm_onnx() const auto sqrt_add_eps(const std::string& name) const
{ {
auto add_eps = f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps"))); auto add_eps = f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps")));
return f("div")( return skip_broadcasts(f(name)(arg(0)(any_of(add_eps, variance()))));
arg(0)(x_minus_mean()), }
arg(1)(skip_broadcasts(f("sqrt")(arg(0)(match::any_of(add_eps, variance())))))); auto layernorm_onnx() const
{
auto div_sqrt = f("div")(arg(0)(x_minus_mean()), arg(1)(sqrt_add_eps("sqrt")));
auto mul_rsqrt = f("mul")(either_arg(0, 1)(x_minus_mean(), sqrt_add_eps("rsqrt")));
return any(any_of(div_sqrt, mul_rsqrt));
} }
auto matcher() const { return layernorm_onnx(); } auto matcher() const { return layernorm_onnx(); }
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/transpose.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/common.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
...@@ -340,12 +341,18 @@ struct find_inner_broadcast ...@@ -340,12 +341,18 @@ struct find_inner_broadcast
std::back_inserter(inputs), std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); }); [](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) { if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape(); return i->get_shape() != inputs.front()->get_shape() and
i->get_shape().elements() != 1;
})) }))
return; return;
auto op = m.insert_instruction(ins, ins->get_operator(), inputs); auto b_it = std::find_if(broadcasts.begin(), broadcasts.end(), [&](auto i) {
m.replace_instruction(ins, broadcasts.front()->get_operator(), op); return not i->get_shape().scalar();
});
if(b_it == broadcasts.end())
b_it = broadcasts.begin();
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
m.replace_instruction(ins, (*b_it)->get_operator(), op);
} }
}; };
......
...@@ -30,14 +30,14 @@ ...@@ -30,14 +30,14 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module_pass_manager;
namespace gpu { namespace gpu {
struct prefuse_ops struct prefuse_ops
{ {
std::string name() const { return "gpu::prefuse_ops"; } std::string name() const { return "gpu::prefuse_ops"; }
void apply(module& m) const; void apply(module_pass_manager& mpm) const;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -90,7 +92,9 @@ struct find_layernorm ...@@ -90,7 +92,9 @@ struct find_layernorm
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto eps = r.instructions["eps"]->eval().at<float>(); float eps = 0;
if(contains(r.instructions, "eps"))
eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, layernorm{eps}, x_ins); m.replace_instruction(ins, layernorm{eps}, x_ins);
} }
...@@ -100,26 +104,26 @@ struct find_add_layernorm ...@@ -100,26 +104,26 @@ struct find_add_layernorm
{ {
auto matcher() const auto matcher() const
{ {
return match::layernorm()( return match::name("gpu::prelayernorm")(
match::var("x")(match::name("add")(match::used_once()).bind("add"))); match::args(match::name("add")(match::used_once()).bind("add")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
float eps = 0; auto op = any_cast<layernorm>(ins->get_operator());
if(contains(r.instructions, "eps"))
eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, add_layernorm{eps}, add_ins->inputs()); m.replace_instruction(ins, add_layernorm{op.epsilon}, add_ins->inputs());
} }
}; };
} // namespace } // namespace
void prefuse_ops::apply(module& m) const void prefuse_ops::apply(module_pass_manager& mpm) const
{ {
match::find_matches(m, find_add_layernorm{}, find_layernorm{}); match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -559,6 +559,32 @@ TEST_CASE(simplify_inner_broadcast2) ...@@ -559,6 +559,32 @@ TEST_CASE(simplify_inner_broadcast2)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_inner_broadcast_scalar)
{
auto b = migraphx::op::multibroadcast{{32, 384}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(b, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{1, 384}}, y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_add_conv1) TEST_CASE(simplify_add_conv1)
{ {
migraphx::module m; migraphx::module m;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment