Unverified Commit 01907e27 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Update matcher to better match layernorm fusions in other models (#1548) (#1567)

* Fuse layernorm with different patterns
* Only match when using the last axis
parent b1d3c954
......@@ -36,23 +36,46 @@ template <class F>
struct layernorm_matcher
{
F f;
auto last_axis() const
{
return make_basic_pred_matcher([](instruction_ref ins) {
auto v = ins->get_operator().to_value();
if(not v.contains("axes"))
return false;
auto axes = v["axes"].to_vector<std::size_t>();
if(axes.size() != 1)
return false;
return axes.front() == ins->inputs().front()->get_shape().lens().size() - 1;
});
}
auto reduce_mean() const { return f("reduce_mean")(last_axis()); }
auto x_minus_mean() const
{
return f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(f("reduce_mean"))));
return f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(reduce_mean())));
}
auto variance() const
{
return f("reduce_mean")(arg(0)(f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f)))));
return reduce_mean()(arg(0)(any_of(
f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))),
f("mul")(arg(0)(x_minus_mean()), arg(1)(x_minus_mean())),
f("sqdiff")(either_arg(0, 1)(any().bind("x"), skip_broadcasts(reduce_mean()))))));
}
auto layernorm_onnx() const
auto sqrt_add_eps(const std::string& name) const
{
auto add_eps = f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps")));
return f("div")(
arg(0)(x_minus_mean()),
return skip_broadcasts(f(name)(arg(0)(any_of(add_eps, variance()))));
}
arg(1)(skip_broadcasts(f("sqrt")(arg(0)(match::any_of(add_eps, variance()))))));
auto layernorm_onnx() const
{
auto div_sqrt = f("div")(arg(0)(x_minus_mean()), arg(1)(sqrt_add_eps("sqrt")));
auto mul_rsqrt = f("mul")(either_arg(0, 1)(x_minus_mean(), sqrt_add_eps("rsqrt")));
return any(any_of(div_sqrt, mul_rsqrt));
}
auto matcher() const { return layernorm_onnx(); }
......
......@@ -31,6 +31,7 @@
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
......@@ -340,12 +341,18 @@ struct find_inner_broadcast
std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape();
return i->get_shape() != inputs.front()->get_shape() and
i->get_shape().elements() != 1;
}))
return;
auto op = m.insert_instruction(ins, ins->get_operator(), inputs);
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
auto b_it = std::find_if(broadcasts.begin(), broadcasts.end(), [&](auto i) {
return not i->get_shape().scalar();
});
if(b_it == broadcasts.end())
b_it = broadcasts.begin();
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
m.replace_instruction(ins, (*b_it)->get_operator(), op);
}
};
......
......@@ -30,14 +30,14 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct module_pass_manager;
namespace gpu {
struct prefuse_ops
{
std::string name() const { return "gpu::prefuse_ops"; }
void apply(module& m) const;
void apply(module_pass_manager& mpm) const;
};
} // namespace gpu
......
......@@ -26,6 +26,8 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -90,7 +92,9 @@ struct find_layernorm
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto eps = r.instructions["eps"]->eval().at<float>();
float eps = 0;
if(contains(r.instructions, "eps"))
eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, layernorm{eps}, x_ins);
}
......@@ -100,26 +104,26 @@ struct find_add_layernorm
{
auto matcher() const
{
return match::layernorm()(
match::var("x")(match::name("add")(match::used_once()).bind("add")));
return match::name("gpu::prelayernorm")(
match::args(match::name("add")(match::used_once()).bind("add")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto add_ins = r.instructions["add"];
float eps = 0;
if(contains(r.instructions, "eps"))
eps = r.instructions["eps"]->eval().at<float>();
auto op = any_cast<layernorm>(ins->get_operator());
m.replace_instruction(ins, add_layernorm{eps}, add_ins->inputs());
m.replace_instruction(ins, add_layernorm{op.epsilon}, add_ins->inputs());
}
};
} // namespace
void prefuse_ops::apply(module& m) const
void prefuse_ops::apply(module_pass_manager& mpm) const
{
match::find_matches(m, find_add_layernorm{}, find_layernorm{});
match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{});
}
} // namespace gpu
......
......@@ -559,6 +559,32 @@ TEST_CASE(simplify_inner_broadcast2)
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast_scalar)
{
auto b = migraphx::op::multibroadcast{{32, 384}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(b, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{1, 384}}, y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_add_conv1)
{
migraphx::module m;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment