Commit 78a1dc1e authored by Paul's avatar Paul
Browse files

Add a quick groupnorm op

parent bf0a7d92
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
static const char* const groupnorm_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/groupnorm.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/preload.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void groupnorm_kernel(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
groupnorm(xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct groupnorm_compiler : compiler<groupnorm_compiler>
{
std::vector<std::string> names() const
{
return {"groupnorm"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
// TODO: Use reduce_dims
auto axis = inputs.front().lens().size() - 1;
auto faxis = find_fast_axis({inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if(axis == faxis)
{
vec = vectorize::elements(ctx, faxis, inputs);
}
auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = inputs.back().elements();
auto block_size = compute_block_size(relements, 256);
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "groupnorm_kernel";
auto src = interpolate_string(groupnorm_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(vec)}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
v["groupnorm"] = "groupnorm";
v["kernel"] = "groupnorm_kernel";
if(op.name() == "gpu::preadd_groupnorm")
{
v["groupnorm"] = "add_groupnorm";
v["kernel"] = "add_groupnorm_kernel";
}
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_groupnorm");
v["post"] = "MIGRAPHX_LIFT(post_groupnorm)";
v["kernel"] =
v["groupnorm"].to<std::string>() + "_" + generate_name_from_ops(*pm) + "_kernel";
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef GUARD_AMDMIGRAPHX_GROUP_NORM_HPP
#define GUARD_AMDMIGRAPHX_GROUP_NORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/print.hpp>
namespace migraphx {
template<class Output, class T>
__device__ void groupnorm(Output out, T x0) {
reduce::block::run<Output>([&](auto out_idx, auto r) {
constexpr auto relements = r.template elements<T>();
auto z1 = r.reduce(op::sum{}, 0, op::mean<relements>{})(x0);
auto z4 = r.reduce(op::sum{}, 0, [&](auto x) {
auto diff = x - z1;
return (diff * diff) / vec_type<decltype(diff)>{relements};
})(x0);
r.outer([&] {
out[out_idx] = migraphx::rsqrt(z4 + 1e-12);
});
});
}
} // namespace migraphx
#endif // GUARD_AMDMIGRAPHX_GROUP_NORM_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment