Unverified Commit 3c67e66f authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

Modify layernorm to allow higher overflow limit for lower precision (#1534)

parent a81036cf
......@@ -48,12 +48,20 @@ __device__ void generic_binary_layernorm(
{
using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input1, Axis>()>;
using reduce_output = reduce::with_axis<Input1, Axis>;
block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>();
constexpr auto relements_r = vec_type<value_type>{1.0 / relements};
auto relements_rsqrt = sqrt(relements_r);
auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) {
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements};
auto x_out = x * relements_r;
// dividing x by sqrt(relements) before squaring allows computing higher values
// before overflow in low precision
auto x2_sqrt = x * relements_rsqrt;
return make_array(x_out, x2_sqrt * x2_sqrt);
})(input);
auto mean_x = means[0];
......
......@@ -24,6 +24,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
......@@ -34,12 +35,12 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
std::vector<size_t> dims,
float eps = 1e-12f)
{
auto scale =
m.add_parameter("scale", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
auto bias =
m.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
auto epsilon = m.add_literal(eps);
auto exponent = m.add_literal(2.0f);
auto mgx_type = x->get_shape().type();
auto scale = m.add_parameter("scale", migraphx::shape{mgx_type, {dims.back()}});
auto bias = m.add_parameter("bias", migraphx::shape{mgx_type, {dims.back()}});
auto epsilon = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {eps}});
auto exponent = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {2.0f}});
auto mean = m.add_instruction(migraphx::op::reduce_mean({2}), x);
auto mean_mbcast =
......@@ -90,6 +91,19 @@ struct test_layernorm2 : verify_program<test_layernorm2>
}
};
struct test_layernorm_fp16 : verify_program<test_layernorm_fp16>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 24, 64};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, dims});
add_layernorm(*mm, x, dims);
return p;
}
};
struct test_layernorm_eps : verify_program<test_layernorm_eps>
{
migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment