Commit 85247d4d authored by Paul's avatar Paul
Browse files

Format

parent d94c54f0
......@@ -8,16 +8,17 @@ namespace migraphx {
template <index_int Axis, class Input, class Output>
__device__ void layernorm(Input input, Output output)
{
constexpr auto relements = get_shape_c<reduce::with_axis<Input, Axis>>{}.elements() / get_shape_c<Input>{}.elements();
constexpr auto relements =
get_shape_c<reduce::with_axis<Input, Axis>>{}.elements() / get_shape_c<Input>{}.elements();
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
using value_type = typename Input::type;
auto mean = [&](auto f) {
auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, f)(input) / value_type{relements};
};
// mean(x)
auto mean_x = mean(op::id{});
// mean(m ^ 2)
auto mean_m2 = mean([&](auto x) {
auto mean_m2 = mean([&](auto x) {
auto m = x - mean_x;
return m * m;
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment