Commit f3aa2c67 authored by Paul's avatar Paul
Browse files

Add layernorm post fusion

parent b973563b
...@@ -1047,6 +1047,26 @@ struct find_contiguous_pointwise ...@@ -1047,6 +1047,26 @@ struct find_contiguous_pointwise
} }
}; };
struct find_layernorm_pointwise
{
auto matcher() const
{
return precompile_name("pointwise")(match::arg(0)(precompile_name("gpu::prelayernorm").bind("layernorm")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto layernorm = r.instructions["layernorm"];
auto* pm = ins->module_inputs().front();
auto inputs = ins->inputs();
inputs.front() = layernorm->inputs().front();
m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm});
}
};
void fuse_ops::apply(module& m) const void fuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_contiguous_pointwise{}, find_gelu{}, find_gelu_new{fast_math}); match::find_matches(m, find_contiguous_pointwise{}, find_gelu{}, find_gelu_new{fast_math});
...@@ -1069,6 +1089,7 @@ void fuse_ops::apply(module& m) const ...@@ -1069,6 +1089,7 @@ void fuse_ops::apply(module& m) const
match::find_matches(m, match::find_matches(m,
find_triadd_layernorm{}, find_triadd_layernorm{},
find_gemm_add{}, find_gemm_add{},
find_layernorm_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{}); match::find_matches(m, find_contiguous{});
......
...@@ -19,15 +19,19 @@ static const char* const layernorm_kernel = R"__migraphx__( ...@@ -19,15 +19,19 @@ static const char* const layernorm_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/layernorm.hpp> #include <migraphx/kernels/layernorm.hpp>
#include <migraphx/kernels/vectorize.hpp> #include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/preload.hpp>
#include <args.hpp> #include <args.hpp>
namespace migraphx { namespace migraphx {
${preamble}
extern "C" { extern "C" {
__global__ void layernorm_kernel(void* input_p, void* output_p) __global__ void ${kernel}(${params})
{ {
transform_args(make_tensors(), rotate_last(), ${transformers})(input_p, output_p)([](auto... xs) { auto idx = make_index();
layernorm<${axis}>(op::id{}, xs...); transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
layernorm<${axis}>(${post}, xs...);
}); });
} }
...@@ -52,6 +56,7 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -52,6 +56,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
{ {
vec = vectorize::elements(faxis, inputs); vec = vectorize::elements(faxis, inputs);
} }
auto preloads = preload::broadcasts(axis, inputs);
auto relements = inputs[0].lens()[axis] / vec.size; auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = inputs.back().elements() / relements; auto nelements = inputs.back().elements() / relements;
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
...@@ -60,18 +65,32 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -60,18 +65,32 @@ struct layernorm_compiler : compiler<layernorm_compiler>
v, compute_global_for(ctx, nelements * block_size, 256), block_size); v, compute_global_for(ctx, nelements * block_size, 256), block_size);
options.output = inputs.back(); options.output = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = "layernorm_kernel"; options.kernel_name = v.get("kernel", "layernorm_kernel");
auto src = interpolate_string( auto src = interpolate_string(
layernorm_kernel, layernorm_kernel,
{{"transformers", make_transformer_args(vec)}, {"axis", to_string(axis)}}); {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(preloads, vec)},
{"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})},
{"axis", to_string(axis)}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); auto v = op.to_value();
if (not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_layernorm");
v["post"] = "MIGRAPHX_LIFT(post_layernorm)";
v["kernel"] = "layernorm_" + generate_name_from_ops(*pm) + "_kernel";
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
} }
}; };
......
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \ #define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...)) [](auto&&... private_lisft_xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
namespace migraphx { namespace migraphx {
......
...@@ -34,9 +34,14 @@ struct layernorm ...@@ -34,9 +34,14 @@ struct layernorm
{ {
std::string name() const { return "gpu::prelayernorm"; } std::string name() const { return "gpu::prelayernorm"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{ {
check_shapes{inputs, *this}.has(1); std::size_t nargs = 1;
if (not mods.empty()) {
auto* pm = mods.front();
nargs = pm->get_parameter_names().size();
}
check_shapes{inputs, *this}.has(nargs);
auto s = inputs.at(0); auto s = inputs.at(0);
if(s.scalar()) if(s.scalar())
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment