Commit 216e3da6 authored by Chao Liu's avatar Chao Liu
Browse files

bug fix and tune implicit gemm

parent caf4d7e6
...@@ -389,7 +389,7 @@ int main() ...@@ -389,7 +389,7 @@ int main()
#if 0 #if 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#endif #endif
...@@ -413,7 +413,7 @@ int main() ...@@ -413,7 +413,7 @@ int main()
#endif #endif
} }
#if 1 #if 0
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host); host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host);
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
#elif 0 #elif 0
......
...@@ -38,7 +38,7 @@ void device_implicit_gemm_convolution_nchw_kcsr( ...@@ -38,7 +38,7 @@ void device_implicit_gemm_convolution_nchw_kcsr(
constexpr unsigned WoPerThread = 2; constexpr unsigned WoPerThread = 2;
constexpr unsigned BlockSize = 16; constexpr unsigned BlockSize = 16;
#elif 1 #elif 0
constexpr unsigned NPerBlock = 2; constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4; constexpr unsigned CPerBlock = 4;
...@@ -51,19 +51,19 @@ void device_implicit_gemm_convolution_nchw_kcsr( ...@@ -51,19 +51,19 @@ void device_implicit_gemm_convolution_nchw_kcsr(
constexpr unsigned WoPerThread = 2; constexpr unsigned WoPerThread = 2;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#elif 0 #elif 1
constexpr unsigned NPerBlock = 2; constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 64; constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32; constexpr unsigned WoPerBlock = 32;
constexpr unsigned KPerThread = 4; constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 2; constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2; constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2; constexpr unsigned WoPerThread = 4;
constexpr unsigned BlockSize = 256; constexpr unsigned BlockSize = 128;
#endif #endif
constexpr unsigned GridSize = constexpr unsigned GridSize =
......
...@@ -105,13 +105,24 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c ...@@ -105,13 +105,24 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
const auto c_thread_mtx_index = CalculateThreadMatrixCIndex(get_thread_local_1d_id()); const auto c_thread_mtx_index = CalculateThreadMatrixCIndex(get_thread_local_1d_id());
mMyThreadOffsetA = c_thread_mtx_index.batch_begin * a_block_mtx.GetElementSpace() + mMyThreadOffsetA = c_thread_mtx_index.batch_begin * BlockMatrixStrideA +
((!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row_begin, 0) ((!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row_begin, 0)
: a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row_begin)); : a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row_begin));
mMyThreadOffsetB = c_thread_mtx_index.batch_begin * b_block_mtx.GetElementSpace() + mMyThreadOffsetB = c_thread_mtx_index.batch_begin * BlockMatrixStrideB +
((!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col_begin) ((!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col_begin)
: b_block_mtx.Get1dIndex(c_thread_mtx_index.col_begin, 0)); : b_block_mtx.Get1dIndex(c_thread_mtx_index.col_begin, 0));
#if 0
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch_begin,
c_thread_mtx_index.row_begin,
c_thread_mtx_index.col_begin,
mMyThreadOffsetA,
mMyThreadOffsetB);
#endif
} }
__device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const __device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const
......
...@@ -174,7 +174,7 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc, ...@@ -174,7 +174,7 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread) for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
{ {
// threadwise convolution // threadwise convolution
#if 1 #if 0
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_thread_block_desc, in_thread_block_desc,
p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin, p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin,
......
...@@ -84,6 +84,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -84,6 +84,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc"); print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc");
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc"); print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
print_ConstantTensorDescriptor(wei_kcsr_block_desc, "wei_kcsr_block_desc");
print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc"); print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc");
print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc"); print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc");
...@@ -184,7 +185,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -184,7 +185,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
in_nchw_block_desc.GetLengths()); in_nchw_block_desc.GetLengths());
#endif #endif
#if 1 #if 0
// weight: global mem to LDS, // weight: global mem to LDS,
// convert [K,C,S,R] to [S,R,C,K] // convert [K,C,S,R] to [S,R,C,K]
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>( blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
...@@ -209,6 +210,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -209,6 +210,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
__syncthreads(); __syncthreads();
#if 1
// a series of batched GEMM // a series of batched GEMM
for(unsigned s = 0; s < S; ++s) for(unsigned s = 0; s < S; ++s)
{ {
...@@ -222,16 +224,21 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -222,16 +224,21 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
f_accum); f_accum);
} }
} }
#endif
} }
const auto matrix_c_index = const auto matrix_c_index =
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
#if 0
printf("%u %u, %u %u %u\n",get_block_1d_id(), get_thread_local_1d_id(), matrix_c_index.batch_begin, matrix_c_index.row_begin, matrix_c_index.col_begin);
#endif
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
const unsigned k_thread_data_begin = matrix_c_index.row_begin; const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread; const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread;
#if 1 #if 0
// output: register to global mem, // output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo] // convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
constexpr auto reorder_nkhw_from_hkwn = Sequence<3, 1, 0, 2>{}; constexpr auto reorder_nkhw_from_hkwn = Sequence<3, 1, 0, 2>{};
......
...@@ -151,6 +151,7 @@ __device__ void threadwise_direct_convolution_3(InDesc, ...@@ -151,6 +151,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
in_w_new_read>{}); in_w_new_read>{});
#if 0 #if 0
// this verison reused old input data in register, and read new data from LDS
// loop over vertical direction // loop over vertical direction
for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s) for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s)
{ {
...@@ -200,6 +201,7 @@ __device__ void threadwise_direct_convolution_3(InDesc, ...@@ -200,6 +201,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
} }
} }
#elif 1 #elif 1
// this version read all input from LDS when filter moves
// loop over vertical direction // loop over vertical direction
for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s) for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s)
{ {
...@@ -226,4 +228,4 @@ __device__ void threadwise_direct_convolution_3(InDesc, ...@@ -226,4 +228,4 @@ __device__ void threadwise_direct_convolution_3(InDesc,
} }
} }
#endif #endif
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment