Commit 187a4769 authored by Paul's avatar Paul
Browse files

Fix div by zero issue

parent a62ef598
...@@ -26,8 +26,8 @@ namespace migraphx { ...@@ -26,8 +26,8 @@ namespace migraphx {
extern "C" { extern "C" {
__global__ void layernorm_kernel(void* input_p, void* output_p) __global__ void layernorm_kernel(void* input_p, void* output_p)
{ {
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) { transform_args(make_tensors(), rotate_last(), ${transformers})(input_p, output_p)([](auto... xs) {
layernorm<${axis}>(input, output); layernorm<${axis}>(op::id{}, xs...);
}); });
} }
......
...@@ -2,18 +2,20 @@ ...@@ -2,18 +2,20 @@
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP #define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp> #include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp> #include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/print.hpp>
namespace migraphx { namespace migraphx {
template <index_int Axis, class Input, class Output> template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(Input input, Output output) __device__ void layernorm(F compute, Output output, Input input, Inputs... inputs)
{ {
constexpr auto relements = using reduce_output = reduce::with_axis<Input, Axis>;
get_shape_c<reduce::with_axis<Input, Axis>>{}.elements() / get_shape_c<Input>{}.elements(); constexpr auto relements =get_shape_c<Input>{}.elements() / get_shape_c<reduce_output>{}.elements();
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { MIGRAPHX_ASSERT(relements > 0);
reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input::type; using value_type = typename Input::type;
auto mean = [&](auto f) { auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, f)(input) / value_type{relements}; return r.reduce(op::sum{}, 0, [&](auto x) { return f(x) / value_type{relements}; })(input);
}; };
// mean(x) // mean(x)
auto mean_x = mean(op::id{}); auto mean_x = mean(op::id{});
...@@ -23,11 +25,11 @@ __device__ void layernorm(Input input, Output output) ...@@ -23,11 +25,11 @@ __device__ void layernorm(Input input, Output output)
return m * m; return m * m;
}); });
r.inner([&](auto& y, auto x) { r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x; auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12) // m * rsqrt(mean(m ^ 2) + 1e-12)
y = m * rsqrt(mean_m2 + value_type{1e-12}); y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...);
})(output, input); })(output, input, inputs...);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment