"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "f66efa5d1d8806e6aacd551e3a36da567000bfa1"
Commit adc10088 authored by Chao Liu's avatar Chao Liu
Browse files

tweak

parent 4a1e97cf
...@@ -233,8 +233,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -233,8 +233,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread); threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
const Float* p_wei_block_on_global = p_wei_global;
for(index_t e_block_data_begin = 0; e_block_data_begin < E; e_block_data_begin += EPerBlock) for(index_t e_block_data_begin = 0; e_block_data_begin < E; e_block_data_begin += EPerBlock)
{ {
blockwise_in_copy.Run(p_in_global, p_in_block); blockwise_in_copy.Run(p_in_global, p_in_block);
...@@ -246,8 +244,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -246,8 +244,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
__syncthreads(); __syncthreads();
blockwise_in_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
} }
// copy output: register to global memory // copy output: register to global memory
...@@ -304,8 +302,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -304,8 +302,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
{ {
threadwise_out_copy.Run(p_out_thread, p_out_global); threadwise_out_copy.Run(p_out_thread, p_out_global);
threadwise_out_copy.MoveSrcSlicingWindow({0, 0, GemmNPerThreadSubC}, true); threadwise_out_copy.MoveSrcSlicingWindow(Sequence<0, 0, GemmNPerThreadSubC>{},
threadwise_out_copy.MoveDstSlicingWindow({0, 0, B1}, true); True);
threadwise_out_copy.MoveDstSlicingWindow(Sequence<0, 0, B1>{}, True);
} }
} }
} }
......
...@@ -233,8 +233,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -233,8 +233,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread); threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
blockwise_in_copy.Run(p_in_global, p_in_block_double); blockwise_in_copy.Run(p_in_global, p_in_block_double);
...@@ -263,15 +261,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -263,15 +261,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global, p_wei_register_buffer);
p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
...@@ -288,14 +285,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -288,14 +285,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration // even iteration
blockwise_in_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer); blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
...@@ -369,8 +366,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -369,8 +366,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
{ {
threadwise_out_copy.Run(p_out_thread, p_out_global); threadwise_out_copy.Run(p_out_thread, p_out_global);
threadwise_out_copy.MoveSrcSlicingWindow({0, 0, GemmNPerThreadSubC}, true); threadwise_out_copy.MoveSrcSlicingWindow(Sequence<0, 0, GemmNPerThreadSubC>{},
threadwise_out_copy.MoveDstSlicingWindow({0, 0, B1}, true); True);
threadwise_out_copy.MoveDstSlicingWindow(Sequence<0, 0, B1>{}, True);
} }
} }
} }
......
...@@ -447,17 +447,19 @@ struct BlockwiseGenericTensorSliceCopy_v2 ...@@ -447,17 +447,19 @@ struct BlockwiseGenericTensorSliceCopy_v2
} }
template <class T, bool PositiveDirection> template <class T, bool PositiveDirection>
__device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>) __device__ void
MoveSrcSlicingWindow(T step_sizes,
integral_constant<bool, PositiveDirection> positive_direction)
{ {
mThreadwiseLoad.MoveSrcSlicingWindow(step_sizes, mThreadwiseLoad.MoveSrcSlicingWindow(step_sizes, positive_direction);
integral_constant<bool, PositiveDirection>{});
} }
template <class T, bool PositiveDirection> template <class T, bool PositiveDirection>
__device__ void MoveDstSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>) __device__ void
MoveDstSlicingWindow(T step_sizes,
integral_constant<bool, PositiveDirection> positive_direction)
{ {
mThreadwiseLoad.MoveDstSlicingWindow(step_sizes, mThreadwiseLoad.MoveDstSlicingWindow(step_sizes, positive_direction);
integral_constant<bool, PositiveDirection>{});
} }
private: private:
......
...@@ -227,9 +227,9 @@ struct ThreadwiseGenericTensorSliceCopy_v2 ...@@ -227,9 +227,9 @@ struct ThreadwiseGenericTensorSliceCopy_v2
template <class T, bool PositiveDirection> template <class T, bool PositiveDirection>
__device__ void MoveDstSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>) __device__ void MoveDstSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
{ {
static_if<PositiveDirection>([&](auto) { mDstSliceOrigin += step_sizes; }).Else([&](auto) { static_if<PositiveDirection>{}([&](auto) {
mDstSliceOrigin -= step_sizes; mDstSliceOrigin += step_sizes;
}); }).Else([&](auto) { mDstSliceOrigin -= step_sizes; });
} }
// private: // private:
......
...@@ -132,7 +132,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -132,7 +132,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = constexpr auto gridwise_conv =
#if 0 #if 1
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
#else #else
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
......
...@@ -379,7 +379,7 @@ int main(int argc, char* argv[]) ...@@ -379,7 +379,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif 1 #elif 0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment