"vscode:/vscode.git/clone" did not exist on "5ee304595c358203d218d05bcd9cfaf6308f89b7"
Commit ff7a6219 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 89ee2597
...@@ -336,14 +336,6 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -336,14 +336,6 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main() int main()
{ {
#if 0 #if 0
constexpr unsigned N = 1;
constexpr unsigned C = 1;
constexpr unsigned HI = 4;
constexpr unsigned WI = 4;
constexpr unsigned K = 1;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#elif 0
constexpr unsigned N = 1; constexpr unsigned N = 1;
constexpr unsigned C = 1; constexpr unsigned C = 1;
constexpr unsigned HI = 34; constexpr unsigned HI = 34;
...@@ -369,12 +361,12 @@ int main() ...@@ -369,12 +361,12 @@ int main()
constexpr unsigned R = 3; constexpr unsigned R = 3;
#elif 0 #elif 0
constexpr unsigned N = 64; constexpr unsigned N = 64;
constexpr unsigned C = 64; constexpr unsigned C = 256;
constexpr unsigned HI = 66; constexpr unsigned HI = 36;
constexpr unsigned WI = 66; constexpr unsigned WI = 36;
constexpr unsigned K = 64; constexpr unsigned K = 64;
constexpr unsigned S = 3; constexpr unsigned S = 5;
constexpr unsigned R = 3; constexpr unsigned R = 5;
#endif #endif
auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{}); auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{});
......
...@@ -52,7 +52,7 @@ void device_implicit_gemm_convolution( ...@@ -52,7 +52,7 @@ void device_implicit_gemm_convolution(
constexpr unsigned WoPerThread = 2; constexpr unsigned WoPerThread = 2;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#elif 1 #elif 0
constexpr unsigned NPerBlock = 2; constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 64; constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr unsigned CPerBlock = 4;
...@@ -60,7 +60,7 @@ void device_implicit_gemm_convolution( ...@@ -60,7 +60,7 @@ void device_implicit_gemm_convolution(
constexpr unsigned WoPerBlock = 32; constexpr unsigned WoPerBlock = 32;
constexpr unsigned KPerThread = 4; constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 1; constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2; constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2; constexpr unsigned WoPerThread = 2;
......
...@@ -152,7 +152,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc, ...@@ -152,7 +152,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc,
for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1); for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1);
c_block_data_begin += CPerBlock, __syncthreads()) c_block_data_begin += CPerBlock, __syncthreads())
{ {
#if 1
// input: global mem to LDS, // input: global mem to LDS,
// convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N] // convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N]
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>( blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
...@@ -165,9 +164,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc, ...@@ -165,9 +164,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc,
p_in_block, p_in_block,
in_nchw_block_desc.GetLengths(), in_nchw_block_desc.GetLengths(),
reorder_chwn_from_nchw); reorder_chwn_from_nchw);
#endif
#if 1
// weight: global mem to LDS, // weight: global mem to LDS,
blockwise_4d_tensor_copy<BlockSize>( blockwise_4d_tensor_copy<BlockSize>(
wei_srck_global_desc, wei_srck_global_desc,
...@@ -176,11 +173,9 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc, ...@@ -176,11 +173,9 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc,
wei_srck_block_desc, wei_srck_block_desc,
p_wei_block, p_wei_block,
wei_srck_block_desc.GetLengths()); wei_srck_block_desc.GetLengths());
#endif
__syncthreads(); __syncthreads();
#if 1
// a series of batched GEMM // a series of batched GEMM
for(unsigned s = 0; s < S; ++s) for(unsigned s = 0; s < S; ++s)
{ {
...@@ -194,7 +189,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc, ...@@ -194,7 +189,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc,
f_accum); f_accum);
} }
} }
#endif
} }
const auto matrix_c_index = const auto matrix_c_index =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment