Commit 95665ff2 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

take in epsilon as parameter

parent 928da1f0
......@@ -51,8 +51,7 @@ struct layernorm_matcher
return f("div")(arg(0)(x_minus_mean()),
arg(1)(skip_broadcasts(f("sqrt")(arg(0)(f("add")(
either_arg(0, 1)(variance(), is_constant()))))))); // 71.7596/sec
// arg(0)(f("add")(either_arg(0, 1)(variance(), has_value(1e-12f)))))))); // 70.8157/sec
either_arg(0, 1)(variance(), is_constant().bind("eps"))))))));
}
auto matcher() const { return layernorm_onnx(); }
......
......@@ -54,7 +54,7 @@ __global__ void ${kernel}(${params})
{
auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, xs...);
${layernorm}<${axis}>(${post}, ${eps}, xs...);
});
}
......@@ -92,6 +92,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = v.get("kernel", "layernorm_kernel");
auto eps = v.get("epsilon", 1e-12f);
auto src = interpolate_string(layernorm_kernel,
{{"kernel", options.kernel_name},
......@@ -101,7 +102,8 @@ struct layernorm_compiler : compiler<layernorm_compiler>
{"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})},
{"layernorm", v.get("layernorm", std::string{"layernorm"})},
{"axis", to_string(axis)}});
{"axis", to_string(axis)},
{"eps", to_string(eps)}});
return compile_hip_code_object(src, options);
}
......
......@@ -43,7 +43,7 @@ template <index_int Axis,
class Input2,
class... Inputs>
__device__ void generic_binary_layernorm(
F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs)
F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) {
......@@ -58,29 +58,30 @@ __device__ void generic_binary_layernorm(
auto mean_x = means[0];
auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x);
auto eps_val = static_cast<value_type>(eps);
r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto x = op(x1, x2);
auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + 1.00136e-05)
y = compute(m * rsqrt(variance + value_type{1.00136e-05}), xs...);
// m * rsqrt(mean(m ^ 2) + epsilon)
y = compute(m * rsqrt(variance + eps_val), xs...);
})(output, input1, input2, inputs...);
});
}
template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs)
__device__ void layernorm(F compute, float eps, Output output, Input input, Inputs... inputs)
{
generic_binary_layernorm<Axis>(
compute, [](auto x, auto) { return x; }, output, input, input, inputs...);
compute, [](auto x, auto) { return x; }, eps, output, input, input, inputs...);
}
template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs>
__device__ void
add_layernorm(F compute, Output output, Input1 input1, Input2 input2, Inputs... inputs)
add_layernorm(F compute, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
generic_binary_layernorm<Axis>(
compute, [](auto x1, auto x2) { return x1 + x2; }, output, input1, input2, inputs...);
compute, [](auto x1, auto x2) { return x1 + x2; }, eps, output, input1, input2, inputs...);
}
} // namespace migraphx
......
......@@ -34,6 +34,13 @@ namespace {
template <class Derived, std::size_t N>
struct layernorm_base
{
float epsilon = 1e-12f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.epsilon, "epsilon"));
}
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{
std::size_t nargs = 1;
......@@ -61,6 +68,7 @@ struct layernorm_base
struct layernorm : layernorm_base<layernorm, 0>
{
std::string name() const { return "gpu::prelayernorm"; }
};
MIGRAPHX_REGISTER_OP(layernorm);
......@@ -79,8 +87,9 @@ struct find_layernorm
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, layernorm{}, x_ins);
m.replace_instruction(ins, layernorm{eps}, x_ins);
}
};
......@@ -95,8 +104,9 @@ struct find_add_layernorm
{
auto ins = r.result;
auto add_ins = r.instructions["add"];
auto eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, add_layernorm{}, add_ins->inputs());
m.replace_instruction(ins, add_layernorm{eps}, add_ins->inputs());
}
};
} // namespace
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment