"src/include/amd_inline_asm.hpp" did not exist on "46a0aec185f614b413254b49f32d105cd8a7e1fb"
Unverified Commit d9578ba6 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Parameterize epsilon for layernorm kernel (#1367)

This PR allows for other values of epsilon to be matched when finding layernorm. Similarly, the calculation now uses the variable for epsilon.
parent 9a70050b
...@@ -50,8 +50,8 @@ struct layernorm_matcher ...@@ -50,8 +50,8 @@ struct layernorm_matcher
{ {
return f("div")(arg(0)(x_minus_mean()), return f("div")(arg(0)(x_minus_mean()),
arg(1)(skip_broadcasts(f("sqrt")( arg(1)(skip_broadcasts(f("sqrt")(arg(0)(
arg(0)(f("add")(either_arg(0, 1)(variance(), has_value(1e-12f)))))))); f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps"))))))));
} }
auto matcher() const { return layernorm_onnx(); } auto matcher() const { return layernorm_onnx(); }
......
...@@ -52,7 +52,7 @@ __global__ void ${kernel}(${params}) ...@@ -52,7 +52,7 @@ __global__ void ${kernel}(${params})
{ {
auto idx = make_index(); auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, xs...); ${layernorm}<${axis}>(${post}, ${eps}, xs...);
}); });
} }
...@@ -90,6 +90,7 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -90,6 +90,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
options.output = inputs.back(); options.output = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = v.get("kernel", "layernorm_kernel"); options.kernel_name = v.get("kernel", "layernorm_kernel");
auto eps = v.get("epsilon", 1e-12f);
auto src = interpolate_string(layernorm_kernel, auto src = interpolate_string(layernorm_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
...@@ -99,7 +100,8 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -99,7 +100,8 @@ struct layernorm_compiler : compiler<layernorm_compiler>
{"post", v.get("post", std::string{"op::id{}"})}, {"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})}, {"preamble", v.get("preamble", std::string{})},
{"layernorm", v.get("layernorm", std::string{"layernorm"})}, {"layernorm", v.get("layernorm", std::string{"layernorm"})},
{"axis", to_string(axis)}}); {"axis", to_string(axis)},
{"eps", to_string(eps)}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -43,7 +43,7 @@ template <index_int Axis, ...@@ -43,7 +43,7 @@ template <index_int Axis,
class Input2, class Input2,
class... Inputs> class... Inputs>
__device__ void generic_binary_layernorm( __device__ void generic_binary_layernorm(
F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs) F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
...@@ -55,32 +55,34 @@ __device__ void generic_binary_layernorm( ...@@ -55,32 +55,34 @@ __device__ void generic_binary_layernorm(
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements}; return make_array(x, x * x) * vec_type<value_type>{1.0 / relements};
})(input1, input2); })(input1, input2);
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x); auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = eps; // implicit conversion for eps
r.inner([&](auto& y, auto x1, auto x2, auto... xs) { r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto x = op(x1, x2); auto x = op(x1, x2);
auto m = x - mean_x; auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y = compute(m * rsqrt(variance + value_type{1e-12}), xs...); // m * rsqrt(mean(m ^ 2) + epsilon)
y = compute(m * rsqrt(variance + eps_val), xs...);
})(output, input1, input2, inputs...); })(output, input1, input2, inputs...);
}); });
} }
template <index_int Axis, class F, class Output, class Input, class... Inputs> template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs) __device__ void layernorm(F compute, float eps, Output output, Input input, Inputs... inputs)
{ {
generic_binary_layernorm<Axis>( generic_binary_layernorm<Axis>(
compute, [](auto x, auto) { return x; }, output, input, input, inputs...); compute, [](auto x, auto) { return x; }, eps, output, input, input, inputs...);
} }
template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs> template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs>
__device__ void __device__ void
add_layernorm(F compute, Output output, Input1 input1, Input2 input2, Inputs... inputs) add_layernorm(F compute, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
generic_binary_layernorm<Axis>( generic_binary_layernorm<Axis>(
compute, [](auto x1, auto x2) { return x1 + x2; }, output, input1, input2, inputs...); compute, [](auto x1, auto x2) { return x1 + x2; }, eps, output, input1, input2, inputs...);
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -35,6 +35,12 @@ namespace { ...@@ -35,6 +35,12 @@ namespace {
template <class Derived, std::size_t N> template <class Derived, std::size_t N>
struct layernorm_base struct layernorm_base
{ {
float epsilon = 1e-12f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.epsilon, "epsilon"));
}
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{ {
std::size_t nargs = 1; std::size_t nargs = 1;
...@@ -62,6 +68,7 @@ struct layernorm_base ...@@ -62,6 +68,7 @@ struct layernorm_base
struct layernorm : layernorm_base<layernorm, 0> struct layernorm : layernorm_base<layernorm, 0>
{ {
std::string name() const { return "gpu::prelayernorm"; } std::string name() const { return "gpu::prelayernorm"; }
}; };
MIGRAPHX_REGISTER_OP(layernorm); MIGRAPHX_REGISTER_OP(layernorm);
...@@ -80,8 +87,9 @@ struct find_layernorm ...@@ -80,8 +87,9 @@ struct find_layernorm
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, layernorm{}, x_ins); m.replace_instruction(ins, layernorm{eps}, x_ins);
} }
}; };
...@@ -96,8 +104,9 @@ struct find_add_layernorm ...@@ -96,8 +104,9 @@ struct find_add_layernorm
{ {
auto ins = r.result; auto ins = r.result;
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, add_layernorm{}, add_ins->inputs()); m.replace_instruction(ins, add_layernorm{eps}, add_ins->inputs());
} }
}; };
} // namespace } // namespace
......
...@@ -29,14 +29,16 @@ ...@@ -29,14 +29,16 @@
#include <migraphx/op/reduce_mean.hpp> #include <migraphx/op/reduce_mean.hpp>
migraphx::instruction_ref migraphx::instruction_ref add_layernorm(migraphx::module& m,
add_layernorm(migraphx::module& m, migraphx::instruction_ref x, std::vector<size_t> dims) migraphx::instruction_ref x,
std::vector<size_t> dims,
float eps = 1e-12f)
{ {
auto scale = auto scale =
m.add_parameter("scale", migraphx::shape{migraphx::shape::float_type, {dims.back()}}); m.add_parameter("scale", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
auto bias = auto bias =
m.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {dims.back()}}); m.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
auto epsilon = m.add_literal(1e-12f); auto epsilon = m.add_literal(eps);
auto exponent = m.add_literal(2.0f); auto exponent = m.add_literal(2.0f);
auto mean = m.add_instruction(migraphx::op::reduce_mean({2}), x); auto mean = m.add_instruction(migraphx::op::reduce_mean({2}), x);
...@@ -88,6 +90,19 @@ struct test_layernorm2 : verify_program<test_layernorm2> ...@@ -88,6 +90,19 @@ struct test_layernorm2 : verify_program<test_layernorm2>
} }
}; };
struct test_layernorm_eps : verify_program<test_layernorm_eps>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 2, 5};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
add_layernorm(*mm, x, dims, 1e-5f);
return p;
}
};
struct test_layernorm_triadd : verify_program<test_layernorm_triadd> struct test_layernorm_triadd : verify_program<test_layernorm_triadd>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment