Commit 42f4c7fd authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 6614729a
...@@ -77,8 +77,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, ...@@ -77,8 +77,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
constexpr unsigned KPerThread = 16; constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr unsigned CPerThread = 1;
constexpr unsigned GemmRowThreadPerCluster = 4; constexpr unsigned GemmThreadPerColumnPerCluster = 4;
constexpr unsigned GemmColumnThreadPerCluster = 8; constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4; constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16; constexpr unsigned InBlockCopyThreadPerDim1 = 16;
...@@ -120,7 +120,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, ...@@ -120,7 +120,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
#if 1 #if 1
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw
#else #elif 0
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
#endif #endif
<GridSize, <GridSize,
...@@ -135,8 +135,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, ...@@ -135,8 +135,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
BPerThread, BPerThread,
KPerThread, KPerThread,
CPerThread, CPerThread,
GemmRowThreadPerCluster, GemmThreadPerColumnPerCluster,
GemmColumnThreadPerCluster, GemmThreadPerRowPerCluster,
InBlockCopyThreadPerDim0, InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1, InBlockCopyThreadPerDim1,
WeiBlockCopyThreadPerDim0, WeiBlockCopyThreadPerDim0,
......
...@@ -76,8 +76,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, ...@@ -76,8 +76,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
constexpr unsigned KPerThread = 1; constexpr unsigned KPerThread = 1;
constexpr unsigned CPerThread = 1; constexpr unsigned CPerThread = 1;
constexpr unsigned GemmThreadPerClusterRow = 1; constexpr unsigned GemmThreadPerColumnPerCluster = 1;
constexpr unsigned GemmThreadPerClusterColumn = 4; constexpr unsigned GemmThreadPerRowPerCluster = 4;
constexpr unsigned BlockSize = 32; constexpr unsigned BlockSize = 32;
#elif 0 #elif 0
...@@ -89,8 +89,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, ...@@ -89,8 +89,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
constexpr unsigned KPerThread = 8; constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 1; constexpr unsigned CPerThread = 1;
constexpr unsigned GemmThreadPerClusterRow = 4; constexpr unsigned GemmThreadPerColumnPerCluster = 4;
constexpr unsigned GemmThreadPerClusterColumn = 4; constexpr unsigned GemmThreadPerRowPerCluster = 4;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#elif 0 #elif 0
...@@ -102,8 +102,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, ...@@ -102,8 +102,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
constexpr unsigned KPerThread = 8; constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 1; constexpr unsigned CPerThread = 1;
constexpr unsigned GemmRowThreadPerCluster = 4; constexpr unsigned GemmThreadPerColumnPerCluster = 4;
constexpr unsigned GemmColumnThreadPerCluster = 4; constexpr unsigned GemmThreadPerRowPerCluster = 4;
constexpr unsigned InBlockCopyThreadPerDim0 = 2; constexpr unsigned InBlockCopyThreadPerDim0 = 2;
constexpr unsigned InBlockCopyThreadPerDim1 = 64; constexpr unsigned InBlockCopyThreadPerDim1 = 64;
...@@ -119,8 +119,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, ...@@ -119,8 +119,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
constexpr unsigned KPerThread = 16; constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 2; constexpr unsigned CPerThread = 2;
constexpr unsigned GemmRowThreadPerCluster = 8; constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmColumnThreadPerCluster = 8; constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 8; constexpr unsigned InBlockCopyThreadPerDim0 = 8;
constexpr unsigned InBlockCopyThreadPerDim1 = 16; constexpr unsigned InBlockCopyThreadPerDim1 = 16;
...@@ -171,8 +171,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, ...@@ -171,8 +171,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
BPerThread, BPerThread,
KPerThread, KPerThread,
CPerThread, CPerThread,
GemmRowThreadPerCluster, GemmThreadPerColumnPerCluster,
GemmColumnThreadPerCluster, GemmThreadPerRowPerCluster,
InBlockCopyThreadPerDim0, InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1> InBlockCopyThreadPerDim1>
<<<grid_dim, block_dim>>>(in_cnhw_desc, <<<grid_dim, block_dim>>>(in_cnhw_desc,
......
...@@ -449,5 +449,37 @@ struct Blockwise2dTensorCopy3 ...@@ -449,5 +449,37 @@ struct Blockwise2dTensorCopy3
assert(false); assert(false);
} }
} }
if(has_tail_d0)
{
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
if(DataPerRead == 1)
{
p_dst[mDstMyThreadOffset + nloop_d0 * dst_loop_stride] =
p_src[mSrcMyThreadOffset + nloop_d0 * src_loop_stride];
}
else if(DataPerRead == 2)
{
*(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset +
nloop_d0 * dst_loop_stride)) =
*(reinterpret_cast<Float2*>(p_src + mSrcMyThreadOffset +
nloop_d0 * src_loop_stride));
}
else if(DataPerRead == 4)
{
*(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset +
nloop_d0 * dst_loop_stride)) =
*(reinterpret_cast<Float4*>(p_src + mSrcMyThreadOffset +
nloop_d0 * src_loop_stride));
}
else
{
assert(false);
}
}
}
} }
}; };
...@@ -20,8 +20,8 @@ template <unsigned GridSize, ...@@ -20,8 +20,8 @@ template <unsigned GridSize,
unsigned BPerThread, unsigned BPerThread,
unsigned KPerThread, unsigned KPerThread,
unsigned CPerThread, unsigned CPerThread,
unsigned GemmThreadPerClusterRow, unsigned GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerClusterColumn, unsigned GemmThreadPerRowPerCluster,
unsigned InBlockCopyThreadPerDim0, unsigned InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1, unsigned InBlockCopyThreadPerDim1,
unsigned WeiBlockCopyThreadPerDim0, unsigned WeiBlockCopyThreadPerDim0,
...@@ -192,8 +192,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc, ...@@ -192,8 +192,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
false, false,
false, false,
CPerThread, CPerThread,
GemmThreadPerClusterRow, GemmThreadPerColumnPerCluster,
GemmThreadPerClusterColumn, GemmThreadPerRowPerCluster,
true>{}; true>{};
// LDS // LDS
......
...@@ -20,8 +20,8 @@ template <unsigned GridSize, ...@@ -20,8 +20,8 @@ template <unsigned GridSize,
unsigned BPerThread, unsigned BPerThread,
unsigned KPerThread, unsigned KPerThread,
unsigned CPerThread, unsigned CPerThread,
unsigned GemmThreadPerClusterRow, unsigned GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerClusterColumn, unsigned GemmThreadPerRowPerCluster,
unsigned InBlockCopyThreadPerDim0, unsigned InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1, unsigned InBlockCopyThreadPerDim1,
unsigned WeiBlockCopyThreadPerDim0, unsigned WeiBlockCopyThreadPerDim0,
...@@ -192,8 +192,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline ...@@ -192,8 +192,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
false, false,
false, false,
CPerThread, CPerThread,
GemmThreadPerClusterRow, GemmThreadPerColumnPerCluster,
GemmThreadPerClusterColumn, GemmThreadPerRowPerCluster,
true>{}; true>{};
// LDS // LDS
......
...@@ -20,8 +20,8 @@ template <unsigned GridSize, ...@@ -20,8 +20,8 @@ template <unsigned GridSize,
unsigned BPerThread, unsigned BPerThread,
unsigned KPerThread, unsigned KPerThread,
unsigned CPerThread, unsigned CPerThread,
unsigned GemmThreadPerClusterRow, unsigned GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerClusterColumn, unsigned GemmThreadPerRowPerCluster,
unsigned InBlockCopyThreadPerDim0, unsigned InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1> unsigned InBlockCopyThreadPerDim1>
__global__ void __global__ void
...@@ -159,8 +159,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, ...@@ -159,8 +159,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
false, false,
false, false,
CPerThread, CPerThread,
GemmThreadPerClusterRow, GemmThreadPerColumnPerCluster,
GemmThreadPerClusterColumn, GemmThreadPerRowPerCluster,
true>{}; true>{};
// LDS // LDS
......
...@@ -20,8 +20,8 @@ template <unsigned GridSize, ...@@ -20,8 +20,8 @@ template <unsigned GridSize,
unsigned BPerThread, unsigned BPerThread,
unsigned KPerThread, unsigned KPerThread,
unsigned CPerThread, unsigned CPerThread,
unsigned GemmRowThreadPerCluster, unsigned GemmThreadPerColumnPerCluster,
unsigned GemmColumnThreadPerCluster, unsigned GemmThreadPerRowPerCluster,
unsigned InBlockCopyThreadPerDim0, unsigned InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1> unsigned InBlockCopyThreadPerDim1>
__global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline( __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline(
...@@ -175,8 +175,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline ...@@ -175,8 +175,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
false, false,
false, false,
CPerThread, CPerThread,
GemmRowThreadPerCluster, GemmThreadPerColumnPerCluster,
GemmColumnThreadPerCluster, GemmThreadPerRowPerCluster,
true>{}; true>{};
// LDS // LDS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment