"...git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "cb00a19603dd59ca6682d113d0c9111036eab418"
Commit caf4d7e6 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 5872b710
...@@ -401,10 +401,10 @@ int main() ...@@ -401,10 +401,10 @@ int main()
#elif 0 #elif 0
device_direct_convolution_2( device_direct_convolution_2(
in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device); in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
#elif 0 #elif 1
device_implicit_gemm_convolution_nchw_kcsr( device_implicit_gemm_convolution_nchw_kcsr(
in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device); in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
#elif 1 #elif 0
device_implicit_gemm_convolution_nchw_srck( device_implicit_gemm_convolution_nchw_srck(
in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device); in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
#elif 0 #elif 0
......
...@@ -58,22 +58,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -58,22 +58,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
constexpr unsigned WBlockWork = constexpr unsigned WBlockWork =
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; (out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
unsigned itmp = get_block_1d_id();
const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork);
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
const unsigned h_block_work_id = itmp / WBlockWork;
const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork;
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
const unsigned wo_block_data_begin = w_block_work_id * HoPerBlock;
const unsigned hi_block_data_begin = ho_block_data_begin;
const unsigned wi_block_data_begin = wo_block_data_begin;
// tensor view of un-reorderd blockwise input and weight (imaginary) // tensor view of un-reorderd blockwise input and weight (imaginary)
constexpr auto in_nchw_block_desc = constexpr auto in_nchw_block_desc =
make_ConstantTensorDescriptor(Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}); make_ConstantTensorDescriptor(Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{});
...@@ -106,6 +90,23 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -106,6 +90,23 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
} }
#endif #endif
// my block work
unsigned itmp = get_block_1d_id();
const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork);
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
const unsigned h_block_work_id = itmp / WBlockWork;
const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork;
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
const unsigned wo_block_data_begin = w_block_work_id * HoPerBlock;
const unsigned hi_block_data_begin = ho_block_data_begin;
const unsigned wi_block_data_begin = wo_block_data_begin;
// a series of blockwise batched GEMM // a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix // C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register // A_matrix and B_matrix saved in LDS, C_matrix saved in register
...@@ -156,6 +157,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -156,6 +157,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1); for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1);
c_block_data_begin += CPerBlock, __syncthreads()) c_block_data_begin += CPerBlock, __syncthreads())
{ {
#if 1
// input: global mem to LDS, // input: global mem to LDS,
// convert [N,C,Hi,Wi] to [C,Hi,Wi,N] // convert [N,C,Hi,Wi] to [C,Hi,Wi,N]
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>( blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
...@@ -168,7 +170,21 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -168,7 +170,21 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
p_in_block, p_in_block,
in_nchw_block_desc.GetLengths(), in_nchw_block_desc.GetLengths(),
reorder_chwn_from_nchw); reorder_chwn_from_nchw);
#else
// input: global mem to LDS,
// no format conversion, this is wrong, for performance study only!
blockwise_4d_tensor_copy<BlockSize>(in_nchw_global_desc,
p_in_global +
in_nchw_global_desc.Get1dIndex(n_block_data_begin,
c_block_data_begin,
hi_block_data_begin,
wi_block_data_begin),
in_nchw_block_desc,
p_in_block,
in_nchw_block_desc.GetLengths());
#endif
#if 1
// weight: global mem to LDS, // weight: global mem to LDS,
// convert [K,C,S,R] to [S,R,C,K] // convert [K,C,S,R] to [S,R,C,K]
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>( blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
...@@ -179,6 +195,17 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -179,6 +195,17 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
p_wei_block, p_wei_block,
wei_kcsr_block_desc.GetLengths(), wei_kcsr_block_desc.GetLengths(),
reorder_srck_from_kcsr); reorder_srck_from_kcsr);
#else
// weight: global mem to LDS,
// no format conversion, this is wrong, for performance study only!
blockwise_4d_tensor_copy<BlockSize>(
wei_kcsr_global_desc,
p_wei_global +
wei_kcsr_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
wei_kcsr_block_desc,
p_wei_block,
wei_kcsr_block_desc.GetLengths());
#endif
__syncthreads(); __syncthreads();
...@@ -204,6 +231,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -204,6 +231,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
const unsigned k_thread_data_begin = matrix_c_index.row_begin; const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread; const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread;
#if 1
// output: register to global mem, // output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo] // convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
constexpr auto reorder_nkhw_from_hkwn = Sequence<3, 1, 0, 2>{}; constexpr auto reorder_nkhw_from_hkwn = Sequence<3, 1, 0, 2>{};
...@@ -218,4 +246,21 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -218,4 +246,21 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
out_hkwn_thread_desc.GetLengths(), out_hkwn_thread_desc.GetLengths(),
reorder_nkhw_from_hkwn); reorder_nkhw_from_hkwn);
#else
// output: register to global mem,
// no format conversion, assume register is in [N,K,Ho,Wo], this is wrong, for performance
// study only!
constexpr auto out_nkhw_thread_desc =
make_ConstantTensorDescriptor(Sequence<NPerThread, KPerThread, HoPerThread, WoPerThread>{});
threadwise_4d_tensor_copy(
out_nkhw_thread_desc,
p_out_thread,
out_nkhw_global_desc,
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_nkhw_thread_desc.GetLengths());
#endif
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment