Commit d78877a7 authored by rocking's avatar rocking
Browse files

Extract divisor out of the loop in reference layernorm

parent 62bc2102
...@@ -90,10 +90,11 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -90,10 +90,11 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
auto divisor = 1 / sqrt(var(m) + arg.epsilon_);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n)); auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n));
auto y_val = (x_val - mean(m)) / sqrt(var(m) + arg.epsilon_); auto y_val = (x_val - mean(m)) * divisor;
y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n); y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n);
arg.acc_elementwise_op_(y_val, y_val); arg.acc_elementwise_op_(y_val, y_val);
arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val); arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment