Commit 98df59c6 authored by letaoqin's avatar letaoqin
Browse files

bias data type convert

parent 64c9f790
......@@ -322,12 +322,12 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += d_g_m_n(idx); });
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); });
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
self(idx) = -ck::NumericLimits<AccDataType>::Infinity();
});
// softmax
......
......@@ -396,12 +396,12 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += d_g_m_n(idx); });
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); });
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
self(idx) = -ck::NumericLimits<AccDataType>::Infinity();
});
// softmax
......
......@@ -1307,8 +1307,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_thread_buf);
// acc add bias
static_for<0, m0 * n0 * n2 * n4, 1>{}(
[&](auto i) { acc_thread_buf(i) += d0_thread_buf[i]; });
static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) {
acc_thread_buf(i) += ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
d0_threadwise_copy.MoveSrcSliceWindow(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment