Commit f9d22b02 authored by rocking's avatar rocking
Browse files

Refine type

parent 94d5f723
...@@ -206,9 +206,9 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -206,9 +206,9 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
{ {
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
float out_f32 = 0; AccDataType out_acc = 0;
layerNormInst(out_f32, c_m_n(m, n), mean_m(m), meanSquare_m(m), gamma_n(n), beta_n(n)); layerNormInst(out_acc, c_m_n(m, n), mean_m(m), meanSquare_m(m), gamma_n(n), beta_n(n));
out_m_n(m, n) = static_cast<DDataType>(out_f32); out_m_n(m, n) = static_cast<DDataType>(out_acc);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment