Commit d8970ea0 authored by Jing Zhang's avatar Jing Zhang
Browse files

add double reg buffer into gemm

parent f05b210a
......@@ -145,6 +145,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
//FloatB p_b_thread[b_thread_mtx.GetElementSpace() * 2];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
......@@ -173,6 +174,39 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
reg_a[1] =
*reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
#if 0
#pragma unroll
for(index_t k = 1; k < K; ++k)
{
int b_reg_0 = (k % 2) * 2;
int b_reg_1 = ((k - 1) % 2) * 2;
reg_b[b_reg_0 + 0] =
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
reg_b[b_reg_0 + 1] = *reinterpret_cast<const Float4*>(
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[b_reg_1 + 0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[b_reg_1 + 1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
outerProduct4x4(
reg_a[1], reg_b[b_reg_1 + 0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(
reg_a[1], reg_b[b_reg_1 + 1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
}
outerProduct4x4(reg_a[0], reg_b[2], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[3], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
outerProduct4x4(reg_a[1], reg_b[2], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[3], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
#else
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] =
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
reg_a[1] =
*reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
#pragma unroll
......@@ -191,6 +225,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
#endif
}
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment