Commit 8d460740 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 3a6044aa
......@@ -443,7 +443,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
#elif 0
// 3x3 filter, 28x28 image
constexpr index_t N = 128;
constexpr index_t C = 256;
......
......@@ -199,7 +199,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
#else
constexpr auto map_k_e_2_e_k = Sequence<1, 0>{};
auto blockwise_wei_copy = BlockwiseTensorSliceReorderCopy_v3<
const auto blockwise_wei_copy = BlockwiseTensorSliceReorderCopy_v3<
BlockSize,
Float,
decltype(wei_e_k_global_desc.ReorderGivenNew2Old(map_k_e_2_e_k)),
......@@ -296,9 +296,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
}
#endif
#if 0 // debug
return;
#endif
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS
{
......@@ -306,7 +304,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global, p_wei_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
......@@ -339,14 +338,15 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
#if 0
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#else
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I1, Number<EPerBlock>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
#endif
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global, p_wei_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data
run_blockwise_gemm(p_wei_block_now, p_in_block_now, p_out_thread);
......@@ -369,14 +369,15 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
#if 0
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#else
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I1, Number<EPerBlock>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
#endif
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global, p_wei_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data
run_blockwise_gemm(p_wei_block_double, p_in_block_double, p_out_thread);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment