Commit 2603bb0f authored by Chao Liu's avatar Chao Liu
Browse files

tuning on vega 20

parent a9031464
...@@ -175,15 +175,17 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -175,15 +175,17 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
c_block_data_begin += CPerBlock, __syncthreads()) c_block_data_begin += CPerBlock, __syncthreads())
{ {
// copy input tensor to LDS // copy input tensor to LDS
blockwise_in_copy.Run(p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin, blockwise_in_copy.Run(p_in_global +
in_nchw_global_desc.Get1dIndex(n_block_data_begin,
c_block_data_begin, c_block_data_begin,
hi_block_data_begin, hi_block_data_begin,
wi_block_data_begin), wi_block_data_begin),
p_in_block); p_in_block);
// copy weight tensor to LDS // copy weight tensor to LDS
blockwise_wei_copy.Run(p_wei_global + wei_kcyx_global_desc.Get1dIndex( blockwise_wei_copy.Run(
k_block_data_begin, c_block_data_begin, 0, 0), p_wei_global +
wei_kcyx_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_block); p_wei_block);
__syncthreads(); __syncthreads();
...@@ -194,7 +196,8 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -194,7 +196,8 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
#if 1 #if 1
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_nchw_thread_block_desc, in_nchw_thread_block_desc,
p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin, p_in_block +
in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -206,7 +209,8 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -206,7 +209,8 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
#elif 0 #elif 0
threadwise_direct_convolution_3( threadwise_direct_convolution_3(
in_nchw_thread_block_desc, in_nchw_thread_block_desc,
p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin, p_in_block +
in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -224,7 +228,8 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -224,7 +228,8 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
out_nkhw_thread_desc, out_nkhw_thread_desc,
p_out_thread, p_out_thread,
out_nkhw_global_desc, out_nkhw_global_desc,
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, p_out_global +
out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
......
...@@ -198,8 +198,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -198,8 +198,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
p_in_vec_block); p_in_vec_block);
// copy weight tensor to LDS // copy weight tensor to LDS
blockwise_wei_copy.Run(p_wei_vec_global + wei_kcyx_vec_global_desc.Get1dIndex( blockwise_wei_copy.Run(
k_block_data_begin, c_block_data_begin, 0, 0), p_wei_vec_global +
wei_kcyx_vec_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_vec_block); p_wei_vec_block);
__syncthreads(); __syncthreads();
...@@ -210,7 +211,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -210,7 +211,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#if 1 #if 1
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_nchw_vec_thread_block_desc, in_nchw_vec_thread_block_desc,
p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin, p_in_vec_block +
in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -222,7 +224,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -222,7 +224,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#elif 0 #elif 0
threadwise_direct_convolution_3( threadwise_direct_convolution_3(
in_nchw_vec_thread_block_desc, in_nchw_vec_thread_block_desc,
p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin, p_in_vec_block +
in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -240,7 +243,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -240,7 +243,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
out_nkhw_thread_desc, out_nkhw_thread_desc,
p_out_thread, p_out_thread,
out_nkhw_global_desc, out_nkhw_global_desc,
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, p_out_global +
out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
......
...@@ -283,7 +283,8 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( ...@@ -283,7 +283,8 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
out_hkwn_thread_desc, out_hkwn_thread_desc,
p_out_thread, p_out_thread,
out_khwn_global_desc, out_khwn_global_desc,
p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, p_out_global +
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
......
...@@ -22,8 +22,7 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) ...@@ -22,8 +22,7 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
return os; return os;
} }
typedef enum typedef enum {
{
Half = 0, Half = 0,
Float = 1, Float = 1,
} DataType_t; } DataType_t;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment