Commit bd811e2c authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent c39c573e
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "device_direct_convolution_1.cuh" #include "device_direct_convolution_1.cuh"
#include "device_direct_convolution_2.cuh" #include "device_direct_convolution_2.cuh"
//#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh" //#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh"
//#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh" #include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh" #include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
//#include "device_winograd_convolution.cuh" //#include "device_winograd_convolution.cuh"
...@@ -418,9 +418,9 @@ int main() ...@@ -418,9 +418,9 @@ int main()
device_direct_convolution_2 device_direct_convolution_2
#elif 0 #elif 0
device_implicit_gemm_convolution_1_nchw_kcsr device_implicit_gemm_convolution_1_nchw_kcsr
#elif 0
device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 1 #elif 1
device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 0
device_implicit_gemm_convolution_2_cnhw_srck_knhw device_implicit_gemm_convolution_2_cnhw_srck_knhw
#elif 0 #elif 0
device_winograd_convolution device_winograd_convolution
......
...@@ -103,6 +103,15 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, ...@@ -103,6 +103,15 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
} }
#endif #endif
// blockwise copy
// wei: format is [S,R,C,K], no conversion needed
constexpr auto blockwise_wei_copy =
blockwise_4d_tensor_copy_1<BlockSize,
Float,
decltype(wei_srck_global_desc),
decltype(wei_srck_block_desc),
decltype(wei_srck_block_desc.GetLengths())>{};
// a series of blockwise batched GEMM // a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix // C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register // A_matrix and B_matrix saved in LDS, C_matrix saved in register
...@@ -171,13 +180,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, ...@@ -171,13 +180,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
#if 1 #if 1
// weight: global mem to LDS, // weight: global mem to LDS,
// format is [S,R,C,K], no conversion needed // format is [S,R,C,K], no conversion needed
blockwise_4d_tensor_copy<BlockSize>( blockwise_wei_copy.run(p_wei_global + wei_srck_global_desc.Get1dIndex(
wei_srck_global_desc, 0, 0, c_block_data_begin, k_block_data_begin),
p_wei_global + p_wei_block);
wei_srck_global_desc.Get1dIndex(0, 0, c_block_data_begin, k_block_data_begin),
wei_srck_block_desc,
p_wei_block,
wei_srck_block_desc.GetLengths());
#endif #endif
__syncthreads(); __syncthreads();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment