Commit 50b96745 authored by Chao Liu's avatar Chao Liu
Browse files

gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn use khwn for thread C data now

parent 1cb98850
......@@ -200,7 +200,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 1
#elif 0
// for 1x1, 28x28
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128;
......
......@@ -104,8 +104,8 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
Sequence<CPerBlock, S, R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
// tensor view of threadwise output in register
constexpr auto out_hkwn_thread_desc =
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{});
constexpr auto out_khwn_thread_desc =
make_ConstantTensorDescriptor(Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
......@@ -179,7 +179,9 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
Number<in_chwn_block_desc.GetStride(I0)>{});
constexpr auto c_kxwn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<WoPerThread * NPerThread>{});
make_ConstantMatrixDescriptor(Number<KPerThread>{},
Number<WoPerThread * NPerThread>{},
Number<out_khwn_thread_desc.GetStride(I1)>{});
#if 0
const auto blockwise_batch_gemm =
......@@ -192,7 +194,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
false,
0,
in_chwn_block_desc.GetStride(I1),
out_hkwn_thread_desc.GetStride(I0),
out_khwn_thread_desc.GetStride(I1),
HoPerBlock,
HoPerThread,
GemmKPerThreadLoop,
......@@ -205,7 +207,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
decltype(c_kxwn_thread_mtx_desc),
0,
in_chwn_block_desc.GetStride(I1),
out_hkwn_thread_desc.GetStride(I0),
out_khwn_thread_desc.GetStride(I1),
HoPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
......@@ -230,10 +232,10 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
// register
Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()];
Float p_out_thread[out_khwn_thread_desc.GetElementSpace()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread);
threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread);
const Float* p_in_global_block_begin =
p_in_global + in_chwn_global_desc.Get1dIndex(
......@@ -275,33 +277,30 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
#if 0
// for v1 batch-gemm
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
const unsigned k_thread_data_begin = c_thread_mtx_begin.row;
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const unsigned n_thread_data_begin = c_thread_mtx_begin.col - wo_thread_data_begin * NPerBlock;
constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{};
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
out_hkwn_thread_desc,
threadwise_4d_tensor_copy(
out_khwn_thread_desc,
p_out_thread,
out_khwn_global_desc,
p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_hkwn_thread_desc.GetLengths(),
reorder_khwn_from_hkwn);
out_khwn_thread_desc.GetLengths());
#else
for(unsigned ho = 0; ho < out_hkwn_thread_desc.GetLength(I0); ++ho)
for(unsigned k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{
for(unsigned k = 0; k < out_hkwn_thread_desc.GetLength(I1); ++k)
for(unsigned ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
{
for(unsigned wo = 0; wo < out_hkwn_thread_desc.GetLength(I2); ++wo)
for(unsigned wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
{
for(unsigned n = 0; n < out_hkwn_thread_desc.GetLength(I3); ++n)
for(unsigned n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
{
const unsigned b = out_hkwn_thread_desc.Get1dIndex(0, 0, wo, n);
const unsigned b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
......@@ -312,13 +311,13 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
const unsigned b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const unsigned wo_thread = b_thread / NPerBlock;
const unsigned n_thread = b_thread - NPerBlock * wo_thread;
const unsigned n_thread = b_thread % NPerBlock;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread,
wo_block_data_begin + wo_thread,
n_block_data_begin + n_thread)] =
p_out_thread[out_hkwn_thread_desc.Get1dIndex(ho, k, wo, n)];
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment