Commit 6d066ede authored by Chao Liu's avatar Chao Liu
Browse files

added implicit gemm v1r3, refactored decomposition of wei tensor (loop over y,...

added implicit gemm v1r3, refactored decomposition of wei tensor (loop over y, x first, and C second) to allow easy lds double buffer on C
parent 5ce19234
......@@ -3,8 +3,10 @@
#include "device.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
......@@ -94,9 +96,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
......@@ -108,10 +110,10 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t BlockSize = 128;
#elif 1
#elif 0
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
constexpr index_t NPerBlock = 4;
constexpr index_t KPerBlock = 64;
......@@ -128,9 +130,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 1;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
......@@ -142,7 +144,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t BlockSize = 128;
#elif 0
......@@ -172,14 +174,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t InBlockCopyDataPerRead_N = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 4;
constexpr index_t BlockSize = 256;
#elif 0
// for 3x3, 56x56, v1, Pascal
// for 3x3, 56x56, v1r1, Pascal
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
......@@ -195,9 +197,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
......@@ -207,7 +209,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t BlockSize = 128;
#elif 0
......@@ -237,13 +239,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 4;
constexpr index_t BlockSize = 128;
#elif 1
#elif 0
// for 3x3, 28x28, v1r1, Pacal
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
......@@ -260,9 +262,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
......@@ -274,11 +276,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t BlockSize = 128;
#elif 1
#elif 0
// for 3x3, 28x28, v1r2, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
......@@ -290,13 +294,37 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 2, 4, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
#elif 1
// for 3x3, 28x28, v1r3, Pascal
// for 3x3, 14x14, v1r3, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
......@@ -308,11 +336,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t BlockSize = 128;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
#elif 0
// for 1x1, 28x28
// for 1x1, 28x28, v1r1, Pascal
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
......@@ -329,9 +360,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
......@@ -341,11 +372,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t BlockSize = 128;
#elif 1
// for 1x1, 14x14, Pascal
#elif 0
// for 1x1, 14x14, v1r1, Pascal
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
......@@ -369,10 +400,10 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t BlockSize = 128;
#endif
......@@ -386,12 +417,16 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
#if 1
#if 0
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
#elif 1
GridwiseConvolutionImplicitGemm_v1r1_lds_double_buffer_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
#endif
<GridSize,
BlockSize,
......@@ -417,13 +452,10 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
Sequence<InBlockCopy_ThreadPerDimC,
InBlockCopy_ThreadPerDimH,
InBlockCopy_ThreadPerDimW,
InBlockCopy_ThreadPerDimN>,
InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead,
OutThreadCopyDataPerWrite>{};
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerRead_N,
WeiBlockCopyDataPerRead_K,
OutThreadCopyDataPerWrite_N>{};
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
dim3(GridSize),
......
......@@ -87,13 +87,13 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 2>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>;
using InBlockReorderMapThreadCluster2SrcCluster = Sequence<1, 2, 3, 0>;
constexpr index_t InBlockReorderDataPerRead_W = 2;
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 2>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 2;
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = Sequence<4, 1, 32>;
using WeiBlockCopyClusterLengths_CXK = Sequence<4, 1, 32>;
constexpr index_t WeiBlockCopyDataPerRead_C = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
......@@ -137,10 +137,10 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
GemmDataPerReadB,
InBlockReorderSrcSubLengths_NCHW,
InBlockReorderSrcClusterLengths_NCHW,
InBlockReorderMapThreadCluster2SrcCluster,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
InBlockReorderDataPerRead_W,
InBlockReorderDataPerWrite_N,
WeiBlockCopyClusterLengths,
WeiBlockCopyClusterLengths_CXK,
WeiBlockCopyDataPerRead_C,
OutThreadCopyDataPerWrite_N>{};
......
......@@ -451,60 +451,6 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 58x58
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 3x3, 58x58
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 3x3 filter, 58x58 image, 0x0 padding
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3 filter, 56x56 image, 1x1 padding
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 3x3 filter, 28x28 image, 1x1 padding
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 1
// 3x3 filter, 28x28 image
constexpr index_t N = 128;
......@@ -578,31 +524,19 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 2;
constexpr index_t WPad = 2;
#elif 0
// 1x1 filter, 32x32 image
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 32;
constexpr index_t WI = 32;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
// 1x1 filter, 14x14 image, C = 2048
// 3x3 filter, 14x14 image
constexpr index_t N = 128;
constexpr index_t C = 2048;
constexpr index_t C = 256;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
// 1x1 filter, 14x14 image, C = 512
#elif 0
// 1x1 filter, 14x14 image
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 14;
......@@ -673,9 +607,9 @@ int main(int argc, char* argv[])
device_direct_convolution_2_nchw_kcyx_nkhw
#elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 0
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 1
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
......
......@@ -36,7 +36,7 @@ template <index_t GridSize,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead,
index_t OutThreadCopyDataPerWrite>
struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
struct GridwiseConvolutionImplicitGemm_v1r1_lds_double_buffer_chwn_cyxk_khwn
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
......
......@@ -213,6 +213,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
#if 1
const Float* p_in_global_block_offset =
p_in_global + in_c_h_w_n_global_desc.Get1dIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
......@@ -247,6 +248,43 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
__syncthreads();
}
}
#else
// this use much more register, haven't figure out why?
for(index_t y = 0; y < Y; ++y)
{
const Float* p_in_global_block_offset =
p_in_global +
in_c_h_w_n_global_desc.Get1dIndex(
0, hi_block_data_begin + y, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock,
p_in_global_block_offset +=
CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
p_wei_global_block_offset +=
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
{
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
__syncthreads();
for(index_t x = 0; x < X; ++x)
{
blockwise_batch_gemm.Run(p_wei_block + wei_c_x_k_block_desc.Get1dIndex(0, x, 0),
p_in_block +
in_c_h_w_n_block_desc.Get1dIndex(0, 0, x, 0),
p_out_thread);
}
__syncthreads();
}
}
#endif
// output: register to global mem,
const auto c_thread_mtx_begin =
......
......@@ -35,10 +35,10 @@ template <index_t GridSize,
index_t GemmDataPerReadB,
class InBlockReorderSrcSubLengths_NCHW,
class InBlockReorderSrcClusterLengths_NCHW,
class InBlockReorderMapThreadCluster2SrcCluster,
class InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
index_t InBlockReorderDataPerRead_W,
index_t InBlockReorderDataPerWrite_N,
class WeiBlockCopyClusterLengths_KXC,
class WeiBlockCopyClusterLengths_CXK,
index_t WeiBlockCopyDataPerRead_C,
index_t OutThreadCopyDataPerWrite_N>
struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
......@@ -122,7 +122,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
// blockwise copy
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
const auto blockwise_in_copy_reorder =
Blockwise4dTensorCopyReorder3<BlockSize,
Float,
......@@ -132,7 +133,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
InBlockReorderSrcSubLengths_NCHW,
InBlockReorderSrcClusterLengths_NCHW,
decltype(map_chwn2nchw),
InBlockReorderMapThreadCluster2SrcCluster,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
InBlockReorderDataPerRead_W,
InBlockReorderDataPerWrite_N>{};
......@@ -144,7 +145,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
decltype(wei_c_x_k_global_desc),
decltype(wei_c_x_k_block_desc),
decltype(wei_c_x_k_block_desc.GetLengths()),
Sequence<4, 1, 32>,
WeiBlockCopyClusterLengths_CXK,
WeiBlockCopyDataPerRead_C>{};
// a series of blockwise batched GEMM
......
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_batched_gemm.hip.hpp"
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopyClusterLengths_CHWN,
index_t InBlockCopyDataPerRead_N,
index_t WeiBlockCopyDataPerRead_K,
index_t OutThreadCopyDataPerWrite_N>
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// be careful of this assertion
static_assert(
NPerThread <= NPerBlock && NPerBlock % NPerThread == 0,
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{};
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
// divide block work: [K, Ho, Wo, N]
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup ");
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin;
// global tensor view
constexpr auto wei_c_k_global_desc =
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
// LDS tensor view
// be careful of alignment
constexpr index_t max_align = mod_conv::max(InBlockCopyDataPerRead_N,
WeiBlockCopyDataPerRead_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{}, Number<max_align>{});
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, KPerBlock>{}, Number<max_align>{});
// tensor view of threadwise output in register
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
#if 0
const auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyDataPerRead_N>{};
#else
const auto blockwise_in_copy =
Blockwise4dTensorCopy3<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerRead_N>{};
#endif
// blockwise wei copy
// format is [CPerBlock, X * KPerBlock]
const auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopyDataPerRead_K>{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
constexpr auto b_c_wn_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
Number<WoPerBlock * NPerBlock>{},
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
constexpr auto c_k_wn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{},
Number<WoPerThread * NPerThread>{},
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
const auto blockwise_batch_gemm =
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
BlockSize,
decltype(a_c_k_block_mtx_desc),
decltype(b_c_wn_block_mtx_desc),
decltype(c_k_wn_thread_mtx_desc),
0,
in_c_h_w_n_block_desc.GetStride(I1),
out_k_h_w_n_thread_desc.GetStride(I1),
HoPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t in_block_space =
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
__shared__ Float p_in_block[in_block_space];
__shared__ Float p_wei_block[wei_block_space];
// register
Float p_out_thread[out_k_h_w_n_thread_desc.GetElementSpace()];
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
}
#endif
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
#if 1
const Float* p_in_global_block_offset =
p_in_global + in_c_h_w_n_global_desc.Get1dIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
{
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
blockwise_in_copy.Run(p_in_global_block_offset +
in_c_h_w_n_global_desc.Get1dIndex(0, y, x, 0),
p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset +
wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, 0),
p_wei_block);
__syncthreads();
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
__syncthreads();
}
}
}
#else
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
const Float* p_in_global_block_offset =
p_in_global +
in_c_h_w_n_global_desc.Get1dIndex(
0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin);
const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock,
p_in_global_block_offset +=
CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
p_wei_global_block_offset +=
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
{
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
__syncthreads();
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
__syncthreads();
}
}
}
#endif
// output: register to global mem,
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin =
c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
// output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 =
(GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2), W1, W2, N / (N1 * N2), N1, N2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
threadwise_10d_tensor_copy(out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global + out_k_h_w_n_global_desc.Get1dIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{});
}
};
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_batched_gemm.hip.hpp"
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopyClusterLengths_CHWN,
index_t InBlockCopyDataPerRead_N,
index_t WeiBlockCopyDataPerRead_K,
index_t OutThreadCopyDataPerWrite_N>
struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// be careful of this assertion
static_assert(
NPerThread <= NPerBlock && NPerBlock % NPerThread == 0,
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{};
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
// divide block work: [K, Ho, Wo, N]
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup ");
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin;
// global tensor view
constexpr auto wei_c_k_global_desc =
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
// LDS tensor view
// be careful of alignment
constexpr index_t max_align = mod_conv::max(InBlockCopyDataPerRead_N,
WeiBlockCopyDataPerRead_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{}, Number<max_align>{});
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, KPerBlock>{}, Number<max_align>{});
// tensor view of threadwise output in register
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
#if 0
const auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyDataPerRead_N>{};
#else
const auto blockwise_in_copy =
Blockwise4dTensorCopy3<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerRead_N>{};
#endif
// blockwise wei copy
// format is [CPerBlock, X * KPerBlock]
const auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopyDataPerRead_K>{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
constexpr auto b_c_wn_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
Number<WoPerBlock * NPerBlock>{},
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
constexpr auto c_k_wn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{},
Number<WoPerThread * NPerThread>{},
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
const auto blockwise_batch_gemm =
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
BlockSize,
decltype(a_c_k_block_mtx_desc),
decltype(b_c_wn_block_mtx_desc),
decltype(c_k_wn_thread_mtx_desc),
0,
in_c_h_w_n_block_desc.GetStride(I1),
out_k_h_w_n_thread_desc.GetStride(I1),
HoPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t in_block_space =
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
__shared__ Float p_in_block[in_block_space];
__shared__ Float p_wei_block[wei_block_space];
// register
Float p_out_thread[out_k_h_w_n_thread_desc.GetElementSpace()];
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
}
#endif
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
const Float* p_in_global_block_offset =
p_in_global +
in_c_h_w_n_global_desc.Get1dIndex(
0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin);
const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock,
p_in_global_block_offset +=
CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
p_wei_global_block_offset +=
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
{
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
__syncthreads();
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
__syncthreads();
}
}
}
// output: register to global mem,
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin =
c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
// output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 =
(GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2), W1, W2, N / (N1 * N2), N1, N2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
threadwise_10d_tensor_copy(out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global + out_k_h_w_n_global_desc.Get1dIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{});
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment