Commit f550f814 authored by Umang Yadav's avatar Umang Yadav
Browse files

add layernorm, remove constexpr for 1/r

parent 13403ab2
...@@ -52,22 +52,25 @@ __device__ void generic_binary_layernorm( ...@@ -52,22 +52,25 @@ __device__ void generic_binary_layernorm(
block::template run<reduce_output>([&](auto, auto r) { block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2); auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type; using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); using vec_value_type = vec_type<value_type>;
constexpr auto relements_r = vec_type<value_type>{1.0 / relements}; constexpr auto relements = r.template elements<Input1>();
auto relements_rsqrt = sqrt(relements_r); auto relements_r = vec_value_type{1.0 / relements};
auto relements_rsqrt = sqrt(relements_r);
auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) { auto means = r.reduce(op::sum{},
auto x_out = x * relements_r; make_array<vec_value_type>(vec_value_type{0}, vec_value_type{0}),
// dividing x by sqrt(relements) before squaring allows computing higher values [&](auto x) {
// before overflow in low precision auto x_out = x * relements_r;
auto x2_sqrt = x * relements_rsqrt; // dividing x by sqrt(relements) before squaring allows computing
return make_array(x_out, x2_sqrt * x2_sqrt); // higher values before overflow in low precision
})(input); auto x2_sqrt = x * relements_rsqrt;
return make_array(x_out, x2_sqrt * x2_sqrt);
})(input);
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x); auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = eps; // implicit conversion for eps value_type eps_val = value_type{eps};
r.inner([&](auto& y, auto x, auto... xs) { r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x; auto m = x - mean_x;
......
...@@ -117,17 +117,18 @@ struct test_layernorm_fp16 : verify_program<test_layernorm_fp16> ...@@ -117,17 +117,18 @@ struct test_layernorm_fp16 : verify_program<test_layernorm_fp16>
} }
}; };
// struct test_layernorm_fp8 : verify_program<test_layernorm_fp8> struct test_layernorm_fp8 : verify_program<test_layernorm_fp8>
// { {
// migraphx::program create_program() const migraphx::program create_program() const
// { {
// migraphx::program p; migraphx::program p;
// auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
// std::vector<size_t> dims = {1, 24, 64}; std::vector<size_t> dims = {1, 24, 64};
// auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, dims});
// dims}); add_layernorm(*mm, x, dims); return p; add_layernorm(*mm, x, dims);
// } return p;
// }; }
};
struct test_layernorm_eps : verify_program<test_layernorm_eps> struct test_layernorm_eps : verify_program<test_layernorm_eps>
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment