Commit 85247d4d authored by Paul's avatar Paul
Browse files

Format

parent d94c54f0
...@@ -8,16 +8,17 @@ namespace migraphx { ...@@ -8,16 +8,17 @@ namespace migraphx {
template <index_int Axis, class Input, class Output> template <index_int Axis, class Input, class Output>
__device__ void layernorm(Input input, Output output) __device__ void layernorm(Input input, Output output)
{ {
constexpr auto relements = get_shape_c<reduce::with_axis<Input, Axis>>{}.elements() / get_shape_c<Input>{}.elements(); constexpr auto relements =
get_shape_c<reduce::with_axis<Input, Axis>>{}.elements() / get_shape_c<Input>{}.elements();
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
using value_type = typename Input::type; using value_type = typename Input::type;
auto mean = [&](auto f) { auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, f)(input) / value_type{relements}; return r.reduce(op::sum{}, 0, f)(input) / value_type{relements};
}; };
// mean(x) // mean(x)
auto mean_x = mean(op::id{}); auto mean_x = mean(op::id{});
// mean(m ^ 2) // mean(m ^ 2)
auto mean_m2 = mean([&](auto x) { auto mean_m2 = mean([&](auto x) {
auto m = x - mean_x; auto m = x - mean_x;
return m * m; return m * m;
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment