Commit 8bd6ea1a authored by Chao Liu's avatar Chao Liu
Browse files

improve implicit gemm NCHW, SRCK, NKHW, and tuned

parent 1de6fd07
...@@ -361,7 +361,7 @@ int main() ...@@ -361,7 +361,7 @@ int main()
constexpr unsigned K = 1; constexpr unsigned K = 1;
constexpr unsigned S = 3; constexpr unsigned S = 3;
constexpr unsigned R = 3; constexpr unsigned R = 3;
#elif 1 #elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr unsigned N = 64; constexpr unsigned N = 64;
constexpr unsigned C = 256; constexpr unsigned C = 256;
...@@ -370,15 +370,6 @@ int main() ...@@ -370,15 +370,6 @@ int main()
constexpr unsigned K = 64; constexpr unsigned K = 64;
constexpr unsigned S = 3; constexpr unsigned S = 3;
constexpr unsigned R = 3; constexpr unsigned R = 3;
#elif 0
// 3x3, 54x54
constexpr unsigned N = 64;
constexpr unsigned C = 64;
constexpr unsigned HI = 54;
constexpr unsigned WI = 54;
constexpr unsigned K = 64;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#elif 0 #elif 0
// 3x3, 56x56 // 3x3, 56x56
constexpr unsigned N = 64; constexpr unsigned N = 64;
...@@ -415,6 +406,15 @@ int main() ...@@ -415,6 +406,15 @@ int main()
constexpr unsigned K = 64; constexpr unsigned K = 64;
constexpr unsigned S = 7; constexpr unsigned S = 7;
constexpr unsigned R = 7; constexpr unsigned R = 7;
#elif 1
// 3x3, 58x58
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 256;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#endif #endif
auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{}); auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{});
...@@ -449,7 +449,7 @@ int main() ...@@ -449,7 +449,7 @@ int main()
device_direct_convolution_2 device_direct_convolution_2
#elif 0 #elif 0
device_implicit_gemm_convolution_1_nchw_kcsr device_implicit_gemm_convolution_1_nchw_kcsr
#elif 1 #elif 0
device_implicit_gemm_convolution_1_nchw_srck_nkhw device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 1 #elif 1
device_implicit_gemm_convolution_1_chwn_csrk_khwn device_implicit_gemm_convolution_1_chwn_csrk_khwn
......
...@@ -87,8 +87,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, ...@@ -87,8 +87,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned WoPerThread = 1; constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 8; constexpr unsigned BlockSize = 8;
#elif 1 #elif 0
// for 3x3, 34x34 | 3x3 58x58 // for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16; constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr unsigned CPerBlock = 4;
...@@ -101,6 +101,21 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, ...@@ -101,6 +101,21 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned HoPerThread = 1; constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 1
// 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#elif 0 #elif 0
// for 5x5, 36x36 // for 5x5, 36x36
......
...@@ -65,7 +65,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, ...@@ -65,7 +65,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
constexpr unsigned WoPerThread = 2; constexpr unsigned WoPerThread = 2;
constexpr unsigned BlockSize = 16; constexpr unsigned BlockSize = 16;
#elif 1 #elif 0
// for 3x3, 34x34 // for 3x3, 34x34
constexpr unsigned NPerBlock = 1; constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 64; constexpr unsigned KPerBlock = 64;
...@@ -73,6 +73,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, ...@@ -73,6 +73,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
constexpr unsigned HoPerBlock = 4; constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32; constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 16; constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 2; constexpr unsigned HoPerThread = 2;
...@@ -80,16 +81,32 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, ...@@ -80,16 +81,32 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#elif 0 #elif 0
// for 3x3, 34x34 // for 3x3, 58x58
constexpr unsigned NPerBlock = 2; constexpr unsigned NPerBlock = 4;
constexpr unsigned KPerBlock = 64; constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2; constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 8;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 1
// for 3x3, 56x56
constexpr unsigned NPerBlock = 32;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32; constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 2; constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
...@@ -123,6 +140,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, ...@@ -123,6 +140,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
CPerBlock, CPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
NPerThread,
KPerThread, KPerThread,
CPerThread, CPerThread,
HoPerThread, HoPerThread,
......
...@@ -17,6 +17,7 @@ template <unsigned GridSize, ...@@ -17,6 +17,7 @@ template <unsigned GridSize,
unsigned CPerBlock, unsigned CPerBlock,
unsigned HoPerBlock, unsigned HoPerBlock,
unsigned WoPerBlock, unsigned WoPerBlock,
unsigned NPerThread,
unsigned KPerThread, unsigned KPerThread,
unsigned CPerThread, unsigned CPerThread,
unsigned HoPerThread, unsigned HoPerThread,
...@@ -32,7 +33,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, ...@@ -32,7 +33,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N] // NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N" // for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock // if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
constexpr unsigned NPerThread = NPerBlock; static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0");
static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock,
"wrong!");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -207,7 +210,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, ...@@ -207,7 +210,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
const unsigned k_thread_data_begin = matrix_c_index.row_begin; const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread; const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock;
const unsigned n_thread_data_begin =
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
// output: register to global mem, // output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo] // convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
...@@ -217,7 +222,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, ...@@ -217,7 +222,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
out_hkwn_thread_desc, out_hkwn_thread_desc,
p_out_thread, p_out_thread,
out_nkhw_global_desc, out_nkhw_global_desc,
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin, p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment