Commit 5b8c054e authored by Paul's avatar Paul
Browse files

Add add_layernorm fusion

parent 610f0417
...@@ -825,13 +825,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r) ...@@ -825,13 +825,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
inline auto precompile_name(std::string s) // NOLINT template<class... Strings>
inline auto precompile_name(Strings... names) // NOLINT
{ {
return match::make_basic_pred_matcher([=](instruction_ref ins) { return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "gpu::precompile_op") if(ins->name() != "gpu::precompile_op")
return false; return false;
auto op = from_value<operation>(ins->get_operator().to_value().at("op")); auto op = from_value<operation>(ins->get_operator().to_value().at("op"));
return (op.name() == s); return (contains({names...}, op.name()));
}); });
} }
...@@ -1052,7 +1053,7 @@ struct find_layernorm_pointwise ...@@ -1052,7 +1053,7 @@ struct find_layernorm_pointwise
auto matcher() const auto matcher() const
{ {
return precompile_name("pointwise")( return precompile_name("pointwise")(
match::arg(0)(precompile_name("gpu::prelayernorm").bind("layernorm"))); match::arg(0)(precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -1061,8 +1062,12 @@ struct find_layernorm_pointwise ...@@ -1061,8 +1062,12 @@ struct find_layernorm_pointwise
auto layernorm = r.instructions["layernorm"]; auto layernorm = r.instructions["layernorm"];
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
auto inputs = ins->inputs(); if (not layernorm->module_inputs().empty())
inputs.front() = layernorm->inputs().front(); return;
auto inputs = layernorm->inputs();
inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin()+1, ins->inputs().end());
m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm}); m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm});
} }
......
...@@ -31,7 +31,7 @@ __global__ void ${kernel}(${params}) ...@@ -31,7 +31,7 @@ __global__ void ${kernel}(${params})
{ {
auto idx = make_index(); auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
layernorm<${axis}>(${post}, xs...); ${layernorm}<${axis}>(${post}, xs...);
}); });
} }
...@@ -43,7 +43,7 @@ __global__ void ${kernel}(${params}) ...@@ -43,7 +43,7 @@ __global__ void ${kernel}(${params})
struct layernorm_compiler : compiler<layernorm_compiler> struct layernorm_compiler : compiler<layernorm_compiler>
{ {
std::vector<std::string> names() const { return {"layernorm", "gpu::prelayernorm"}; } std::vector<std::string> names() const { return {"layernorm", "gpu::prelayernorm", "gpu::preadd_layernorm"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
...@@ -74,6 +74,7 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -74,6 +74,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
{"transformers", make_transformer_args(preloads, vec)}, {"transformers", make_transformer_args(preloads, vec)},
{"post", v.get("post", std::string{"op::id{}"})}, {"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})}, {"preamble", v.get("preamble", std::string{})},
{"layernorm", v.get("layernorm", std::string{"layernorm"})},
{"axis", to_string(axis)}}); {"axis", to_string(axis)}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
...@@ -82,12 +83,19 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -82,12 +83,19 @@ struct layernorm_compiler : compiler<layernorm_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
auto v = op.to_value(); auto v = op.to_value();
v["layernorm"] = "layernorm";
v["kernel"] = "layernorm_kernel";
if (op.name() == "gpu::preadd_layernorm")
{
v["layernorm"] = "add_layernorm";
v["kernel"] = "add_layernorm_kernel";
}
if(not ins->module_inputs().empty()) if(not ins->module_inputs().empty())
{ {
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_layernorm"); v["preamble"] = generate_pointwise(*pm, "post_layernorm");
v["post"] = "MIGRAPHX_LIFT(post_layernorm)"; v["post"] = "MIGRAPHX_LIFT(post_layernorm)";
v["kernel"] = "layernorm_" + generate_name_from_ops(*pm) + "_kernel"; v["kernel"] = v["layernorm"].to<std::string>() + "_" + generate_name_from_ops(*pm) + "_kernel";
} }
return replace(compile_op(ctx, to_shapes(ins->inputs()), v)); return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
} }
......
...@@ -6,34 +6,47 @@ ...@@ -6,34 +6,47 @@
namespace migraphx { namespace migraphx {
template <index_int Axis, class F, class Output, class Input, class... Inputs> template <index_int Axis, class F, class BinOp, class Output, class Input1, class Input2, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs) __device__ void generic_binary_layernorm(F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
using reduce_output = reduce::with_axis<Input, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
constexpr auto relements = constexpr auto relements =
get_shape_c<Input>{}.elements() / get_shape_c<reduce_output>{}.elements(); get_shape_c<Input1>{}.elements() / get_shape_c<reduce_output>{}.elements();
MIGRAPHX_ASSERT(relements > 0); MIGRAPHX_ASSERT(relements > 0);
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input::type; using value_type = typename Input1::type;
auto mean = [&](auto f) { auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x) { return f(x) / value_type{relements}; })( return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) { return f(x1, x2) / value_type{relements}; })(
input); input1, input2);
}; };
// mean(x) // mean(x)
auto mean_x = mean(op::id{}); auto mean_x = mean(op);
// mean(m ^ 2) // mean(m ^ 2)
auto mean_m2 = mean([&](auto x) { auto mean_m2 = mean([&](auto x1, auto x2) {
auto m = x - mean_x; auto m = op(x1, x2) - mean_x;
return m * m; return m * m;
}); });
r.inner([&](auto& y, auto x, auto... xs) { r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto m = x - mean_x; auto m = op(x1, x2) - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12) // m * rsqrt(mean(m ^ 2) + 1e-12)
y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...); y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...);
})(output, input, inputs...); })(output, input1, input2, inputs...);
}); });
} }
template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs)
{
generic_binary_layernorm<Axis>(compute, [](auto x, auto) { return x; }, output, input, input, inputs...);
}
template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs>
__device__ void add_layernorm(F compute, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
generic_binary_layernorm<Axis>(compute, [](auto x1, auto x2) { return x1 + x2; }, output, input1, input2, inputs...);
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP #endif // MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
...@@ -30,10 +30,10 @@ namespace migraphx { ...@@ -30,10 +30,10 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace { namespace {
struct layernorm
{
std::string name() const { return "gpu::prelayernorm"; }
template<class Derived, std::size_t N>
struct layernorm_base
{
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{ {
std::size_t nargs = 1; std::size_t nargs = 1;
...@@ -42,7 +42,7 @@ struct layernorm ...@@ -42,7 +42,7 @@ struct layernorm
auto* pm = mods.front(); auto* pm = mods.front();
nargs = pm->get_parameter_names().size(); nargs = pm->get_parameter_names().size();
} }
check_shapes{inputs, *this}.has(nargs); check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs+N);
auto s = inputs.at(0); auto s = inputs.at(0);
if(s.scalar()) if(s.scalar())
{ {
...@@ -58,8 +58,19 @@ struct layernorm ...@@ -58,8 +58,19 @@ struct layernorm
} }
} }
}; };
struct layernorm : layernorm_base<layernorm, 0>
{
std::string name() const { return "gpu::prelayernorm"; }
};
MIGRAPHX_REGISTER_OP(layernorm); MIGRAPHX_REGISTER_OP(layernorm);
struct add_layernorm : layernorm_base<add_layernorm, 1>
{
std::string name() const { return "gpu::preadd_layernorm"; }
};
MIGRAPHX_REGISTER_OP(add_layernorm);
struct find_layernorm struct find_layernorm
{ {
auto matcher() const { return match::layernorm(); } auto matcher() const { return match::layernorm(); }
...@@ -73,6 +84,19 @@ struct find_layernorm ...@@ -73,6 +84,19 @@ struct find_layernorm
} }
}; };
struct find_add_layernorm
{
auto matcher() const { return match::layernorm()(match::var("x")(match::name("add").bind("add"))); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto add_ins = r.instructions["add"];
m.replace_instruction(ins, add_layernorm{}, add_ins->inputs());
}
};
struct find_gpulayernorm struct find_gpulayernorm
{ {
auto matcher() const { return match::layernorm(); } auto matcher() const { return match::layernorm(); }
...@@ -134,7 +158,7 @@ struct find_gputriaddlayernorm ...@@ -134,7 +158,7 @@ struct find_gputriaddlayernorm
void prefuse_ops::apply(module& m) const void prefuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_layernorm{}); match::find_matches(m, find_add_layernorm{}, find_layernorm{});
// match::find_matches(m, find_gputriaddlayernorm{}, find_gpulayernorm{}); // match::find_matches(m, find_gputriaddlayernorm{}, find_gpulayernorm{});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment