Commit 0de4286a authored by Jing Zhang's avatar Jing Zhang
Browse files

increase depth of pipeline

parent 6fef303e
...@@ -384,9 +384,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -384,9 +384,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
Float4* reg_c = (Float4*)(p_c_thread); Float4* reg_c = (Float4*)(p_c_thread);
void* a_loc = (void *)(p_a_block + mMyThreadOffsetA); void* a_loc = (void *)(p_a_block + mMyThreadOffsetA);
void* b_loc = (void *)(p_b_block + mMyThreadOffsetB); void* b_loc = (void *)(p_b_block + mMyThreadOffsetB);
#pragma unroll
// loop over k // loop over k
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) int k_chunk = 2;
#pragma unroll
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop * k_chunk)
{ {
#if 0 #if 0
...@@ -402,15 +403,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -402,15 +403,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
#else #else
ds_read_b128(reg_a[0], a_loc, k_begin * 512); int k = k_begin;
ds_read_b128(reg_b[0], b_loc, k_begin * 256); ds_read_b128(reg_a[0], a_loc, k * 512);
ds_read_b128(reg_b[1], b_loc, 128 + k_begin * 256); ds_read_b128(reg_b[0], b_loc, k * 256);
ds_read_b128(reg_a[1], a_loc, 256 + k_begin * 512); ds_read_b128(reg_b[1], b_loc, 128 + k * 256);
ds_read_b128(reg_a[1], a_loc, 256 + k * 512);
lgkmcnt(2); lgkmcnt(2);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1); lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
lgkmcnt(0); lgkmcnt(0);
for(int i = 0; i < k_chunk - 1; i++)
{
k = k + 1;
ds_read_b128(reg_a[0], a_loc, k * 512);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
ds_read_b128(reg_b[0], b_loc, k * 256);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
ds_read_b128(reg_b[1], b_loc, 128 + k * 256);
ds_read_b128(reg_a[1], a_loc, 256 + k * 512);
lgkmcnt(2);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
lgkmcnt(0);
}
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
#endif #endif
......
...@@ -73,20 +73,16 @@ __device__ void threadwise_gemm(MatrixA, ...@@ -73,20 +73,16 @@ __device__ void threadwise_gemm(MatrixA,
for(index_t k = 0; k < K; ++k) for(index_t k = 0; k < K; ++k)
{ {
#if 1 #if 1
for(index_t i = 0; i < M; i+=4) for(index_t i = 0; i < M; i++)
{ {
const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed
const Float4 *a_vec = (const Float4 *)&p_a_thread[aindex];
for(index_t j = 0; j < N; j+=4) for(index_t j = 0; j < N; j++)
{ {
const index_t bindex = b_mtx.Get1dIndex(k, j); const index_t bindex = b_mtx.Get1dIndex(k, j);
const index_t cindex = c_mtx.Get1dIndex(i, j); const index_t cindex = c_mtx.Get1dIndex(i, j);
const Float4 *b_vec = (const Float4 *)&p_b_thread[bindex]; p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex];
Float4 *c_vec = (Float4 *)&p_c_thread[cindex];
outerProduct4x4(a_vec[0], b_vec[0], c_vec[0], c_vec[2], c_vec[4], c_vec[6]);
} }
} }
#else #else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment