Commit b30edb4c authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent d21e9beb
...@@ -196,7 +196,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -196,7 +196,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel( float time = launch_kernel(
#if 1 #if 0
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
#else #else
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
......
...@@ -595,7 +595,7 @@ int main(int argc, char* argv[]) ...@@ -595,7 +595,7 @@ int main(int argc, char* argv[])
#elif 1 #elif 1
// 1x1 filter, 14x14 image, C = 512 // 1x1 filter, 14x14 image, C = 512
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 512; constexpr index_t C = 2048;
constexpr index_t HI = 14; constexpr index_t HI = 14;
constexpr index_t WI = 14; constexpr index_t WI = 14;
constexpr index_t K = 512; constexpr index_t K = 512;
......
...@@ -408,7 +408,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -408,7 +408,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_lds_begin); p_lds_begin);
} }
#if 1 #if 0
asm volatile("\n \ asm volatile("\n \
s_waitcnt lgkmcnt(0) \n \ s_waitcnt lgkmcnt(0) \n \
" ::); " ::);
......
...@@ -213,9 +213,7 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric ...@@ -213,9 +213,7 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 1 #if 1
blockwise_gemm.Run blockwise_gemm.Run
#elif 0 #elif 1
blockwise_gemm.Run_asm
#elif 0
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
#endif #endif
(p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
......
...@@ -19,8 +19,6 @@ template <index_t GridSize, ...@@ -19,8 +19,6 @@ template <index_t GridSize,
index_t CPerBlock, index_t CPerBlock,
index_t BPerThread, index_t BPerThread,
index_t KPerThread, index_t KPerThread,
index_t GemmThreadPerColumnPerCluster,
index_t GemmThreadPerRowPerCluster,
index_t GemmMPerThreadSubC, index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC, index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
...@@ -28,17 +26,9 @@ template <index_t GridSize, ...@@ -28,17 +26,9 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t InBlockCopyThreadPerDim0,
index_t InBlockCopyThreadPerDim1,
index_t WeiBlockCopyThreadPerDim0,
index_t WeiBlockCopyThreadPerDim1,
index_t InBlockCopyDataPerRead, index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead> index_t WeiBlockCopyDataPerRead>
__global__ void __global__ void gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
#if 0
__launch_bounds__(256,2)
#endif
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) Float* const __restrict__ p_out_global)
...@@ -115,57 +105,23 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( ...@@ -115,57 +105,23 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
} }
#endif #endif
// blockwise in copy // blockwise in copy
// formmat is [CPerBlock,BPerBlock + BGhostRead] // formmat is [CPerBlock,BPerBlock + BGhostRead]
#if 0
const auto blockwise_in_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths()),
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_in_copy = Blockwise2dTensorCopy3<BlockSize, const auto blockwise_in_copy = Blockwise2dTensorCopy3<BlockSize,
Float, Float,
decltype(in_cb_global_desc), decltype(in_cb_global_desc),
decltype(in_cb_block_desc), decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths()), decltype(in_cb_block_desc.GetLengths()),
InBlockCopyDataPerRead>{}; InBlockCopyDataPerRead>{};
#endif
// blockwise wei copy // blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock] // format is [CPerBlock*Y*X,KPerBlock]
#if 0
const auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize, const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize,
Float, Float,
decltype(wei_ek_global_desc), decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc), decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()), decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{}; WeiBlockCopyDataPerRead>{};
#endif
// a series of blockwise GEMM // a series of blockwise GEMM
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -182,19 +138,6 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( ...@@ -182,19 +138,6 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
constexpr auto c_kxb_thread_mtx_desc = constexpr auto c_kxb_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<BPerThread>{}); make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<BPerThread>{});
#if 0
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
decltype(c_kxb_thread_mtx_desc),
true,
false,
false,
GemmKPerThreadLoop,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
true>{};
#else
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<BlockSize, BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<BlockSize,
decltype(a_cxk_block_mtx_desc), decltype(a_cxk_block_mtx_desc),
...@@ -207,7 +150,6 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( ...@@ -207,7 +150,6 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop>{}; GemmKPerThreadLoop>{};
#endif
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr index_t in_block_size = constexpr index_t in_block_size =
...@@ -216,9 +158,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( ...@@ -216,9 +158,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
constexpr index_t wei_block_size = constexpr index_t wei_block_size =
wei_cyxk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{}); wei_cyxk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead constexpr index_t max_align =
? InBlockCopyDataPerRead mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
: WeiBlockCopyDataPerRead;
// LDS double buffer // LDS double buffer
__shared__ Float p_in_block_0[max_align * ((in_block_size + max_align - 1) / max_align)]; __shared__ Float p_in_block_0[max_align * ((in_block_size + max_align - 1) / max_align)];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment