Commit 9080a1df authored by Paul's avatar Paul
Browse files

Handle zero epsilon

parent 100551f7
......@@ -48,10 +48,10 @@ struct layernorm_matcher
auto layernorm_onnx() const
{
auto add_eps = f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps")));
return f("div")(arg(0)(x_minus_mean()),
arg(1)(skip_broadcasts(f("sqrt")(arg(0)(
f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps"))))))));
arg(1)(skip_broadcasts(f("sqrt")(arg(0)(match::any_of(add_eps, variance()))))));
}
auto matcher() const { return layernorm_onnx(); }
......
......@@ -104,7 +104,9 @@ struct find_add_layernorm
{
auto ins = r.result;
auto add_ins = r.instructions["add"];
auto eps = r.instructions["eps"]->eval().at<float>();
float eps = 0;
if (contains(r.instructions, "eps"))
eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, add_layernorm{eps}, add_ins->inputs());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment