"python/vscode:/vscode.git/clone" did not exist on "1f26e8b8e4c8b884e59036dccd87929b2af592f9"
Commit 2603bb0f authored by Chao Liu's avatar Chao Liu
Browse files

tuning on vega 20

parent a9031464
...@@ -175,16 +175,18 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -175,16 +175,18 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
c_block_data_begin += CPerBlock, __syncthreads()) c_block_data_begin += CPerBlock, __syncthreads())
{ {
// copy input tensor to LDS // copy input tensor to LDS
blockwise_in_copy.Run(p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin, blockwise_in_copy.Run(p_in_global +
c_block_data_begin, in_nchw_global_desc.Get1dIndex(n_block_data_begin,
hi_block_data_begin, c_block_data_begin,
wi_block_data_begin), hi_block_data_begin,
wi_block_data_begin),
p_in_block); p_in_block);
// copy weight tensor to LDS // copy weight tensor to LDS
blockwise_wei_copy.Run(p_wei_global + wei_kcyx_global_desc.Get1dIndex( blockwise_wei_copy.Run(
k_block_data_begin, c_block_data_begin, 0, 0), p_wei_global +
p_wei_block); wei_kcyx_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_block);
__syncthreads(); __syncthreads();
...@@ -194,10 +196,11 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -194,10 +196,11 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
#if 1 #if 1
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_nchw_thread_block_desc, in_nchw_thread_block_desc,
p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin, p_in_block +
c_thread_data, in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
hi_thread_data_begin, c_thread_data,
wi_thread_data_begin), hi_thread_data_begin,
wi_thread_data_begin),
wei_kcyx_thread_block_desc, wei_kcyx_thread_block_desc,
p_wei_block + p_wei_block +
wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
...@@ -206,10 +209,11 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -206,10 +209,11 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
#elif 0 #elif 0
threadwise_direct_convolution_3( threadwise_direct_convolution_3(
in_nchw_thread_block_desc, in_nchw_thread_block_desc,
p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin, p_in_block +
c_thread_data, in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
hi_thread_data_begin, c_thread_data,
wi_thread_data_begin), hi_thread_data_begin,
wi_thread_data_begin),
wei_kcyx_thread_block_desc, wei_kcyx_thread_block_desc,
p_wei_block + p_wei_block +
wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
...@@ -224,9 +228,10 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -224,9 +228,10 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
out_nkhw_thread_desc, out_nkhw_thread_desc,
p_out_thread, p_out_thread,
out_nkhw_global_desc, out_nkhw_global_desc,
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, p_out_global +
k_block_data_begin + k_thread_data_begin, out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, k_block_data_begin + k_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_nkhw_thread_desc.GetLengths()); out_nkhw_thread_desc.GetLengths());
} }
...@@ -198,9 +198,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -198,9 +198,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
p_in_vec_block); p_in_vec_block);
// copy weight tensor to LDS // copy weight tensor to LDS
blockwise_wei_copy.Run(p_wei_vec_global + wei_kcyx_vec_global_desc.Get1dIndex( blockwise_wei_copy.Run(
k_block_data_begin, c_block_data_begin, 0, 0), p_wei_vec_global +
p_wei_vec_block); wei_kcyx_vec_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_vec_block);
__syncthreads(); __syncthreads();
...@@ -210,10 +211,11 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -210,10 +211,11 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#if 1 #if 1
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_nchw_vec_thread_block_desc, in_nchw_vec_thread_block_desc,
p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin, p_in_vec_block +
c_thread_data, in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
hi_thread_data_begin, c_thread_data,
wi_thread_data_begin), hi_thread_data_begin,
wi_thread_data_begin),
wei_kcyx_vec_thread_block_desc, wei_kcyx_vec_thread_block_desc,
p_wei_vec_block + p_wei_vec_block +
wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
...@@ -222,10 +224,11 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -222,10 +224,11 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#elif 0 #elif 0
threadwise_direct_convolution_3( threadwise_direct_convolution_3(
in_nchw_vec_thread_block_desc, in_nchw_vec_thread_block_desc,
p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin, p_in_vec_block +
c_thread_data, in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
hi_thread_data_begin, c_thread_data,
wi_thread_data_begin), hi_thread_data_begin,
wi_thread_data_begin),
wei_kcyx_vec_thread_block_desc, wei_kcyx_vec_thread_block_desc,
p_wei_vec_block + p_wei_vec_block +
wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
...@@ -240,9 +243,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -240,9 +243,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
out_nkhw_thread_desc, out_nkhw_thread_desc,
p_out_thread, p_out_thread,
out_nkhw_global_desc, out_nkhw_global_desc,
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, p_out_global +
k_block_data_begin + k_thread_data_begin, out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, k_block_data_begin + k_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_nkhw_thread_desc.GetLengths()); out_nkhw_thread_desc.GetLengths());
} }
...@@ -283,10 +283,11 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( ...@@ -283,10 +283,11 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
out_hkwn_thread_desc, out_hkwn_thread_desc,
p_out_thread, p_out_thread,
out_khwn_global_desc, out_khwn_global_desc,
p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, p_out_global +
ho_block_data_begin + ho_thread_data_begin, out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
n_block_data_begin + n_thread_data_begin), wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_hkwn_thread_desc.GetLengths(), out_hkwn_thread_desc.GetLengths(),
reorder_khwn_from_hkwn); reorder_khwn_from_hkwn);
} }
...@@ -22,8 +22,7 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) ...@@ -22,8 +22,7 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
return os; return os;
} }
typedef enum typedef enum {
{
Half = 0, Half = 0,
Float = 1, Float = 1,
} DataType_t; } DataType_t;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment