Commit 3d091db2 authored by rocking's avatar rocking
Browse files

Use AccDataType

parent 6c0636c5
......@@ -181,8 +181,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
auto reduceSumOpInst = ReduceSumOp{};
for(int m = 0; m < M; ++m)
{
float mean_acc = reduceSumOpInst.GetIdentityValue();
float square_mean_acc = reduceSumOpInst.GetIdentityValue();
AccDataType mean_acc = reduceSumOpInst.GetIdentityValue();
AccDataType square_mean_acc = reduceSumOpInst.GetIdentityValue();
for(int n = 0; n < N; ++n)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment