Commit 1bd880a6 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 796f72e2
...@@ -77,7 +77,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -77,7 +77,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data()); out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 0 #if 1
// for 3x3, 34x34 // for 3x3, 34x34
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64; constexpr index_t KPerBlock = 64;
...@@ -105,6 +105,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -105,6 +105,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2; constexpr index_t OutThreadCopyDataPerWrite = 2;
...@@ -145,7 +147,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -145,7 +147,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 3x3 58x58, NKC = 64, 64, 256 // 3x3 58x58
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4; constexpr index_t CPerBlock = 4;
...@@ -166,21 +168,6 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -166,21 +168,6 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 1
// for 7x7, 38x38 // for 7x7, 38x38
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
...@@ -210,9 +197,42 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -210,9 +197,42 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4; constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 56x56, v1, Pacal
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1 #elif 1
// for 3x3, 56x56 // for 3x3, 56x56, v1r2, Pascal
// for 3x3, 34x34, v1r2, Pascal
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8; constexpr index_t CPerBlock = 8;
...@@ -231,6 +251,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -231,6 +251,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2; constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 1;
constexpr index_t GemmDataPerReadB = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 2; constexpr index_t InBlockCopy_ThreadPerDimC = 2;
constexpr index_t InBlockCopy_ThreadPerDimH = 4; constexpr index_t InBlockCopy_ThreadPerDimH = 4;
...@@ -317,7 +339,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -317,7 +339,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
constexpr auto gridwise_conv = constexpr auto gridwise_conv =
#if 0 #if 1
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
#elif 1 #elif 1
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...@@ -346,6 +368,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -346,6 +368,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
Sequence<InBlockCopy_ThreadPerDimC, Sequence<InBlockCopy_ThreadPerDimC,
InBlockCopy_ThreadPerDimH, InBlockCopy_ThreadPerDimH,
InBlockCopy_ThreadPerDimW, InBlockCopy_ThreadPerDimW,
......
...@@ -205,6 +205,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -205,6 +205,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t InBlockCopyThreadPerDim0 = 4; constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16; constexpr index_t InBlockCopyThreadPerDim1 = 16;
...@@ -233,6 +235,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -233,6 +235,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t InBlockCopyThreadPerDim0 = 4; constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16; constexpr index_t InBlockCopyThreadPerDim1 = 16;
...@@ -289,6 +293,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -289,6 +293,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopyThreadPerDim0, InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1, InBlockCopyThreadPerDim1,
WeiBlockCopyThreadPerDim0, WeiBlockCopyThreadPerDim0,
......
...@@ -409,7 +409,7 @@ int main(int argc, char* argv[]) ...@@ -409,7 +409,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -454,7 +454,7 @@ int main(int argc, char* argv[]) ...@@ -454,7 +454,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 0
// 7x7, 38x38 // 7x7, 38x38
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
......
...@@ -16,7 +16,9 @@ template <index_t BlockSize, ...@@ -16,7 +16,9 @@ template <index_t BlockSize,
index_t MLevel1Cluster, index_t MLevel1Cluster,
index_t NLevel1Cluster, index_t NLevel1Cluster,
index_t KPerThreadLoop, index_t KPerThreadLoop,
index_t BatchPerThread> index_t BatchPerThread,
index_t DataPerReadA,
index_t DataPerReadB>
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{ {
index_t mMyThreadOffsetA = 0; index_t mMyThreadOffsetA = 0;
...@@ -220,7 +222,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -220,7 +222,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
mMyThreadOffsetA, mMyThreadOffsetA,
a_thread_mtx, a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
} }
// copy B-sub to form B // copy B-sub to form B
...@@ -233,7 +236,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -233,7 +236,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
mMyThreadOffsetB, mMyThreadOffsetB,
b_thread_mtx, b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths()); b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
} }
// loop over batch // loop over batch
...@@ -264,7 +268,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -264,7 +268,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
(ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA, (ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA,
a_thread_mtx, a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
} }
} }
...@@ -280,7 +285,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -280,7 +285,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
(ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB, (ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB,
b_thread_mtx, b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths()); b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
} }
} }
} }
......
...@@ -14,7 +14,9 @@ template <index_t BlockSize, ...@@ -14,7 +14,9 @@ template <index_t BlockSize,
index_t NLevel0Cluster, index_t NLevel0Cluster,
index_t MLevel1Cluster, index_t MLevel1Cluster,
index_t NLevel1Cluster, index_t NLevel1Cluster,
index_t KPerThreadLoop> index_t KPerThreadLoop,
index_t DataPerReadA,
index_t DataPerReadB>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{ {
struct MatrixIndex struct MatrixIndex
...@@ -276,7 +278,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -276,7 +278,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
mMyThreadOffsetA, mMyThreadOffsetA,
a_thread_mtx, a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
} }
#pragma unroll #pragma unroll
...@@ -289,7 +292,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -289,7 +292,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
mMyThreadOffsetB, mMyThreadOffsetB,
b_thread_mtx, b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths()); b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
} }
// C = A * B // C = A * B
...@@ -359,7 +363,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -359,7 +363,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster, p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx, a_thread_sub_mtx,
p_a_thread_0 + m_repeat * MPerThreadSubC, p_a_thread_0 + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
} }
#pragma unroll #pragma unroll
...@@ -369,7 +374,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -369,7 +374,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster, p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx, b_thread_sub_mtx,
p_b_thread_0 + n_repeat * NPerThreadSubC, p_b_thread_0 + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths()); b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
} }
bool even_loop = true; bool even_loop = true;
...@@ -394,7 +400,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -394,7 +400,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
m_repeat * MPerLevel1Cluster, m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx, a_thread_sub_mtx,
p_a_thread_next + m_repeat * MPerThreadSubC, p_a_thread_next + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
} }
#pragma unroll #pragma unroll
...@@ -406,7 +413,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -406,7 +413,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
n_repeat * NPerLevel1Cluster, n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx, b_thread_sub_mtx,
p_b_thread_next + n_repeat * NPerThreadSubC, p_b_thread_next + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths()); b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
} }
// C = A * B // C = A * B
......
...@@ -30,6 +30,8 @@ template <index_t GridSize, ...@@ -30,6 +30,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopyThreadPerDims, class InBlockCopyThreadPerDims,
index_t InBlockCopyDataPerRead, index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead, index_t WeiBlockCopyDataPerRead,
...@@ -169,7 +171,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn ...@@ -169,7 +171,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, GemmKPerThreadLoop,
HoPerThread>{}; HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{}); constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
......
...@@ -30,6 +30,8 @@ template <index_t GridSize, ...@@ -30,6 +30,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopyThreadPerDims, class InBlockCopyThreadPerDims,
index_t InBlockCopyDataPerRead, index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead, index_t WeiBlockCopyDataPerRead,
...@@ -172,7 +174,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer ...@@ -172,7 +174,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, GemmKPerThreadLoop,
HoPerThread>{}; HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{}); constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
......
...@@ -30,6 +30,8 @@ template <index_t GridSize, ...@@ -30,6 +30,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopyThreadPerDims, class InBlockCopyThreadPerDims,
index_t InBlockCopyDataPerRead, index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead, index_t WeiBlockCopyDataPerRead,
...@@ -173,7 +175,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -173,7 +175,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, GemmKPerThreadLoop,
HoPerThread>{}; HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{}); constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
...@@ -185,7 +189,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -185,7 +189,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
// register // register
Float p_out_thread[out_khwn_thread_desc.GetElementSpace()]; Float p_out_thread[out_khwn_thread_desc.GetElementSpace()];
#if 1 #if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{ {
print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc"); print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc");
......
...@@ -26,6 +26,8 @@ template <index_t GridSize, ...@@ -26,6 +26,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
index_t InBlockCopyThreadPerDim0, index_t InBlockCopyThreadPerDim0,
index_t InBlockCopyThreadPerDim1, index_t InBlockCopyThreadPerDim1,
index_t WeiBlockCopyThreadPerDim0, index_t WeiBlockCopyThreadPerDim0,
...@@ -174,7 +176,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn ...@@ -174,7 +176,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop>{}; GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr index_t max_align = constexpr index_t max_align =
......
...@@ -27,6 +27,8 @@ template <index_t GridSize, ...@@ -27,6 +27,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
index_t InBlockCopyThreadPerDim0, index_t InBlockCopyThreadPerDim0,
index_t InBlockCopyThreadPerDim1, index_t InBlockCopyThreadPerDim1,
index_t WeiBlockCopyThreadPerDim0, index_t WeiBlockCopyThreadPerDim0,
...@@ -178,7 +180,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -178,7 +180,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop>{}; GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr index_t max_align = constexpr index_t max_align =
......
#pragma once #pragma once
template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol> template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol, index_t DataPerRead>
__device__ void threadwise_matrix_copy(SrcMatrix, __device__ void threadwise_matrix_copy(SrcMatrix,
const Float* __restrict__ p_src, const Float* __restrict__ p_src,
DstMatrix, DstMatrix,
Float* __restrict__ p_dst, Float* __restrict__ p_dst,
Sequence<NRow, NCol>) Sequence<NRow, NCol>,
Number<DataPerRead>)
{ {
static_assert(NCol % DataPerRead == 0, "wrong! should be NCol % == DataPerRead == 0");
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
constexpr auto src_mtx = SrcMatrix{}; constexpr auto src_mtx = SrcMatrix{};
constexpr auto dst_mtx = DstMatrix{}; constexpr auto dst_mtx = DstMatrix{};
for(index_t i = 0; i < NRow; ++i) for(index_t i = 0; i < NRow; ++i)
{ {
for(index_t j = 0; j < NCol; ++j) for(index_t j = 0; j < NCol; j += DataPerRead)
{ {
const index_t src_index = src_mtx.Get1dIndex(i, j); const index_t src_index = src_mtx.Get1dIndex(i, j);
const index_t dst_index = dst_mtx.Get1dIndex(i, j); const index_t dst_index = dst_mtx.Get1dIndex(i, j);
p_dst[dst_index] = p_src[src_index]; *reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment