Commit 268d1c71 authored by Chao Liu's avatar Chao Liu
Browse files

tidy up

parent c9fa46af
...@@ -69,7 +69,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -69,7 +69,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc)); Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
#if 1 #if 0
// 3x3, 34x34 // 3x3, 34x34
// need to use register double buffer for GEMM // need to use register double buffer for GEMM
constexpr index_t BPerBlock = 128; constexpr index_t BPerBlock = 128;
...@@ -189,7 +189,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -189,7 +189,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
#elif 1 #elif 0
// 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer // 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer
constexpr index_t BPerBlock = 64; constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
...@@ -217,7 +217,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -217,7 +217,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t OutThreadCopyDataPerWrite = 4; constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 1
// 1x1, 14x14, Vega 20, enable lds_double_buffer, disable register_double_buffer // 1x1, 14x14, Vega 20, enable lds_double_buffer, disable register_double_buffer
constexpr index_t BPerBlock = 128; constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
......
...@@ -409,7 +409,7 @@ int main(int argc, char* argv[]) ...@@ -409,7 +409,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -580,7 +580,7 @@ int main(int argc, char* argv[]) ...@@ -580,7 +580,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 1x1 filter, 14x14 image, C = 2048 // 1x1 filter, 14x14 image, C = 2048
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 2048; constexpr index_t C = 2048;
...@@ -661,7 +661,7 @@ int main(int argc, char* argv[]) ...@@ -661,7 +661,7 @@ int main(int argc, char* argv[])
device_direct_convolution_2_nchw_kcyx_nkhw device_direct_convolution_2_nchw_kcyx_nkhw
#elif 0 #elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1 #elif 0
device_implicit_gemm_convolution_1_chwn_cyxk_khwn device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 1 #elif 1
device_implicit_gemm_convolution_2_chwn_cyxk_khwn device_implicit_gemm_convolution_2_chwn_cyxk_khwn
......
...@@ -340,7 +340,8 @@ struct BlockwiseChwnTensorCopyPadded ...@@ -340,7 +340,8 @@ struct BlockwiseChwnTensorCopyPadded
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize; constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
const Float* p_src_tmp = const Float* p_src_tmp =
p_src + src_desc.Get1dIndex(c_block_data_begin, p_src +
src_desc.Get1dIndex(c_block_data_begin,
(ho_block_data_begin + h_block_pad_low) - h_global_pad_low, (ho_block_data_begin + h_block_pad_low) - h_global_pad_low,
(wo_block_data_begin + w_block_pad_low) - w_global_pad_low, (wo_block_data_begin + w_block_pad_low) - w_global_pad_low,
n_block_data_begin); n_block_data_begin);
......
...@@ -329,7 +329,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -329,7 +329,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
c_thread_sub_mtx, c_thread_sub_mtx,
p_c_thread + c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster, p_c_thread +
c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
n_repeat * NPerLevel1Cluster), n_repeat * NPerLevel1Cluster),
c_block_mtx, c_block_mtx,
p_c_block + p_c_block +
......
...@@ -93,7 +93,8 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -93,7 +93,8 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
Float p_out_thread[out_thread_desc.GetElementSpace()]; Float p_out_thread[out_thread_desc.GetElementSpace()];
threadwise_4d_tensor_copy(out_block_desc, threadwise_4d_tensor_copy(out_block_desc,
p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin, p_out_block +
out_block_desc.Get1dIndex(n_thread_data_begin,
k_thread_data_begin, k_thread_data_begin,
ho_thread_data_begin, ho_thread_data_begin,
wo_thread_data_begin), wo_thread_data_begin),
...@@ -107,7 +108,8 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -107,7 +108,8 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
// threadwise convolution // threadwise convolution
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_thread_block_desc, in_thread_block_desc,
p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin, p_in_block +
in_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data_begin, c_thread_data_begin,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -122,7 +124,8 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -122,7 +124,8 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
threadwise_4d_tensor_copy(out_thread_desc, threadwise_4d_tensor_copy(out_thread_desc,
p_out_thread, p_out_thread,
out_block_desc, out_block_desc,
p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin, p_out_block +
out_block_desc.Get1dIndex(n_thread_data_begin,
k_thread_data_begin, k_thread_data_begin,
ho_thread_data_begin, ho_thread_data_begin,
wo_thread_data_begin), wo_thread_data_begin),
......
...@@ -183,7 +183,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn ...@@ -183,7 +183,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread);
const Float* p_in_global_block_begin = const Float* p_in_global_block_begin =
p_in_global + in_chwn_global_desc.Get1dIndex( p_in_global +
in_chwn_global_desc.Get1dIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_begin = const Float* p_wei_global_block_begin =
...@@ -267,7 +268,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn ...@@ -267,7 +268,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
constexpr index_t N2 = GemmNPerThreadSubC; constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2; constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC); constexpr index_t W2 =
(GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2; constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
......
...@@ -387,11 +387,12 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -387,11 +387,12 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{}); constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{});
threadwise_6d_tensor_copy(out_6d_thread_desc, threadwise_6d_tensor_copy(
out_6d_thread_desc,
p_out_thread, p_out_thread,
out_6d_global_desc, out_6d_global_desc,
p_out_global + out_kb_global_desc.Get1dIndex( p_out_global +
k_thread_data_begin, b_thread_data_begin), out_kb_global_desc.Get1dIndex(k_thread_data_begin, b_thread_data_begin),
out_6d_thread_desc.GetLengths(), out_6d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite>{}); Number<OutThreadCopyDataPerWrite>{});
} }
......
...@@ -113,7 +113,8 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ ...@@ -113,7 +113,8 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
c_block_work_begin += CPerBlock) c_block_work_begin += CPerBlock)
{ {
// copy input tensor to LDS // copy input tensor to LDS
blockwise_in_copy.Run(p_in_global + in_global_desc.Get1dIndex(n_block_work_begin, blockwise_in_copy.Run(p_in_global +
in_global_desc.Get1dIndex(n_block_work_begin,
c_block_work_begin, c_block_work_begin,
hi_block_work_begin, hi_block_work_begin,
wi_block_work_begin), wi_block_work_begin),
...@@ -143,9 +144,9 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ ...@@ -143,9 +144,9 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
} }
// copy output tensor from LDS to device mem // copy output tensor from LDS to device mem
blockwise_out_copy.Run(p_out_block, blockwise_out_copy.Run(
p_out_global + out_global_desc.Get1dIndex(n_block_work_begin, p_out_block,
k_block_work_begin, p_out_global +
ho_block_work_begin, out_global_desc.Get1dIndex(
wo_block_work_begin)); n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin));
} }
...@@ -175,15 +175,17 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -175,15 +175,17 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
c_block_data_begin += CPerBlock, __syncthreads()) c_block_data_begin += CPerBlock, __syncthreads())
{ {
// copy input tensor to LDS // copy input tensor to LDS
blockwise_in_copy.Run(p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin, blockwise_in_copy.Run(p_in_global +
in_nchw_global_desc.Get1dIndex(n_block_data_begin,
c_block_data_begin, c_block_data_begin,
hi_block_data_begin, hi_block_data_begin,
wi_block_data_begin), wi_block_data_begin),
p_in_block); p_in_block);
// copy weight tensor to LDS // copy weight tensor to LDS
blockwise_wei_copy.Run(p_wei_global + wei_kcyx_global_desc.Get1dIndex( blockwise_wei_copy.Run(
k_block_data_begin, c_block_data_begin, 0, 0), p_wei_global +
wei_kcyx_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_block); p_wei_block);
__syncthreads(); __syncthreads();
...@@ -194,7 +196,8 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -194,7 +196,8 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
#if 1 #if 1
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_nchw_thread_block_desc, in_nchw_thread_block_desc,
p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin, p_in_block +
in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -206,7 +209,8 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -206,7 +209,8 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
#elif 0 #elif 0
threadwise_direct_convolution_3( threadwise_direct_convolution_3(
in_nchw_thread_block_desc, in_nchw_thread_block_desc,
p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin, p_in_block +
in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -224,7 +228,8 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -224,7 +228,8 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
out_nkhw_thread_desc, out_nkhw_thread_desc,
p_out_thread, p_out_thread,
out_nkhw_global_desc, out_nkhw_global_desc,
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, p_out_global +
out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
......
...@@ -198,8 +198,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -198,8 +198,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
p_in_vec_block); p_in_vec_block);
// copy weight tensor to LDS // copy weight tensor to LDS
blockwise_wei_copy.Run(p_wei_vec_global + wei_kcyx_vec_global_desc.Get1dIndex( blockwise_wei_copy.Run(
k_block_data_begin, c_block_data_begin, 0, 0), p_wei_vec_global +
wei_kcyx_vec_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_vec_block); p_wei_vec_block);
__syncthreads(); __syncthreads();
...@@ -210,7 +211,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -210,7 +211,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#if 1 #if 1
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_nchw_vec_thread_block_desc, in_nchw_vec_thread_block_desc,
p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin, p_in_vec_block +
in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -222,7 +224,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -222,7 +224,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#elif 0 #elif 0
threadwise_direct_convolution_3( threadwise_direct_convolution_3(
in_nchw_vec_thread_block_desc, in_nchw_vec_thread_block_desc,
p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin, p_in_vec_block +
in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -240,7 +243,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -240,7 +243,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
out_nkhw_thread_desc, out_nkhw_thread_desc,
p_out_thread, p_out_thread,
out_nkhw_global_desc, out_nkhw_global_desc,
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, p_out_global +
out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
......
...@@ -283,7 +283,8 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( ...@@ -283,7 +283,8 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
out_hkwn_thread_desc, out_hkwn_thread_desc,
p_out_thread, p_out_thread,
out_khwn_global_desc, out_khwn_global_desc,
p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, p_out_global +
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
......
...@@ -22,8 +22,7 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) ...@@ -22,8 +22,7 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
return os; return os;
} }
typedef enum typedef enum {
{
Half = 0, Half = 0,
Float = 1, Float = 1,
} DataType_t; } DataType_t;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment