Commit c7590278 authored by Jiashi Li's avatar Jiashi Li
Browse files

Fix accuracy issue in sum_OdO kernel

parent ef5b1a69
...@@ -140,7 +140,7 @@ struct FmhaKernelBwdSumOdO { ...@@ -140,7 +140,7 @@ struct FmhaKernelBwdSumOdO {
*reinterpret_cast<Vec*>(value_dO) = *reinterpret_cast<const Vec*>(&ptr_dO_bhq[idx_d]); *reinterpret_cast<Vec*>(value_dO) = *reinterpret_cast<const Vec*>(&ptr_dO_bhq[idx_d]);
for (int v = 0; v < kElementsPerLoad; v++) { for (int v = 0; v < kElementsPerLoad; v++) {
acc += value_O[v] * value_dO[v]; acc += ElementAcc(value_O[v]) * ElementAcc(value_dO[v]);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment