Commit 5696c81f authored by Chao Liu's avatar Chao Liu
Browse files

simplify blockwise batched GEMM

parent 71434918
......@@ -211,52 +211,12 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
#pragma unroll
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
// read first batch of A, B
// copy A-sub to form A
#pragma unroll
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
mMyThreadOffsetA,
a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
// copy B-sub to form B
#pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
mMyThreadOffsetB,
b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
// loop over batch
#pragma unroll
for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
for(index_t ib = 0; ib < BatchPerThread; ++ib)
{
// do current batch of gemm
threadwise_gemm(a_thread_mtx,
True,
p_a_thread,
b_thread_mtx,
False,
p_b_thread,
c_thread_mtx,
False,
p_c_thread + ib * ThreadMatrixStrideC);
// read next batch of a, b
if(BlockMatrixStrideA != 0)
if(BlockMatrixStrideA != 0 or ib == 0)
{
#pragma unroll
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
......@@ -265,7 +225,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
a_block_mtx,
p_a_block +
a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
(ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA,
ib * BlockMatrixStrideA + mMyThreadOffsetA,
a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths(),
......@@ -273,7 +233,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
}
if(BlockMatrixStrideB != 0)
if(BlockMatrixStrideB != 0 or ib == 0)
{
#pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
......@@ -282,25 +242,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
b_block_mtx,
p_b_block +
b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
(ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB,
ib * BlockMatrixStrideB + mMyThreadOffsetB,
b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
}
}
// do last batch of gemm
threadwise_gemm(a_thread_mtx,
True,
p_a_thread,
b_thread_mtx,
False,
p_b_thread,
c_thread_mtx,
False,
p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC);
threadwise_gemm(a_thread_mtx,
True,
p_a_thread,
b_thread_mtx,
False,
p_b_thread,
c_thread_mtx,
False,
p_c_thread + ib * ThreadMatrixStrideC);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment