Commit c9fa46af authored by Chao Liu's avatar Chao Liu
Browse files

debugging implicit gemm v1: use 10d tensor output

parent 90abf427
...@@ -248,16 +248,15 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -248,16 +248,15 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1 #elif 1
// for 1x1, 14x14 // for 1x1, 14x14, Pascal
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2; constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4; constexpr index_t NPerThread = 8;
constexpr index_t KPerThread = 16; constexpr index_t KPerThread = 8;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1; constexpr index_t WoPerThread = 1;
...@@ -265,8 +264,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -265,8 +264,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 8; constexpr index_t InBlockCopy_ThreadPerDimC = 8;
...@@ -278,6 +277,37 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -278,6 +277,37 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2; constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif 1
// for 1x1, 14x14, Pascal, try
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 1;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 1;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#endif #endif
......
...@@ -69,7 +69,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -69,7 +69,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc)); Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
#if 0 #if 1
// 3x3, 34x34 // 3x3, 34x34
// need to use register double buffer for GEMM // need to use register double buffer for GEMM
constexpr index_t BPerBlock = 128; constexpr index_t BPerBlock = 128;
...@@ -87,9 +87,6 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -87,9 +87,6 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmNLevel1Cluster = 8; constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr index_t InBlockCopyThreadPerDim0 = 4; constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16; constexpr index_t InBlockCopyThreadPerDim1 = 16;
...@@ -98,6 +95,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -98,6 +95,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
......
...@@ -409,7 +409,7 @@ int main(int argc, char* argv[]) ...@@ -409,7 +409,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -580,7 +580,7 @@ int main(int argc, char* argv[]) ...@@ -580,7 +580,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 0
// 1x1 filter, 14x14 image, C = 2048 // 1x1 filter, 14x14 image, C = 2048
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 2048; constexpr index_t C = 2048;
...@@ -592,7 +592,7 @@ int main(int argc, char* argv[]) ...@@ -592,7 +592,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 1x1 filter, 14x14 image, C = 512 // 1x1 filter, 14x14 image, C = 512
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 512; constexpr index_t C = 512;
...@@ -663,7 +663,7 @@ int main(int argc, char* argv[]) ...@@ -663,7 +663,7 @@ int main(int argc, char* argv[])
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1 #elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 0 #elif 1
device_implicit_gemm_convolution_2_chwn_cyxk_khwn device_implicit_gemm_convolution_2_chwn_cyxk_khwn
#endif #endif
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
......
...@@ -44,6 +44,32 @@ __host__ __device__ constexpr auto ...@@ -44,6 +44,32 @@ __host__ __device__ constexpr auto
1>{}; 1>{};
} }
// this is ugly, only for 8d
template <index_t L0,
index_t L1,
index_t L2,
index_t L3,
index_t L4,
index_t L5,
index_t L6,
index_t L7,
index_t L8,
index_t L9>
__host__ __device__ constexpr auto
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7, L8, L9>)
{
return Sequence<L1 * L2 * L3 * L4 * L5 * L6 * L7 * L8 * L9,
L2 * L3 * L4 * L5 * L6 * L7 * L8 * L9,
L3 * L4 * L5 * L6 * L7 * L8 * L9,
L4 * L5 * L6 * L7 * L8 * L9,
L5 * L6 * L7 * L8 * L9,
L6 * L7 * L8 * L9,
L7 * L8 * L9,
L8 * L9,
L9,
1>{};
}
// this is ugly, only for 2d // this is ugly, only for 2d
template <index_t L0, index_t L1, index_t Align> template <index_t L0, index_t L1, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1>, __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1>,
......
...@@ -340,8 +340,7 @@ struct BlockwiseChwnTensorCopyPadded ...@@ -340,8 +340,7 @@ struct BlockwiseChwnTensorCopyPadded
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize; constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
const Float* p_src_tmp = const Float* p_src_tmp =
p_src + p_src + src_desc.Get1dIndex(c_block_data_begin,
src_desc.Get1dIndex(c_block_data_begin,
(ho_block_data_begin + h_block_pad_low) - h_global_pad_low, (ho_block_data_begin + h_block_pad_low) - h_global_pad_low,
(wo_block_data_begin + w_block_pad_low) - w_global_pad_low, (wo_block_data_begin + w_block_pad_low) - w_global_pad_low,
n_block_data_begin); n_block_data_begin);
......
...@@ -329,8 +329,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -329,8 +329,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
c_thread_sub_mtx, c_thread_sub_mtx,
p_c_thread + p_c_thread + c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
n_repeat * NPerLevel1Cluster), n_repeat * NPerLevel1Cluster),
c_block_mtx, c_block_mtx,
p_c_block + p_c_block +
......
...@@ -93,8 +93,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -93,8 +93,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
Float p_out_thread[out_thread_desc.GetElementSpace()]; Float p_out_thread[out_thread_desc.GetElementSpace()];
threadwise_4d_tensor_copy(out_block_desc, threadwise_4d_tensor_copy(out_block_desc,
p_out_block + p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin,
out_block_desc.Get1dIndex(n_thread_data_begin,
k_thread_data_begin, k_thread_data_begin,
ho_thread_data_begin, ho_thread_data_begin,
wo_thread_data_begin), wo_thread_data_begin),
...@@ -108,8 +107,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -108,8 +107,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
// threadwise convolution // threadwise convolution
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_thread_block_desc, in_thread_block_desc,
p_in_block + p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin,
in_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data_begin, c_thread_data_begin,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -124,8 +122,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -124,8 +122,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
threadwise_4d_tensor_copy(out_thread_desc, threadwise_4d_tensor_copy(out_thread_desc,
p_out_thread, p_out_thread,
out_block_desc, out_block_desc,
p_out_block + p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin,
out_block_desc.Get1dIndex(n_thread_data_begin,
k_thread_data_begin, k_thread_data_begin,
ho_thread_data_begin, ho_thread_data_begin,
wo_thread_data_begin), wo_thread_data_begin),
......
...@@ -40,12 +40,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn ...@@ -40,12 +40,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N] // be careful of this assertion
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0"); static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0");
static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock,
"wrong!");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -172,16 +168,13 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn ...@@ -172,16 +168,13 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
constexpr index_t max_align = constexpr index_t max_align =
mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead); mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
constexpr index_t in_block_space = constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
in_chwn_block_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space = constexpr index_t wei_block_space =
wei_cyxk_block_desc.GetElementSpace(Number<max_align>{}); wei_cyxk_block_desc.GetElementSpace(Number<max_align>{});
__shared__ Float __shared__ Float p_in_block[in_block_space];
p_in_block[in_block_space]; __shared__ Float p_wei_block[wei_block_space];
__shared__ Float
p_wei_block[wei_block_space];
// register // register
Float p_out_thread[out_khwn_thread_desc.GetElementSpace()]; Float p_out_thread[out_khwn_thread_desc.GetElementSpace()];
...@@ -190,8 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn ...@@ -190,8 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread);
const Float* p_in_global_block_begin = const Float* p_in_global_block_begin =
p_in_global + p_in_global + in_chwn_global_desc.Get1dIndex(
in_chwn_global_desc.Get1dIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_begin = const Float* p_wei_global_block_begin =
...@@ -269,26 +261,32 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn ...@@ -269,26 +261,32 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
// this is for v2 GEMM // this is for v2 GEMM
// output is a 8d tensor // output is a 10d tensor
if(NPerThread < NPerBlock && WoPerThread == 1) if(NPerThread <= NPerBlock)
{ {
constexpr index_t N1_ = GemmNPerThreadSubC; constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t W1_ = WoPerBlock / ((WoPerThread * NPerThread) / GemmNPerThreadSubC); constexpr index_t N1 = NPerBlock / N2;
constexpr index_t K2_ = GemmMPerThreadSubC;
constexpr index_t K1_ = KPerBlock / KPerThread; constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2;
constexpr auto out_8d_global_desc = make_ConstantTensorDescriptor(
Sequence<K / (K1_ * K2_), K1_, K2_, Ho, Wo / W1_, W1_, N / N1_, N1_>{}); constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_8d_thread_desc =
make_ConstantTensorDescriptor(Sequence<KPerBlock / (K1_ * K2_), constexpr auto out_10d_global_desc =
1, make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
K2_, K1,
HoPerThread, K2,
WoPerBlock / W1_, Ho,
1, Wo / (W1 * W2),
1, W1,
N1_>{}); W2,
N / (N1 * N2),
N1,
N2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
#if 0 #if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...@@ -301,25 +299,21 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn ...@@ -301,25 +299,21 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
} }
#endif #endif
threadwise_8d_tensor_copy( threadwise_10d_tensor_copy(
out_8d_thread_desc, out_10d_thread_desc,
p_out_thread, p_out_thread,
out_8d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_8d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite>{}); Number<OutThreadCopyDataPerWrite>{});
} }
else if(NPerThread == NPerBlock)
{
// not implemented yet
assert(false);
}
else else
{ {
// no implemented yet
assert(false); assert(false);
} }
#endif #endif
......
...@@ -31,11 +31,10 @@ template <index_t GridSize, ...@@ -31,11 +31,10 @@ template <index_t GridSize,
index_t WeiBlockCopyThreadPerDim0, index_t WeiBlockCopyThreadPerDim0,
index_t WeiBlockCopyThreadPerDim1, index_t WeiBlockCopyThreadPerDim1,
index_t InBlockCopyDataPerRead, index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead> index_t WeiBlockCopyDataPerRead,
index_t OutThreadCopyDataPerWrite>
struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
{ {
__host__ __device__ constexpr GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn() {}
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
...@@ -232,7 +231,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn ...@@ -232,7 +231,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
{ {
for(index_t x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
#if 0 #if 1
blockwise_gemm.Run blockwise_gemm.Run
#elif 0 #elif 0
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
......
...@@ -387,12 +387,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -387,12 +387,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{}); constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{});
threadwise_6d_tensor_copy( threadwise_6d_tensor_copy(out_6d_thread_desc,
out_6d_thread_desc,
p_out_thread, p_out_thread,
out_6d_global_desc, out_6d_global_desc,
p_out_global + p_out_global + out_kb_global_desc.Get1dIndex(
out_kb_global_desc.Get1dIndex(k_thread_data_begin, b_thread_data_begin), k_thread_data_begin, b_thread_data_begin),
out_6d_thread_desc.GetLengths(), out_6d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite>{}); Number<OutThreadCopyDataPerWrite>{});
} }
......
...@@ -113,8 +113,7 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ ...@@ -113,8 +113,7 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
c_block_work_begin += CPerBlock) c_block_work_begin += CPerBlock)
{ {
// copy input tensor to LDS // copy input tensor to LDS
blockwise_in_copy.Run(p_in_global + blockwise_in_copy.Run(p_in_global + in_global_desc.Get1dIndex(n_block_work_begin,
in_global_desc.Get1dIndex(n_block_work_begin,
c_block_work_begin, c_block_work_begin,
hi_block_work_begin, hi_block_work_begin,
wi_block_work_begin), wi_block_work_begin),
...@@ -144,9 +143,9 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ ...@@ -144,9 +143,9 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
} }
// copy output tensor from LDS to device mem // copy output tensor from LDS to device mem
blockwise_out_copy.Run( blockwise_out_copy.Run(p_out_block,
p_out_block, p_out_global + out_global_desc.Get1dIndex(n_block_work_begin,
p_out_global + k_block_work_begin,
out_global_desc.Get1dIndex( ho_block_work_begin,
n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin)); wo_block_work_begin));
} }
...@@ -175,17 +175,15 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -175,17 +175,15 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
c_block_data_begin += CPerBlock, __syncthreads()) c_block_data_begin += CPerBlock, __syncthreads())
{ {
// copy input tensor to LDS // copy input tensor to LDS
blockwise_in_copy.Run(p_in_global + blockwise_in_copy.Run(p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin,
in_nchw_global_desc.Get1dIndex(n_block_data_begin,
c_block_data_begin, c_block_data_begin,
hi_block_data_begin, hi_block_data_begin,
wi_block_data_begin), wi_block_data_begin),
p_in_block); p_in_block);
// copy weight tensor to LDS // copy weight tensor to LDS
blockwise_wei_copy.Run( blockwise_wei_copy.Run(p_wei_global + wei_kcyx_global_desc.Get1dIndex(
p_wei_global + k_block_data_begin, c_block_data_begin, 0, 0),
wei_kcyx_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_block); p_wei_block);
__syncthreads(); __syncthreads();
...@@ -196,8 +194,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -196,8 +194,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
#if 1 #if 1
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_nchw_thread_block_desc, in_nchw_thread_block_desc,
p_in_block + p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -209,8 +206,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -209,8 +206,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
#elif 0 #elif 0
threadwise_direct_convolution_3( threadwise_direct_convolution_3(
in_nchw_thread_block_desc, in_nchw_thread_block_desc,
p_in_block + p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -228,8 +224,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -228,8 +224,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
out_nkhw_thread_desc, out_nkhw_thread_desc,
p_out_thread, p_out_thread,
out_nkhw_global_desc, out_nkhw_global_desc,
p_out_global + p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
......
...@@ -198,9 +198,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -198,9 +198,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
p_in_vec_block); p_in_vec_block);
// copy weight tensor to LDS // copy weight tensor to LDS
blockwise_wei_copy.Run( blockwise_wei_copy.Run(p_wei_vec_global + wei_kcyx_vec_global_desc.Get1dIndex(
p_wei_vec_global + k_block_data_begin, c_block_data_begin, 0, 0),
wei_kcyx_vec_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_vec_block); p_wei_vec_block);
__syncthreads(); __syncthreads();
...@@ -211,8 +210,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -211,8 +210,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#if 1 #if 1
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_nchw_vec_thread_block_desc, in_nchw_vec_thread_block_desc,
p_in_vec_block + p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -224,8 +222,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -224,8 +222,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#elif 0 #elif 0
threadwise_direct_convolution_3( threadwise_direct_convolution_3(
in_nchw_vec_thread_block_desc, in_nchw_vec_thread_block_desc,
p_in_vec_block + p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -243,8 +240,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -243,8 +240,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
out_nkhw_thread_desc, out_nkhw_thread_desc,
p_out_thread, p_out_thread,
out_nkhw_global_desc, out_nkhw_global_desc,
p_out_global + p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
......
...@@ -283,8 +283,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( ...@@ -283,8 +283,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
out_hkwn_thread_desc, out_hkwn_thread_desc,
p_out_thread, p_out_thread,
out_khwn_global_desc, out_khwn_global_desc,
p_out_global + p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
......
...@@ -22,7 +22,8 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) ...@@ -22,7 +22,8 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
return os; return os;
} }
typedef enum { typedef enum
{
Half = 0, Half = 0,
Float = 1, Float = 1,
} DataType_t; } DataType_t;
......
...@@ -162,3 +162,118 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc, ...@@ -162,3 +162,118 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc,
} }
} }
} }
// need to assume src and dst is aligned
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
__device__ void threadwise_10d_tensor_copy(SrcDesc,
const Float* __restrict__ p_src,
DstDesc,
Float* __restrict__ p_dst,
SrcOpLengths,
Number<DataPerRead>)
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
static_assert(SrcDesc{}.GetDimension() == 10 && DstDesc{}.GetDimension() == 10 &&
SrcOpLengths::nDim == 10,
"wrong! should be 10 dimension");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
constexpr auto I9 = Number<9>{};
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
static_assert(SrcDesc{}.GetStride(I9) == 1 && DstDesc{}.GetStride(I9) == 1,
"wrong! only support stride7 == 1!\n");
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
"wrong! only support DataPerRead == 1, 2 or 4!\n");
static_assert(SrcDesc{}.GetStride(I8) % DataPerRead == 0 &&
DstDesc{}.GetStride(I8) % DataPerRead == 0,
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
constexpr index_t L9 = SrcOpLengths{}.Get(I9);
static_assert(L9 % DataPerRead == 0, "wrong! L9 should be evenly divided by DataPerRead");
constexpr index_t nloop_d9 = L9 / DataPerRead;
#pragma unroll
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
{
#pragma unroll
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
{
#pragma unroll
for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
{
#pragma unroll
for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
{
#pragma unroll
for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
{
#pragma unroll
for(index_t did5 = 0; did5 < ref_desc.GetLength(I5); ++did5)
{
#pragma unroll
for(index_t did6 = 0; did6 < ref_desc.GetLength(I6); ++did6)
{
#pragma unroll
for(index_t did7 = 0; did7 < ref_desc.GetLength(I7); ++did7)
{
#pragma unroll
for(index_t did8 = 0; did8 < ref_desc.GetLength(I8); ++did8)
{
#pragma unroll
for(index_t iloop_d9 = 0; iloop_d9 < nloop_d9; ++iloop_d9)
{
const index_t src_index =
src_desc.Get1dIndex(did0,
did1,
did2,
did3,
did4,
did5,
did6,
did7,
did8,
iloop_d9 * DataPerRead);
const index_t dst_index =
dst_desc.Get1dIndex(did0,
did1,
did2,
did3,
did4,
did5,
did6,
did7,
did8,
iloop_d9 * DataPerRead);
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src +
src_index));
}
}
}
}
}
}
}
}
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment