"examples/research_projects/controlnetxs/README_sdxl.md" did not exist on "78b87dc25aa3cb5eab282354d9b001b90a75cca4"
Unverified Commit 3e58b1e4 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Update matcher to better match layernorm fusions in other models (#1548)



* Fuse layernorm with different patterns
* Only match when using the last axis
Co-authored-by: default avatarkahmed10 <15948690+kahmed10@users.noreply.github.com>
Co-authored-by: default avatarkahmed10 <15948690+kahmed10@users.noreply.github.com>
parent edda9f93
......@@ -36,23 +36,46 @@ template <class F>
struct layernorm_matcher
{
F f;
auto last_axis() const
{
return make_basic_pred_matcher([](instruction_ref ins) {
auto v = ins->get_operator().to_value();
if(not v.contains("axes"))
return false;
auto axes = v["axes"].to_vector<std::size_t>();
if(axes.size() != 1)
return false;
return axes.front() == ins->inputs().front()->get_shape().lens().size() - 1;
});
}
auto reduce_mean() const { return f("reduce_mean")(last_axis()); }
auto x_minus_mean() const
{
return f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(f("reduce_mean"))));
return f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(reduce_mean())));
}
auto variance() const
{
return f("reduce_mean")(arg(0)(f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f)))));
return reduce_mean()(arg(0)(any_of(
f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))),
f("mul")(arg(0)(x_minus_mean()), arg(1)(x_minus_mean())),
f("sqdiff")(either_arg(0, 1)(any().bind("x"), skip_broadcasts(reduce_mean()))))));
}
auto layernorm_onnx() const
auto sqrt_add_eps(const std::string& name) const
{
auto add_eps = f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps")));
return f("div")(
arg(0)(x_minus_mean()),
return skip_broadcasts(f(name)(arg(0)(any_of(add_eps, variance()))));
}
arg(1)(skip_broadcasts(f("sqrt")(arg(0)(match::any_of(add_eps, variance()))))));
auto layernorm_onnx() const
{
auto div_sqrt = f("div")(arg(0)(x_minus_mean()), arg(1)(sqrt_add_eps("sqrt")));
auto mul_rsqrt = f("mul")(either_arg(0, 1)(x_minus_mean(), sqrt_add_eps("rsqrt")));
return any(any_of(div_sqrt, mul_rsqrt));
}
auto matcher() const { return layernorm_onnx(); }
......
......@@ -31,6 +31,7 @@
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
......@@ -340,12 +341,18 @@ struct find_inner_broadcast
std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape();
return i->get_shape() != inputs.front()->get_shape() and
i->get_shape().elements() != 1;
}))
return;
auto op = m.insert_instruction(ins, ins->get_operator(), inputs);
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
auto b_it = std::find_if(broadcasts.begin(), broadcasts.end(), [&](auto i) {
return not i->get_shape().scalar();
});
if(b_it == broadcasts.end())
b_it = broadcasts.begin();
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
m.replace_instruction(ins, (*b_it)->get_operator(), op);
}
};
......
......@@ -30,14 +30,14 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct module_pass_manager;
namespace gpu {
struct prefuse_ops
{
std::string name() const { return "gpu::prefuse_ops"; }
void apply(module& m) const;
void apply(module_pass_manager& mpm) const;
};
} // namespace gpu
......
......@@ -26,6 +26,8 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -90,7 +92,9 @@ struct find_layernorm
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto eps = r.instructions["eps"]->eval().at<float>();
float eps = 0;
if(contains(r.instructions, "eps"))
eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, layernorm{eps}, x_ins);
}
......@@ -100,26 +104,26 @@ struct find_add_layernorm
{
auto matcher() const
{
return match::layernorm()(
match::var("x")(match::name("add")(match::used_once()).bind("add")));
return match::name("gpu::prelayernorm")(
match::args(match::name("add")(match::used_once()).bind("add")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto add_ins = r.instructions["add"];
float eps = 0;
if(contains(r.instructions, "eps"))
eps = r.instructions["eps"]->eval().at<float>();
auto op = any_cast<layernorm>(ins->get_operator());
m.replace_instruction(ins, add_layernorm{eps}, add_ins->inputs());
m.replace_instruction(ins, add_layernorm{op.epsilon}, add_ins->inputs());
}
};
} // namespace
void prefuse_ops::apply(module& m) const
void prefuse_ops::apply(module_pass_manager& mpm) const
{
match::find_matches(m, find_add_layernorm{}, find_layernorm{});
match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{});
}
} // namespace gpu
......
......@@ -559,6 +559,32 @@ TEST_CASE(simplify_inner_broadcast2)
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast_scalar)
{
auto b = migraphx::op::multibroadcast{{32, 384}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(b, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{1, 384}}, y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_add_conv1)
{
migraphx::module m;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment