Commit 8bd6ea1a authored by Chao Liu's avatar Chao Liu
Browse files

improve implicit gemm NCHW, SRCK, NKHW, and tuned

parent 1de6fd07
......@@ -361,7 +361,7 @@ int main()
constexpr unsigned K = 1;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#elif 1
#elif 0
// 3x3, 34x34
constexpr unsigned N = 64;
constexpr unsigned C = 256;
......@@ -370,15 +370,6 @@ int main()
constexpr unsigned K = 64;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#elif 0
// 3x3, 54x54
constexpr unsigned N = 64;
constexpr unsigned C = 64;
constexpr unsigned HI = 54;
constexpr unsigned WI = 54;
constexpr unsigned K = 64;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#elif 0
// 3x3, 56x56
constexpr unsigned N = 64;
......@@ -415,6 +406,15 @@ int main()
constexpr unsigned K = 64;
constexpr unsigned S = 7;
constexpr unsigned R = 7;
#elif 1
// 3x3, 58x58
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 256;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#endif
auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{});
......@@ -449,7 +449,7 @@ int main()
device_direct_convolution_2
#elif 0
device_implicit_gemm_convolution_1_nchw_kcsr
#elif 1
#elif 0
device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 1
device_implicit_gemm_convolution_1_chwn_csrk_khwn
......
......@@ -87,8 +87,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 8;
#elif 1
// for 3x3, 34x34 | 3x3 58x58
#elif 0
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
......@@ -101,6 +101,21 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 1
// 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 0
// for 5x5, 36x36
......
......@@ -65,7 +65,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
constexpr unsigned WoPerThread = 2;
constexpr unsigned BlockSize = 16;
#elif 1
#elif 0
// for 3x3, 34x34
constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 64;
......@@ -73,6 +73,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 2;
......@@ -80,16 +81,32 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
constexpr unsigned BlockSize = 128;
#elif 0
// for 3x3, 34x34
constexpr unsigned NPerBlock = 2;
// for 3x3, 58x58
constexpr unsigned NPerBlock = 4;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 8;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 1
// for 3x3, 56x56
constexpr unsigned NPerBlock = 32;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 2;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
......@@ -123,6 +140,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
CPerThread,
HoPerThread,
......
......@@ -17,6 +17,7 @@ template <unsigned GridSize,
unsigned CPerBlock,
unsigned HoPerBlock,
unsigned WoPerBlock,
unsigned NPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned HoPerThread,
......@@ -32,7 +33,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
constexpr unsigned NPerThread = NPerBlock;
static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0");
static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock,
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -207,7 +210,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock;
const unsigned n_thread_data_begin =
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
// output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
......@@ -217,7 +222,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
out_hkwn_thread_desc,
p_out_thread,
out_nkhw_global_desc,
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin,
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment