Commit 1af49c6f authored by Paul's avatar Paul
Browse files

Format

parent 187a4769
......@@ -10,12 +10,14 @@ template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs)
{
using reduce_output = reduce::with_axis<Input, Axis>;
constexpr auto relements =get_shape_c<Input>{}.elements() / get_shape_c<reduce_output>{}.elements();
constexpr auto relements =
get_shape_c<Input>{}.elements() / get_shape_c<reduce_output>{}.elements();
MIGRAPHX_ASSERT(relements > 0);
reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input::type;
auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x) { return f(x) / value_type{relements}; })(input);
return r.reduce(op::sum{}, 0, [&](auto x) { return f(x) / value_type{relements}; })(
input);
};
// mean(x)
auto mean_x = mean(op::id{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment