"src/include/amd_inline_asm.hpp" did not exist on "b188c0d2437448cf7a17f904c9ea721b8a2fdd89"
Commit c8667736 authored by Chao Liu's avatar Chao Liu
Browse files

unroll some loop, register double buffer gemm

parent 1b323316
...@@ -611,7 +611,7 @@ int main() ...@@ -611,7 +611,7 @@ int main()
nrepeat); nrepeat);
#endif #endif
#if 0 #if 1
if(S == 3 && R == 3) if(S == 3 && R == 3)
{ {
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
......
...@@ -66,42 +66,12 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InDesc, ...@@ -66,42 +66,12 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InDesc,
Tensor<T> out_knhw(make_TensorDescriptor(out_knhw_desc)); Tensor<T> out_knhw(make_TensorDescriptor(out_knhw_desc));
#if 0 #if 1
// 1x1, 28x28 // 1x1, 28x28
constexpr unsigned BPerBlock = 64; constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 64; constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 8; constexpr unsigned CPerBlock = 8;
constexpr unsigned BPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned GemmMPerThreadSubC = 16;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 8;
constexpr unsigned GemmMLevel1Cluster = 1;
constexpr unsigned GemmNLevel1Cluster = 2;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 64;
#elif 1
// 1x1, 28x28 try
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 8;
constexpr unsigned BPerThread = 8; constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8; constexpr unsigned KPerThread = 8;
......
...@@ -598,9 +598,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -598,9 +598,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
#pragma unroll
// loop over k // loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{ {
#pragma unroll
// copy A-sub to form A // copy A-sub to form A
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ {
...@@ -613,6 +615,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -613,6 +615,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths());
} }
#pragma unroll
// copy B-sub to form B // copy B-sub to form B
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ {
...@@ -638,4 +641,148 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -638,4 +641,148 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
f_accum); f_accum);
} }
} }
template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void Run_RegisterDoubleBuffer(FloatA* const p_a_block,
FloatB* const p_b_block,
FloatC* p_c_thread,
Accumulator f_accum) const
{
constexpr auto True = Constant<bool, true>{};
constexpr auto False = Constant<bool, false>{};
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
const auto a_thread_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThread>{}); // constexpr doesn't compile
const auto b_thread_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThread>{}); // constexpr doesn't compile
// thread A-sub, B-sub for copy
const auto a_thread_sub_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{},
Number<MPerThreadSubC>{},
Number<MPerThread>{}); // constexpr doesn't compile
const auto b_thread_sub_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{},
Number<NPerThreadSubC>{},
Number<NPerThread>{}); // constexpr doesn't compile
FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()];
FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
// preload A, B
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx,
p_a_thread_0 + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths());
}
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx,
p_b_thread_0 + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths());
}
bool even_loop = true;
#pragma unroll
for(unsigned k_begin = 0; k_begin + 1 < K;
k_begin += KPerThreadLoop, even_loop = !even_loop)
{ // loop over k
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;
FloatA* p_a_thread_next = even_loop ? p_a_thread_1 : p_a_thread_0;
FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0;
// preload next A, B
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA +
(k_begin + 1) * a_block_mtx.RowStride() +
m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx,
p_a_thread_next + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths());
}
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB +
(k_begin + 1) * b_block_mtx.RowStride() +
n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx,
p_b_thread_next + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths());
}
// C = A * B
threadwise_gemm(a_thread_mtx,
True,
p_a_thread_now,
b_thread_mtx,
False,
p_b_thread_now,
c_thread_mtx,
False,
p_c_thread,
f_accum);
}
// last loop
{
even_loop = !even_loop;
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;
// C = A * B
threadwise_gemm(a_thread_mtx,
True,
p_a_thread_now,
b_thread_mtx,
False,
p_b_thread_now,
c_thread_mtx,
False,
p_c_thread,
f_accum);
}
}
}; };
...@@ -237,7 +237,12 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc, ...@@ -237,7 +237,12 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
{ {
auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
blockwise_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), #if 1
blockwise_gemm.Run
#else
blockwise_gemm.Run_RegisterDoubleBuffer
#endif
(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
p_in_block + s * Wi + r, p_in_block + s * Wi + r,
p_out_thread, p_out_thread,
f_accum); f_accum);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment