Commit 2ae93464 authored by Paul's avatar Paul
Browse files

Format

parent 78a1dc1e
...@@ -61,10 +61,7 @@ __global__ void groupnorm_kernel(${params}) ...@@ -61,10 +61,7 @@ __global__ void groupnorm_kernel(${params})
struct groupnorm_compiler : compiler<groupnorm_compiler> struct groupnorm_compiler : compiler<groupnorm_compiler>
{ {
std::vector<std::string> names() const std::vector<std::string> names() const { return {"groupnorm"}; }
{
return {"groupnorm"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
......
...@@ -8,18 +8,17 @@ ...@@ -8,18 +8,17 @@
namespace migraphx { namespace migraphx {
template<class Output, class T> template <class Output, class T>
__device__ void groupnorm(Output out, T x0) { __device__ void groupnorm(Output out, T x0)
{
reduce::block::run<Output>([&](auto out_idx, auto r) { reduce::block::run<Output>([&](auto out_idx, auto r) {
constexpr auto relements = r.template elements<T>(); constexpr auto relements = r.template elements<T>();
auto z1 = r.reduce(op::sum{}, 0, op::mean<relements>{})(x0); auto z1 = r.reduce(op::sum{}, 0, op::mean<relements>{})(x0);
auto z4 = r.reduce(op::sum{}, 0, [&](auto x) { auto z4 = r.reduce(op::sum{}, 0, [&](auto x) {
auto diff = x - z1; auto diff = x - z1;
return (diff * diff) / vec_type<decltype(diff)>{relements}; return (diff * diff) / vec_type<decltype(diff)>{relements};
})(x0); })(x0);
r.outer([&] { r.outer([&] { out[out_idx] = migraphx::rsqrt(z4 + 1e-12); });
out[out_idx] = migraphx::rsqrt(z4 + 1e-12);
});
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment