Commit d94c54f0 authored by Paul's avatar Paul
Browse files

Add layernorm kernel

parent 6d8e4c53
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
static const char* const layernorm_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/layernorm.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void layernorm_kernel(void* input_p, void* output_p)
{
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
layernorm<${axis}>(input, output);
});
}
}
} // namespace migraphx
)__migraphx__";
struct layernorm_compiler : compiler<layernorm_compiler>
{
std::vector<std::string> names() const { return {"layernorm"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
// TODO: Use reduce_dims
auto axis = inputs.front().lens().size() - 1;
auto faxis = find_fast_axis({inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if(inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(faxis, inputs);
}
auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = inputs.back().elements() / relements;
auto block_size = compute_block_size(relements, 256);
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "layernorm_kernel";
auto src = interpolate_string(
layernorm_kernel,
{{"transformers", make_transformer_args(vec)}, {"axis", to_string(axis)}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
namespace migraphx {
template <index_int Axis, class Input, class Output>
__device__ void layernorm(Input input, Output output)
{
constexpr auto relements = get_shape_c<reduce::with_axis<Input, Axis>>{}.elements() / get_shape_c<Input>{}.elements();
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
using value_type = typename Input::type;
auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, f)(input) / value_type{relements};
};
// mean(x)
auto mean_x = mean(op::id{});
// mean(m ^ 2)
auto mean_m2 = mean([&](auto x) {
auto m = x - mean_x;
return m * m;
});
r.inner([&](auto& y, auto x) {
auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y = m * rsqrt(mean_m2 + value_type{1e-12});
})(output, input);
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
......@@ -6,6 +6,29 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct layernorm
{
std::string name() const { return "gpu::prelayernorm"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
if(s.scalar())
{
return s;
}
else if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
}
};
namespace {
struct find_layernorm
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment