"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "219deef165d905dcc78100b9bd4485c988a01820"
Commit 18a81e35 authored by Chao Liu's avatar Chao Liu
Browse files

adding assembly

parent 8c923db4
...@@ -435,11 +435,12 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -435,11 +435,12 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
#pragma unroll #pragma unroll
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{ {
// read first batch of A, B // read first batch of A, B
// copy A-sub to form A // copy A-sub to form A
#pragma unroll //#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ {
#if 0
threadwise_matrix_copy( threadwise_matrix_copy(
a_block_mtx, a_block_mtx,
p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
...@@ -447,12 +448,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -447,12 +448,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
a_thread_mtx, a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths());
#else
for(unsigned i = 0; i < a_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < a_thread_mtx.NCol(); ++j)
{
p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] =
p_a_block[a_block_mtx.Get1dIndex(k_begin + i,
m_repeat * MPerLevel1Cluster + j) +
mMyThreadOffsetA];
}
}
#endif
} }
// copy B-sub to form B // copy B-sub to form B
#pragma unroll //#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ {
#if 0
threadwise_matrix_copy( threadwise_matrix_copy(
b_block_mtx, b_block_mtx,
p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) + p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
...@@ -460,13 +474,26 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -460,13 +474,26 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
b_thread_mtx, b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths()); b_thread_sub_mtx.GetLengths());
#else
for(unsigned i = 0; i < b_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < b_thread_mtx.NCol(); ++j)
{
p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] =
p_b_block[b_block_mtx.Get1dIndex(k_begin + i,
n_repeat * MPerLevel1Cluster + j) +
mMyThreadOffsetB];
}
}
#endif
} }
// loop over batch // loop over batch
#pragma unroll //#pragma unroll
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib)
{ {
// do current batch of gemm // do current batch of gemm
#if 0
threadwise_gemm(a_thread_mtx, threadwise_gemm(a_thread_mtx,
True, True,
p_a_thread, p_a_thread,
...@@ -477,13 +504,32 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -477,13 +504,32 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
False, False,
p_c_thread + ib * ThreadMatrixStrideC, p_c_thread + ib * ThreadMatrixStrideC,
f_accum); f_accum);
#else
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k)
{
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex =
a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex =
c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
}
}
}
#endif
// read next batch of a, b // read next batch of a, b
if(BlockMatrixStrideA != 0) if(BlockMatrixStrideA != 0)
{ {
#pragma unroll //#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ {
#if 0
threadwise_matrix_copy( threadwise_matrix_copy(
a_block_mtx, a_block_mtx,
p_a_block + p_a_block +
...@@ -492,14 +538,28 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -492,14 +538,28 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
a_thread_mtx, a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths());
#else
for(unsigned i = 0; i < a_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < a_thread_mtx.NCol(); ++j)
{
p_a_thread[a_thread_mtx.Get1dIndex(i,
m_repeat * MPerThreadSubC + j)] =
p_a_block[a_block_mtx.Get1dIndex(
k_begin + i, m_repeat * MPerLevel1Cluster + j) +
(ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA];
}
}
#endif
} }
} }
if(BlockMatrixStrideB != 0) if(BlockMatrixStrideB != 0)
{ {
#pragma unroll //#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ {
#if 0
threadwise_matrix_copy( threadwise_matrix_copy(
b_block_mtx, b_block_mtx,
p_b_block + p_b_block +
...@@ -508,11 +568,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -508,11 +568,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
b_thread_mtx, b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths()); b_thread_sub_mtx.GetLengths());
#else
for(unsigned i = 0; i < b_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < b_thread_mtx.NCol(); ++j)
{
p_b_thread[b_thread_mtx.Get1dIndex(i,
n_repeat * NPerThreadSubC + j)] =
p_b_block[b_block_mtx.Get1dIndex(
k_begin + i, n_repeat * MPerLevel1Cluster + j) +
(ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB];
}
}
#endif
} }
} }
} }
// do last batch of gemm // do last batch of gemm
#if 0
threadwise_gemm(a_thread_mtx, threadwise_gemm(a_thread_mtx,
True, True,
p_a_thread, p_a_thread,
...@@ -523,6 +597,23 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -523,6 +597,23 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
False, False,
p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC, p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC,
f_accum); f_accum);
#else
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k)
{
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex =
c_thread_mtx.Get1dIndex(i, j) + (BatchPerThread - 1) * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
}
}
}
#endif
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment