Commit 187a4769 authored by Paul's avatar Paul
Browse files

Fix div by zero issue

parent a62ef598
......@@ -26,8 +26,8 @@ namespace migraphx {
extern "C" {
__global__ void layernorm_kernel(void* input_p, void* output_p)
{
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
layernorm<${axis}>(input, output);
transform_args(make_tensors(), rotate_last(), ${transformers})(input_p, output_p)([](auto... xs) {
layernorm<${axis}>(op::id{}, xs...);
});
}
......
......@@ -2,18 +2,20 @@
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/print.hpp>
namespace migraphx {
template <index_int Axis, class Input, class Output>
__device__ void layernorm(Input input, Output output)
template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs)
{
constexpr auto relements =
get_shape_c<reduce::with_axis<Input, Axis>>{}.elements() / get_shape_c<Input>{}.elements();
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
using reduce_output = reduce::with_axis<Input, Axis>;
constexpr auto relements =get_shape_c<Input>{}.elements() / get_shape_c<reduce_output>{}.elements();
MIGRAPHX_ASSERT(relements > 0);
reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input::type;
auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, f)(input) / value_type{relements};
return r.reduce(op::sum{}, 0, [&](auto x) { return f(x) / value_type{relements}; })(input);
};
// mean(x)
auto mean_x = mean(op::id{});
......@@ -23,11 +25,11 @@ __device__ void layernorm(Input input, Output output)
return m * m;
});
r.inner([&](auto& y, auto x) {
r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y = m * rsqrt(mean_m2 + value_type{1e-12});
})(output, input);
y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...);
})(output, input, inputs...);
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment