"launcher/vscode:/vscode.git/clone" did not exist on "5e6ddfd6a4fecc394255d7109f87c420c98b4e15"
Commit 50b96745 authored by Chao Liu's avatar Chao Liu
Browse files

gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn use khwn for thread C data now

parent 1cb98850
...@@ -200,7 +200,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, ...@@ -200,7 +200,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned WoPerThread = 1; constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#elif 1 #elif 0
// for 1x1, 28x28 // for 1x1, 28x28
constexpr unsigned NPerBlock = 16; constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128; constexpr unsigned KPerBlock = 128;
......
...@@ -104,8 +104,8 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -104,8 +104,8 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
Sequence<CPerBlock, S, R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{}); Sequence<CPerBlock, S, R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
// tensor view of threadwise output in register // tensor view of threadwise output in register
constexpr auto out_hkwn_thread_desc = constexpr auto out_khwn_thread_desc =
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{}); make_ConstantTensorDescriptor(Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
#if 0 #if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...@@ -179,7 +179,9 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -179,7 +179,9 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
Number<in_chwn_block_desc.GetStride(I0)>{}); Number<in_chwn_block_desc.GetStride(I0)>{});
constexpr auto c_kxwn_thread_mtx_desc = constexpr auto c_kxwn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<WoPerThread * NPerThread>{}); make_ConstantMatrixDescriptor(Number<KPerThread>{},
Number<WoPerThread * NPerThread>{},
Number<out_khwn_thread_desc.GetStride(I1)>{});
#if 0 #if 0
const auto blockwise_batch_gemm = const auto blockwise_batch_gemm =
...@@ -192,7 +194,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -192,7 +194,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
false, false,
0, 0,
in_chwn_block_desc.GetStride(I1), in_chwn_block_desc.GetStride(I1),
out_hkwn_thread_desc.GetStride(I0), out_khwn_thread_desc.GetStride(I1),
HoPerBlock, HoPerBlock,
HoPerThread, HoPerThread,
GemmKPerThreadLoop, GemmKPerThreadLoop,
...@@ -205,7 +207,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -205,7 +207,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
decltype(c_kxwn_thread_mtx_desc), decltype(c_kxwn_thread_mtx_desc),
0, 0,
in_chwn_block_desc.GetStride(I1), in_chwn_block_desc.GetStride(I1),
out_hkwn_thread_desc.GetStride(I0), out_khwn_thread_desc.GetStride(I1),
HoPerBlock, HoPerBlock,
GemmMPerThreadSubC, GemmMPerThreadSubC,
GemmNPerThreadSubC, GemmNPerThreadSubC,
...@@ -230,10 +232,10 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -230,10 +232,10 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
// register // register
Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()]; Float p_out_thread[out_khwn_thread_desc.GetElementSpace()];
// set threadwise output tensor to 0 // set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread);
const Float* p_in_global_block_begin = const Float* p_in_global_block_begin =
p_in_global + in_chwn_global_desc.Get1dIndex( p_in_global + in_chwn_global_desc.Get1dIndex(
...@@ -275,33 +277,30 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -275,33 +277,30 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N] // convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
#if 0 #if 0
// for v1 batch-gemm // for v1 batch-gemm
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
const unsigned k_thread_data_begin = c_thread_mtx_begin.row; const unsigned k_thread_data_begin = c_thread_mtx_begin.row;
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const unsigned n_thread_data_begin = c_thread_mtx_begin.col - wo_thread_data_begin * NPerBlock; const unsigned n_thread_data_begin = c_thread_mtx_begin.col - wo_thread_data_begin * NPerBlock;
constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{}; threadwise_4d_tensor_copy(
out_khwn_thread_desc,
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
out_hkwn_thread_desc,
p_out_thread, p_out_thread,
out_khwn_global_desc, out_khwn_global_desc,
p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_hkwn_thread_desc.GetLengths(), out_khwn_thread_desc.GetLengths());
reorder_khwn_from_hkwn);
#else #else
for(unsigned ho = 0; ho < out_hkwn_thread_desc.GetLength(I0); ++ho) for(unsigned k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{ {
for(unsigned k = 0; k < out_hkwn_thread_desc.GetLength(I1); ++k) for(unsigned ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
{ {
for(unsigned wo = 0; wo < out_hkwn_thread_desc.GetLength(I2); ++wo) for(unsigned wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
{ {
for(unsigned n = 0; n < out_hkwn_thread_desc.GetLength(I3); ++n) for(unsigned n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
{ {
const unsigned b = out_hkwn_thread_desc.Get1dIndex(0, 0, wo, n); const unsigned b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance = const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b); blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
...@@ -312,13 +311,13 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -312,13 +311,13 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
const unsigned b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col; const unsigned b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const unsigned wo_thread = b_thread / NPerBlock; const unsigned wo_thread = b_thread / NPerBlock;
const unsigned n_thread = b_thread - NPerBlock * wo_thread; const unsigned n_thread = b_thread % NPerBlock;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread, p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread, ho_block_data_begin + ho_thread,
wo_block_data_begin + wo_thread, wo_block_data_begin + wo_thread,
n_block_data_begin + n_thread)] = n_block_data_begin + n_thread)] =
p_out_thread[out_hkwn_thread_desc.Get1dIndex(ho, k, wo, n)]; p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment