Commit c138e212 authored by Chao Liu's avatar Chao Liu
Browse files

nchw*cyxk*nkhw on AMD

parent 49d5af10
...@@ -217,9 +217,9 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, ...@@ -217,9 +217,9 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr auto gridwise_conv = constexpr auto gridwise_conv =
#if 0 #if 0
GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
#elif 0 #elif 0
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
#endif #endif
<GridSize, <GridSize,
......
...@@ -57,19 +57,19 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, ...@@ -57,19 +57,19 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0 #if 0
// for 3x3, 28x28, v1r2, Pascal // for 3x3, 34x34, v1r3, Pascal
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2; constexpr index_t WoPerBlock = 16;
constexpr index_t NPerThread = 4; constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 4;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -81,30 +81,30 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, ...@@ -81,30 +81,30 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 2>; using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>; using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 1, 16>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 2; constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 4; constexpr index_t InBlockReorderDataPerWrite_N = 1;
using WeiBlockCopyClusterLengths = Sequence<4, 1, 32>; using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2; constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif 0 #elif 0
// for 3x3, 28x28, v1r3, Pascal, bad // for 3x3, 34x34, v1r3, Vega 20
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2; constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 2; constexpr index_t WoPerBlock = 16;
constexpr index_t NPerThread = 4; constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 4;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -116,25 +116,25 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, ...@@ -116,25 +116,25 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>; using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>; using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 2, 16>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 1; // not used yet constexpr index_t InBlockReorderDataPerWrite_N = 2;
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2; constexpr index_t OutThreadCopyDataPerWrite_W = 4;
#elif 1 #elif 1
// for 3x3, 34x34, v1r3, Pascal // for 3x3, 34x34, v1r3, Vega 20, try
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 2; constexpr index_t NPerBlock = 4;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2; constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 16; constexpr index_t WoPerBlock = 8;
constexpr index_t NPerThread = 2; constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 8; constexpr index_t KPerThread = 8;
...@@ -151,15 +151,50 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, ...@@ -151,15 +151,50 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>; using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 1, 16>; using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 4, 8>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 1; // not used yet constexpr index_t InBlockReorderDataPerWrite_N = 1;
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 0
// for 3x3, 28x28, v1r2, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 2>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 2;
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = Sequence<4, 1, 32>;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2; constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#endif #endif
......
...@@ -608,7 +608,7 @@ int main(int argc, char* argv[]) ...@@ -608,7 +608,7 @@ int main(int argc, char* argv[])
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 0 #elif 0
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 1 #elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
#elif 1 #elif 1
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
......
...@@ -196,6 +196,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -196,6 +196,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB>{}; GemmDataPerReadB>{};
// choose GEMM implementation here
const auto run_blockwise_batch_gemm = [&](auto... Xs) {
#if 0
return blockwise_batch_gemm.Run(Xs...);
#elif 0
return blockwise_batch_gemm.Run_asm(Xs...);
#else
return blockwise_batch_gemm.Run_asm_v2(Xs...);
#endif
};
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr index_t in_block_space = constexpr index_t in_block_space =
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{}); in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
...@@ -293,7 +304,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -293,7 +304,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
p_wei_register_clipboard); p_wei_register_clipboard);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_batch_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); run_blockwise_batch_gemm(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard,
...@@ -322,7 +333,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -322,7 +333,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
p_wei_register_clipboard); p_wei_register_clipboard);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_batch_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); run_blockwise_batch_gemm(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy_reorder.RunStoreRegisterClipboard( blockwise_in_copy_reorder.RunStoreRegisterClipboard(
...@@ -334,7 +345,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -334,7 +345,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_batch_gemm.Run(p_wei_block_double + wei_block_space, run_blockwise_batch_gemm(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space, p_in_block_double + in_block_space,
p_out_thread); p_out_thread);
} }
......
...@@ -78,22 +78,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -78,22 +78,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup "); "wrong! cannot evenly divide work for workgroup ");
// constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock);
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock);
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); constexpr auto block_work_desc = make_ConstantTensorDescriptor(
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork); const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id());
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
const index_t k_block_data_begin = k_block_work_id * KPerBlock; const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin; const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin; const index_t wi_block_data_begin = wo_block_data_begin;
...@@ -193,6 +191,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -193,6 +191,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB>{}; GemmDataPerReadB>{};
// choose GEMM implementation here
const auto run_blockwise_batch_gemm = [&](auto... Xs) {
#if 1
return blockwise_batch_gemm.Run(Xs...);
#elif 0
return blockwise_batch_gemm.Run_asm(Xs...);
#else
return blockwise_batch_gemm.Run_asm_v2(Xs...);
#endif
};
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr index_t in_block_space = constexpr index_t in_block_space =
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{}); in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
...@@ -222,7 +231,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -222,7 +231,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
// set threadwise output tensor to 0 // set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
#if 1 #if 0
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global +
in_n_c_h_w_global_desc.Get1dIndex( in_n_c_h_w_global_desc.Get1dIndex(
...@@ -267,7 +276,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -267,7 +276,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
__syncthreads(); __syncthreads();
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread);
__syncthreads(); __syncthreads();
} }
...@@ -314,7 +323,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -314,7 +323,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
__syncthreads(); __syncthreads();
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread);
__syncthreads(); __syncthreads();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment