Commit 1bd880a6 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 796f72e2
......@@ -77,7 +77,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 0
#if 1
// for 3x3, 34x34
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
......@@ -105,6 +105,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
......@@ -145,7 +147,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 64, 64, 256
// 3x3 58x58
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
......@@ -166,21 +168,6 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 1
// for 7x7, 38x38
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
......@@ -210,9 +197,42 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 56x56, v1, Pacal
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif 1
// for 3x3, 56x56
// for 3x3, 56x56, v1r2, Pascal
// for 3x3, 34x34, v1r2, Pascal
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
......@@ -231,6 +251,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 1;
constexpr index_t GemmDataPerReadB = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 2;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
......@@ -317,7 +339,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
#if 0
#if 1
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
......@@ -346,6 +368,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
Sequence<InBlockCopy_ThreadPerDimC,
InBlockCopy_ThreadPerDimH,
InBlockCopy_ThreadPerDimW,
......
......@@ -205,6 +205,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
......@@ -233,6 +235,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
......@@ -289,6 +293,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1,
WeiBlockCopyThreadPerDim0,
......
......@@ -409,7 +409,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
#elif 1
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
......@@ -454,7 +454,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
#elif 0
// 7x7, 38x38
constexpr index_t N = 64;
constexpr index_t C = 256;
......
......@@ -16,7 +16,9 @@ template <index_t BlockSize,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop,
index_t BatchPerThread>
index_t BatchPerThread,
index_t DataPerReadA,
index_t DataPerReadB>
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
index_t mMyThreadOffsetA = 0;
......@@ -220,7 +222,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
mMyThreadOffsetA,
a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths());
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
// copy B-sub to form B
......@@ -233,7 +236,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
mMyThreadOffsetB,
b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths());
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
// loop over batch
......@@ -264,7 +268,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
(ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA,
a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths());
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
}
......@@ -280,7 +285,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
(ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB,
b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths());
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
}
}
......
......@@ -14,7 +14,9 @@ template <index_t BlockSize,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop>
index_t KPerThreadLoop,
index_t DataPerReadA,
index_t DataPerReadB>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
struct MatrixIndex
......@@ -276,7 +278,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
mMyThreadOffsetA,
a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths());
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
#pragma unroll
......@@ -289,7 +292,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
mMyThreadOffsetB,
b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths());
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
// C = A * B
......@@ -359,7 +363,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx,
p_a_thread_0 + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths());
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
#pragma unroll
......@@ -369,7 +374,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx,
p_b_thread_0 + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths());
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
bool even_loop = true;
......@@ -394,7 +400,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx,
p_a_thread_next + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths());
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
#pragma unroll
......@@ -406,7 +413,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx,
p_b_thread_next + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths());
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
// C = A * B
......
......@@ -30,6 +30,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopyThreadPerDims,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead,
......@@ -169,7 +171,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread>{};
HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
......
......@@ -30,6 +30,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopyThreadPerDims,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead,
......@@ -172,7 +174,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread>{};
HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
......
......@@ -30,6 +30,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopyThreadPerDims,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead,
......@@ -173,7 +175,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread>{};
HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
......@@ -185,7 +189,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
// register
Float p_out_thread[out_khwn_thread_desc.GetElementSpace()];
#if 1
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc");
......
......@@ -26,6 +26,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
index_t InBlockCopyThreadPerDim0,
index_t InBlockCopyThreadPerDim1,
index_t WeiBlockCopyThreadPerDim0,
......@@ -174,7 +176,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop>{};
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t max_align =
......
......@@ -27,6 +27,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
index_t InBlockCopyThreadPerDim0,
index_t InBlockCopyThreadPerDim1,
index_t WeiBlockCopyThreadPerDim0,
......@@ -178,7 +180,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop>{};
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t max_align =
......
#pragma once
template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol>
template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol, index_t DataPerRead>
__device__ void threadwise_matrix_copy(SrcMatrix,
const Float* __restrict__ p_src,
DstMatrix,
Float* __restrict__ p_dst,
Sequence<NRow, NCol>)
Sequence<NRow, NCol>,
Number<DataPerRead>)
{
static_assert(NCol % DataPerRead == 0, "wrong! should be NCol % == DataPerRead == 0");
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
constexpr auto src_mtx = SrcMatrix{};
constexpr auto dst_mtx = DstMatrix{};
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; ++j)
for(index_t j = 0; j < NCol; j += DataPerRead)
{
const index_t src_index = src_mtx.Get1dIndex(i, j);
const index_t dst_index = dst_mtx.Get1dIndex(i, j);
p_dst[dst_index] = p_src[src_index];
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment