"...composable_kernel-1.git" did not exist on "984b3722bfe45dcfecf040535c7e6a5d2c962c26"
Commit 8d460740 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 3a6044aa
...@@ -443,7 +443,7 @@ int main(int argc, char* argv[]) ...@@ -443,7 +443,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 0
// 3x3 filter, 28x28 image // 3x3 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 256;
......
...@@ -199,7 +199,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -199,7 +199,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
#else #else
constexpr auto map_k_e_2_e_k = Sequence<1, 0>{}; constexpr auto map_k_e_2_e_k = Sequence<1, 0>{};
auto blockwise_wei_copy = BlockwiseTensorSliceReorderCopy_v3< const auto blockwise_wei_copy = BlockwiseTensorSliceReorderCopy_v3<
BlockSize, BlockSize,
Float, Float,
decltype(wei_e_k_global_desc.ReorderGivenNew2Old(map_k_e_2_e_k)), decltype(wei_e_k_global_desc.ReorderGivenNew2Old(map_k_e_2_e_k)),
...@@ -296,9 +296,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -296,9 +296,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
} }
#endif #endif
#if 0 // debug const Float* p_wei_block_on_global = p_wei_global;
return;
#endif
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
...@@ -306,7 +304,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -306,7 +304,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global, p_wei_register_clipboard); blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double); blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
...@@ -339,14 +338,15 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -339,14 +338,15 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
#if 0 #if 0
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#else #else
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I1, Number<EPerBlock>{}, True); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
#endif #endif
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global, p_wei_register_clipboard); blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
run_blockwise_gemm(p_wei_block_now, p_in_block_now, p_out_thread); run_blockwise_gemm(p_wei_block_now, p_in_block_now, p_out_thread);
...@@ -369,14 +369,15 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -369,14 +369,15 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
#if 0 #if 0
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#else #else
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I1, Number<EPerBlock>{}, True); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
#endif #endif
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global, p_wei_register_clipboard); blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
run_blockwise_gemm(p_wei_block_double, p_in_block_double, p_out_thread); run_blockwise_gemm(p_wei_block_double, p_in_block_double, p_out_thread);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment