Commit 42f4c7fd authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 6614729a
......@@ -77,8 +77,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned GemmRowThreadPerCluster = 4;
constexpr unsigned GemmColumnThreadPerCluster = 8;
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
......@@ -120,7 +120,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
#if 1
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw
#else
#elif 0
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
#endif
<GridSize,
......@@ -135,8 +135,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
BPerThread,
KPerThread,
CPerThread,
GemmRowThreadPerCluster,
GemmColumnThreadPerCluster,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1,
WeiBlockCopyThreadPerDim0,
......
......@@ -76,8 +76,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
constexpr unsigned KPerThread = 1;
constexpr unsigned CPerThread = 1;
constexpr unsigned GemmThreadPerClusterRow = 1;
constexpr unsigned GemmThreadPerClusterColumn = 4;
constexpr unsigned GemmThreadPerColumnPerCluster = 1;
constexpr unsigned GemmThreadPerRowPerCluster = 4;
constexpr unsigned BlockSize = 32;
#elif 0
......@@ -89,8 +89,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 1;
constexpr unsigned GemmThreadPerClusterRow = 4;
constexpr unsigned GemmThreadPerClusterColumn = 4;
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
constexpr unsigned GemmThreadPerRowPerCluster = 4;
constexpr unsigned BlockSize = 128;
#elif 0
......@@ -102,8 +102,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 1;
constexpr unsigned GemmRowThreadPerCluster = 4;
constexpr unsigned GemmColumnThreadPerCluster = 4;
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
constexpr unsigned GemmThreadPerRowPerCluster = 4;
constexpr unsigned InBlockCopyThreadPerDim0 = 2;
constexpr unsigned InBlockCopyThreadPerDim1 = 64;
......@@ -119,8 +119,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 2;
constexpr unsigned GemmRowThreadPerCluster = 8;
constexpr unsigned GemmColumnThreadPerCluster = 8;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 8;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
......@@ -171,8 +171,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
BPerThread,
KPerThread,
CPerThread,
GemmRowThreadPerCluster,
GemmColumnThreadPerCluster,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1>
<<<grid_dim, block_dim>>>(in_cnhw_desc,
......
......@@ -449,5 +449,37 @@ struct Blockwise2dTensorCopy3
assert(false);
}
}
if(has_tail_d0)
{
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
if(DataPerRead == 1)
{
p_dst[mDstMyThreadOffset + nloop_d0 * dst_loop_stride] =
p_src[mSrcMyThreadOffset + nloop_d0 * src_loop_stride];
}
else if(DataPerRead == 2)
{
*(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset +
nloop_d0 * dst_loop_stride)) =
*(reinterpret_cast<Float2*>(p_src + mSrcMyThreadOffset +
nloop_d0 * src_loop_stride));
}
else if(DataPerRead == 4)
{
*(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset +
nloop_d0 * dst_loop_stride)) =
*(reinterpret_cast<Float4*>(p_src + mSrcMyThreadOffset +
nloop_d0 * src_loop_stride));
}
else
{
assert(false);
}
}
}
}
};
......@@ -20,8 +20,8 @@ template <unsigned GridSize,
unsigned BPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned GemmThreadPerClusterRow,
unsigned GemmThreadPerClusterColumn,
unsigned GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerRowPerCluster,
unsigned InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1,
unsigned WeiBlockCopyThreadPerDim0,
......@@ -192,8 +192,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
false,
false,
CPerThread,
GemmThreadPerClusterRow,
GemmThreadPerClusterColumn,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
true>{};
// LDS
......
......@@ -20,8 +20,8 @@ template <unsigned GridSize,
unsigned BPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned GemmThreadPerClusterRow,
unsigned GemmThreadPerClusterColumn,
unsigned GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerRowPerCluster,
unsigned InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1,
unsigned WeiBlockCopyThreadPerDim0,
......@@ -192,8 +192,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
false,
false,
CPerThread,
GemmThreadPerClusterRow,
GemmThreadPerClusterColumn,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
true>{};
// LDS
......
......@@ -20,8 +20,8 @@ template <unsigned GridSize,
unsigned BPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned GemmThreadPerClusterRow,
unsigned GemmThreadPerClusterColumn,
unsigned GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerRowPerCluster,
unsigned InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1>
__global__ void
......@@ -159,8 +159,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
false,
false,
CPerThread,
GemmThreadPerClusterRow,
GemmThreadPerClusterColumn,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
true>{};
// LDS
......
......@@ -20,8 +20,8 @@ template <unsigned GridSize,
unsigned BPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned GemmRowThreadPerCluster,
unsigned GemmColumnThreadPerCluster,
unsigned GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerRowPerCluster,
unsigned InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1>
__global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline(
......@@ -175,8 +175,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
false,
false,
CPerThread,
GemmRowThreadPerCluster,
GemmColumnThreadPerCluster,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
true>{};
// LDS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment