Commit 5b8c054e authored by Paul's avatar Paul
Browse files

Add add_layernorm fusion

parent 610f0417
......@@ -825,13 +825,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
}
inline auto precompile_name(std::string s) // NOLINT
template<class... Strings>
inline auto precompile_name(Strings... names) // NOLINT
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "gpu::precompile_op")
return false;
auto op = from_value<operation>(ins->get_operator().to_value().at("op"));
return (op.name() == s);
return (contains({names...}, op.name()));
});
}
......@@ -1052,7 +1053,7 @@ struct find_layernorm_pointwise
auto matcher() const
{
return precompile_name("pointwise")(
match::arg(0)(precompile_name("gpu::prelayernorm").bind("layernorm")));
match::arg(0)(precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm")));
}
void apply(module& m, const match::matcher_result& r) const
......@@ -1061,8 +1062,12 @@ struct find_layernorm_pointwise
auto layernorm = r.instructions["layernorm"];
auto* pm = ins->module_inputs().front();
auto inputs = ins->inputs();
inputs.front() = layernorm->inputs().front();
if (not layernorm->module_inputs().empty())
return;
auto inputs = layernorm->inputs();
inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin()+1, ins->inputs().end());
m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm});
}
......
......@@ -31,7 +31,7 @@ __global__ void ${kernel}(${params})
{
auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
layernorm<${axis}>(${post}, xs...);
${layernorm}<${axis}>(${post}, xs...);
});
}
......@@ -43,7 +43,7 @@ __global__ void ${kernel}(${params})
struct layernorm_compiler : compiler<layernorm_compiler>
{
std::vector<std::string> names() const { return {"layernorm", "gpu::prelayernorm"}; }
std::vector<std::string> names() const { return {"layernorm", "gpu::prelayernorm", "gpu::preadd_layernorm"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
......@@ -74,6 +74,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
{"transformers", make_transformer_args(preloads, vec)},
{"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})},
{"layernorm", v.get("layernorm", std::string{"layernorm"})},
{"axis", to_string(axis)}});
return compile_hip_code_object(src, options);
......@@ -82,12 +83,19 @@ struct layernorm_compiler : compiler<layernorm_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
v["layernorm"] = "layernorm";
v["kernel"] = "layernorm_kernel";
if (op.name() == "gpu::preadd_layernorm")
{
v["layernorm"] = "add_layernorm";
v["kernel"] = "add_layernorm_kernel";
}
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_layernorm");
v["post"] = "MIGRAPHX_LIFT(post_layernorm)";
v["kernel"] = "layernorm_" + generate_name_from_ops(*pm) + "_kernel";
v["kernel"] = v["layernorm"].to<std::string>() + "_" + generate_name_from_ops(*pm) + "_kernel";
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
......
......@@ -6,34 +6,47 @@
namespace migraphx {
template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs)
template <index_int Axis, class F, class BinOp, class Output, class Input1, class Input2, class... Inputs>
__device__ void generic_binary_layernorm(F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
using reduce_output = reduce::with_axis<Input, Axis>;
using reduce_output = reduce::with_axis<Input1, Axis>;
constexpr auto relements =
get_shape_c<Input>{}.elements() / get_shape_c<reduce_output>{}.elements();
get_shape_c<Input1>{}.elements() / get_shape_c<reduce_output>{}.elements();
MIGRAPHX_ASSERT(relements > 0);
reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input::type;
using value_type = typename Input1::type;
auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x) { return f(x) / value_type{relements}; })(
input);
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) { return f(x1, x2) / value_type{relements}; })(
input1, input2);
};
// mean(x)
auto mean_x = mean(op::id{});
auto mean_x = mean(op);
// mean(m ^ 2)
auto mean_m2 = mean([&](auto x) {
auto m = x - mean_x;
auto mean_m2 = mean([&](auto x1, auto x2) {
auto m = op(x1, x2) - mean_x;
return m * m;
});
r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x;
r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto m = op(x1, x2) - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...);
})(output, input, inputs...);
})(output, input1, input2, inputs...);
});
}
template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs)
{
generic_binary_layernorm<Axis>(compute, [](auto x, auto) { return x; }, output, input, input, inputs...);
}
template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs>
__device__ void add_layernorm(F compute, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
generic_binary_layernorm<Axis>(compute, [](auto x1, auto x2) { return x1 + x2; }, output, input1, input2, inputs...);
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
......@@ -30,10 +30,10 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace {
struct layernorm
{
std::string name() const { return "gpu::prelayernorm"; }
template<class Derived, std::size_t N>
struct layernorm_base
{
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{
std::size_t nargs = 1;
......@@ -42,7 +42,7 @@ struct layernorm
auto* pm = mods.front();
nargs = pm->get_parameter_names().size();
}
check_shapes{inputs, *this}.has(nargs);
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs+N);
auto s = inputs.at(0);
if(s.scalar())
{
......@@ -58,8 +58,19 @@ struct layernorm
}
}
};
struct layernorm : layernorm_base<layernorm, 0>
{
std::string name() const { return "gpu::prelayernorm"; }
};
MIGRAPHX_REGISTER_OP(layernorm);
struct add_layernorm : layernorm_base<add_layernorm, 1>
{
std::string name() const { return "gpu::preadd_layernorm"; }
};
MIGRAPHX_REGISTER_OP(add_layernorm);
struct find_layernorm
{
auto matcher() const { return match::layernorm(); }
......@@ -73,6 +84,19 @@ struct find_layernorm
}
};
struct find_add_layernorm
{
auto matcher() const { return match::layernorm()(match::var("x")(match::name("add").bind("add"))); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto add_ins = r.instructions["add"];
m.replace_instruction(ins, add_layernorm{}, add_ins->inputs());
}
};
struct find_gpulayernorm
{
auto matcher() const { return match::layernorm(); }
......@@ -134,7 +158,7 @@ struct find_gputriaddlayernorm
void prefuse_ops::apply(module& m) const
{
match::find_matches(m, find_layernorm{});
match::find_matches(m, find_add_layernorm{}, find_layernorm{});
// match::find_matches(m, find_gputriaddlayernorm{}, find_gpulayernorm{});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment