Unverified Commit 3c67e66f authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

Modify layernorm to allow higher overflow limit for lower precision (#1534)

parent a81036cf
...@@ -48,12 +48,20 @@ __device__ void generic_binary_layernorm( ...@@ -48,12 +48,20 @@ __device__ void generic_binary_layernorm(
{ {
using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input1, Axis>()>; using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input1, Axis>()>;
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
block::template run<reduce_output>([&](auto, auto r) { block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2); auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type; using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
constexpr auto relements_r = vec_type<value_type>{1.0 / relements};
auto relements_rsqrt = sqrt(relements_r);
auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) { auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) {
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements}; auto x_out = x * relements_r;
// dividing x by sqrt(relements) before squaring allows computing higher values
// before overflow in low precision
auto x2_sqrt = x * relements_rsqrt;
return make_array(x_out, x2_sqrt * x2_sqrt);
})(input); })(input);
auto mean_x = means[0]; auto mean_x = means[0];
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -34,12 +35,12 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m, ...@@ -34,12 +35,12 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
std::vector<size_t> dims, std::vector<size_t> dims,
float eps = 1e-12f) float eps = 1e-12f)
{ {
auto scale = auto mgx_type = x->get_shape().type();
m.add_parameter("scale", migraphx::shape{migraphx::shape::float_type, {dims.back()}}); auto scale = m.add_parameter("scale", migraphx::shape{mgx_type, {dims.back()}});
auto bias = auto bias = m.add_parameter("bias", migraphx::shape{mgx_type, {dims.back()}});
m.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
auto epsilon = m.add_literal(eps); auto epsilon = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {eps}});
auto exponent = m.add_literal(2.0f); auto exponent = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {2.0f}});
auto mean = m.add_instruction(migraphx::op::reduce_mean({2}), x); auto mean = m.add_instruction(migraphx::op::reduce_mean({2}), x);
auto mean_mbcast = auto mean_mbcast =
...@@ -90,6 +91,19 @@ struct test_layernorm2 : verify_program<test_layernorm2> ...@@ -90,6 +91,19 @@ struct test_layernorm2 : verify_program<test_layernorm2>
} }
}; };
struct test_layernorm_fp16 : verify_program<test_layernorm_fp16>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 24, 64};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, dims});
add_layernorm(*mm, x, dims);
return p;
}
};
struct test_layernorm_eps : verify_program<test_layernorm_eps> struct test_layernorm_eps : verify_program<test_layernorm_eps>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment