Commit 28354a0f authored by Chao Liu's avatar Chao Liu
Browse files

make LDS double buffer works, 1x1 conv now hits 80% of peak

parent 61ac0866
......@@ -614,7 +614,7 @@ int main()
nrepeat);
#endif
#if 1
#if 0
if(S == 3 && R == 3)
{
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
......
......@@ -128,7 +128,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
constexpr unsigned BlockSize = 64;
#elif 1
// 1x1, 28x28, 128 threads
// 1x1, 28x28, 128 threads, no lds-double-buffer
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
......@@ -215,37 +216,37 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
cudaEventCreate(&start);
cudaEventRecord(start, 0);
#if 1
#if 0
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw
#else
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer
#endif
<GridSize,
BlockSize,
T,
decltype(in_cnhw_desc),
decltype(wei_csrk_desc),
decltype(out_knhw_desc),
BPerBlock,
KPerBlock,
CPerBlock,
BPerThread,
KPerThread,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1,
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1,
InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead>
<GridSize,
BlockSize,
T,
decltype(in_cnhw_desc),
decltype(wei_csrk_desc),
decltype(out_knhw_desc),
BPerBlock,
KPerBlock,
CPerBlock,
BPerThread,
KPerThread,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1,
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1,
InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead>
<<<grid_dim, block_dim>>>(in_cnhw_desc,
static_cast<T*>(in_cnhw_device_buf.GetDeviceBuffer()),
wei_csrk_desc,
......
......@@ -512,4 +512,196 @@ struct Blockwise2dTensorCopy3
}
}
}
#if 1
__device__ constexpr unsigned GetRegisterClipboardSize() const
{
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
}
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
Float* p_clipboard) const
{
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
using Float2 = float2;
using Float4 = float4;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
{
if(DataPerRead == 1)
{
p_clipboard[iloop] = p_src[mSrcMyThreadOffset + iloop * src_loop_stride];
}
else if(DataPerRead == 2)
{
*(reinterpret_cast<Float2*>(p_clipboard + iloop * 2)) =
*(reinterpret_cast<const Float2*>(p_src + mSrcMyThreadOffset +
iloop * src_loop_stride));
}
else if(DataPerRead == 4)
{
*(reinterpret_cast<Float4*>(p_clipboard + iloop * 4)) =
*(reinterpret_cast<const Float4*>(p_src + mSrcMyThreadOffset +
iloop * src_loop_stride));
}
else
{
assert(false);
}
}
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
if(has_tail_d0)
{
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
if(DataPerRead == 1)
{
p_clipboard[nloop_d0] = p_src[mSrcMyThreadOffset + nloop_d0 * src_loop_stride];
}
else if(DataPerRead == 2)
{
*(reinterpret_cast<Float2*>(p_clipboard + nloop_d0 * 2)) =
*(reinterpret_cast<const Float2*>(p_src + mSrcMyThreadOffset +
nloop_d0 * src_loop_stride));
}
else if(DataPerRead == 4)
{
*(reinterpret_cast<Float4*>(p_clipboard + nloop_d0 * 4)) =
*(reinterpret_cast<const Float4*>(p_src + mSrcMyThreadOffset +
nloop_d0 * src_loop_stride));
}
else
{
assert(false);
}
}
}
}
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const
{
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
using Float2 = float2;
using Float4 = float4;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
{
if(DataPerRead == 1)
{
p_dst[mDstMyThreadOffset + iloop * dst_loop_stride] = p_clipboard[iloop];
}
else if(DataPerRead == 2)
{
*(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
*(reinterpret_cast<const Float2*>(p_clipboard + iloop * 2));
}
else if(DataPerRead == 4)
{
*(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
*(reinterpret_cast<const Float4*>(p_clipboard + iloop * 4));
}
else
{
assert(false);
}
}
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
if(has_tail_d0)
{
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
if(DataPerRead == 1)
{
p_dst[mDstMyThreadOffset + nloop_d0 * dst_loop_stride] = p_clipboard[nloop_d0];
}
else if(DataPerRead == 2)
{
*(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset +
nloop_d0 * dst_loop_stride)) =
*(reinterpret_cast<const Float2*>(p_clipboard + nloop_d0 * 2));
}
else if(DataPerRead == 4)
{
*(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset +
nloop_d0 * dst_loop_stride)) =
*(reinterpret_cast<const Float4*>(p_clipboard + nloop_d0 * 4));
}
else
{
assert(false);
}
}
}
}
#endif
};
......@@ -262,8 +262,26 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b
__syncthreads();
// load next data
#if 0
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next);
#elif 0
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
#elif 1
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
#endif
// compute on current data
// a series of GEMM
......@@ -283,6 +301,13 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b
f_accum);
}
}
#if 0
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next);
#elif 1
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next);
#endif
}
// last computation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment