Commit 766b0a9e authored by Chao Liu's avatar Chao Liu
Browse files

experimenting

parent f35c64eb
......@@ -10,7 +10,7 @@ void device_direct_convolution_1(InDesc,
const Tensor<T>& wei,
OutDesc,
Tensor<T>& out,
unsigned nrepeat)
index_t nrepeat)
{
std::size_t data_sz = sizeof(T);
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
......@@ -34,28 +34,28 @@ void device_direct_convolution_1(InDesc,
#if 1
// 3x3, 34x34
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 16;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 16;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(gridwise_direct_convolution_1<T,
InDesc,
......
......@@ -10,7 +10,7 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& wei,
OutDesc,
Tensor<T>& out,
unsigned nrepeat)
index_t nrepeat)
{
std::size_t data_sz = sizeof(T);
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
......@@ -34,49 +34,49 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
#if 1
// 3x3, 34x34, 128 thread
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 128;
#elif 1
// 3x3, 34x34, 128 thread, fp16
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw<T,
......
......@@ -10,10 +10,10 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
const Tensor<TInWei>& wei_kcyx,
OutDesc,
Tensor<TOut>& out_nkhw,
unsigned nrepeat)
index_t nrepeat)
{
// this suppose in / wei data type is int8x4
constexpr unsigned NVector = 4;
constexpr index_t NVector = 4;
using accum_t = int32_t;
using vector_t = vector_type<TInWei, NVector>;
using vector_mem_t = typename vector_t::MemoryType;
......@@ -27,17 +27,17 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// vectorized input
auto in_nchw_vec_desc = make_ConstantTensorDescriptor(Sequence<N, C / NVector, Hi, Wi>{});
......@@ -96,84 +96,84 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
#if 0
// 3x3, 34x34, 128 thread, fp32, vector = 1
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3, 34x34, 128 thread, fp32, vector = 2
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3, 34x34, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 4;
constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 4;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t BlockSize = 128;
#elif 1
// 1x1, 32x32, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 16;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 4;
constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 16;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 4;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei,
......
......@@ -10,7 +10,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
unsigned nrepeat)
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -21,17 +21,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
......@@ -76,218 +76,218 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
#if 0
// for 3x3, 34x34
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 4;
constexpr unsigned InBlockCopy_ThreadPerDimH = 4;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif 0
// for 5x5, 36x36
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopy_ThreadPerDimC = 2;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 4;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t InBlockCopy_ThreadPerDimC = 2;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 2; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 2; // not used, yet
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 0
// for 7x7, 38x38
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 1;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4; // not used, yet
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 56x56
constexpr unsigned NPerBlock = 32;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 0
// for 1x1, 28x28
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 8;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif 1
// for 1x1, 14x14
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 8;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn<GridSize,
......
......@@ -12,7 +12,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
Tensor<T>& out_nkhw,
LowerPads,
UpperPads,
unsigned nrepeat)
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -23,17 +23,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
......@@ -77,177 +77,177 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 0
constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 1;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 1;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 1;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 1;
constexpr unsigned BlockSize = 8;
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 1;
constexpr index_t CPerBlock = 1;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 1;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 1;
constexpr index_t WeiBlockCopyThreadPerDim1 = 1;
constexpr index_t BlockSize = 8;
#elif 1
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 0
// for 5x5, 36x36
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 0
// for 7x7, 38x38
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 56x56
constexpr unsigned NPerBlock = 32;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 1
// 3x3 56x56, NKC = 16,256,128, with padding
// 3x3 28x28, NKC = 16,512,256, with padding
// 3x3 20x84, NKC = 16,256,256, with padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 2;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 64;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 2;
constexpr index_t WeiBlockCopyThreadPerDim1 = 64;
constexpr index_t BlockSize = 128;
#elif 0
// for 5x5 filter, 20x84 image, 1x1 padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 1;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 0
// for 1x1, 28x28
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded<GridSize,
......
......@@ -11,7 +11,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
unsigned nrepeat)
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -22,19 +22,19 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned N = in_nchw_desc.GetLength(I0);
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = in_nchw_desc.GetLength(I0);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1);
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
// convert in_nchw to in_cnhw
auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{});
......@@ -71,128 +71,158 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
#if 0
// 3x3, 34x34
// need to use register double buffer for GEMM
constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 8;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 1x1, 28x28, 64 threads
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 8;
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 64;
#elif 1
constexpr index_t BlockSize = 64;
#elif 0
// 1x1, 28x28, 128 threads, no lds-double-buffer
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 1x1, 28x28, 256 thread
constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 256;
#elif 1
// 1x1, 14x14, Vega 10
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 4;
constexpr unsigned GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 256;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
......@@ -208,7 +238,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
#if 1
......
......@@ -40,11 +40,11 @@ struct GeneratorTensor_Checkboard
template <class... Ts>
double operator()(Ts... Xs) const
{
std::array<unsigned long, sizeof...(Ts)> dims = {{Xs...}};
std::array<index_t, sizeof...(Ts)> dims = {{Xs...}};
return std::accumulate(dims.begin(),
dims.end(),
true,
[](bool init, unsigned long x) -> int { return init != (x % 2); })
[](bool init, index_t x) -> int { return init != (x % 2); })
? 1
: -1;
}
......@@ -80,9 +80,9 @@ auto make_TensorDescriptor(TConstTensorDesc)
constexpr auto I3 = Number<3>{};
constexpr auto desc = TConstTensorDesc{};
std::initializer_list<unsigned> lengths = {
std::initializer_list<index_t> lengths = {
desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3)};
std::initializer_list<unsigned> strides = {
std::initializer_list<index_t> strides = {
desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)};
return TensorDescriptor(lengths, strides);
......@@ -95,11 +95,11 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
LowerPads,
UpperPads)
{
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
auto f = [&](auto n, auto k, auto ho, auto wo) {
double v = 0;
......@@ -153,11 +153,11 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
std::size_t HO = out_nkhw.mDesc.GetLengths()[2];
std::size_t WO = out_nkhw.mDesc.GetLengths()[3];
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
std::size_t HiPerTile = HoPerTile + Y - 1;
std::size_t WiPerTile = WoPerTile + X - 1;
......@@ -399,211 +399,211 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[])
{
#if 0
constexpr unsigned N = 1;
constexpr unsigned C = 1;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 1;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 1;
constexpr index_t C = 1;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 1;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 34x34
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 34;
constexpr unsigned WI = 34;
constexpr unsigned K = 64;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 34;
constexpr index_t WI = 34;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 56x56
constexpr unsigned N = 64;
constexpr unsigned C = 64;
constexpr unsigned HI = 56;
constexpr unsigned WI = 56;
constexpr unsigned K = 64;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 3x3, 58x58
constexpr unsigned N = 64;
constexpr unsigned C = 64;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 64;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 5x5, 36x36
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 36;
constexpr unsigned WI = 36;
constexpr unsigned K = 64;
constexpr unsigned Y = 5;
constexpr unsigned X = 5;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 36;
constexpr index_t WI = 36;
constexpr index_t K = 64;
constexpr index_t Y = 5;
constexpr index_t X = 5;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 7x7, 38x38
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 38;
constexpr unsigned WI = 38;
constexpr unsigned K = 64;
constexpr unsigned Y = 7;
constexpr unsigned X = 7;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 38;
constexpr index_t WI = 38;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 7;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 58x58
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 3x3 filter, 58x58 image, 0x0 padding
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3 filter, 56x56 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 56;
constexpr unsigned WI = 56;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 3x3 filter, 28x28 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 512;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 1x1 filter, 28x28 image
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3 filter, 20x84 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 20;
constexpr unsigned WI = 84;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 20;
constexpr index_t WI = 84;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 3x3 filter, 112x112 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 64;
constexpr unsigned HI = 112;
constexpr unsigned WI = 112;
constexpr unsigned K = 128;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t N = 16;
constexpr index_t C = 64;
constexpr index_t HI = 112;
constexpr index_t WI = 112;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 5x5 filter, 20x86 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 20;
constexpr unsigned WI = 86;
constexpr unsigned K = 512;
constexpr unsigned Y = 5;
constexpr unsigned X = 5;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 20;
constexpr index_t WI = 86;
constexpr index_t K = 512;
constexpr index_t Y = 5;
constexpr index_t X = 5;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
constexpr unsigned N = 16;
constexpr unsigned C = 192;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 32;
constexpr unsigned Y = 5;
constexpr unsigned X = 5;
constexpr unsigned HPad = 2;
constexpr unsigned WPad = 2;
constexpr index_t N = 16;
constexpr index_t C = 192;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 32;
constexpr index_t Y = 5;
constexpr index_t X = 5;
constexpr index_t HPad = 2;
constexpr index_t WPad = 2;
#elif 0
// 1x1 filter, 32x32 image
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 32;
constexpr unsigned WI = 32;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 32;
constexpr index_t WI = 32;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 1x1 filter, 14x14 image
constexpr unsigned N = 128;
constexpr unsigned C = 2048;
constexpr unsigned HI = 14;
constexpr unsigned WI = 14;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
// 1x1 filter, 14x14 image, C = 2048
constexpr index_t N = 128;
constexpr index_t C = 2048;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
// 1x1 filter, 14x14 image, C = 512
constexpr unsigned N = 128;
constexpr unsigned C = 512;
constexpr unsigned HI = 14;
constexpr unsigned WI = 14;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#endif
auto lower_pads = Sequence<HPad, WPad>{};
......@@ -634,7 +634,7 @@ int main(int argc, char* argv[])
}
bool do_verification = atoi(argv[1]);
unsigned nrepeat = atoi(argv[2]);
index_t nrepeat = atoi(argv[2]);
if(do_verification)
{
......
#pragma once
template <class TData, unsigned NSize>
template <class TData, index_t NSize>
struct Array
{
using Type = Array<TData, NSize>;
static constexpr unsigned nSize = NSize;
static constexpr index_t nSize = NSize;
unsigned mData[nSize];
index_t mData[nSize];
template <class... Xs>
__host__ __device__ Array(Xs... xs) : mData{static_cast<TData>(xs)...}
{
}
__host__ __device__ TData operator[](unsigned i) const { return mData[i]; }
__host__ __device__ TData operator[](index_t i) const { return mData[i]; }
};
#pragma once
#include "common.hip.hpp"
template <unsigned NRow_, unsigned NCol_, unsigned RowStride_>
template <index_t NRow_, index_t NCol_, index_t RowStride_>
struct ConstantMatrixDescriptor
{
__host__ __device__ constexpr ConstantMatrixDescriptor()
......@@ -9,24 +9,28 @@ struct ConstantMatrixDescriptor
static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!");
}
__host__ __device__ constexpr unsigned NRow() const { return NRow_; }
__host__ __device__ constexpr index_t NRow() const { return NRow_; }
__host__ __device__ constexpr unsigned NCol() const { return NCol_; }
__host__ __device__ constexpr index_t NCol() const { return NCol_; }
__host__ __device__ constexpr unsigned RowStride() const { return RowStride_; }
__host__ __device__ constexpr index_t RowStride() const { return RowStride_; }
__host__ __device__ constexpr auto GetLengths() const { return Sequence<NRow_, NCol_>{}; }
__host__ __device__ constexpr unsigned GetElementSize() const { return NRow_ * NCol_; }
__host__ __device__ constexpr index_t GetElementSize() const { return NRow_ * NCol_; }
__host__ __device__ constexpr unsigned GetElementSpace() const { return NRow_ * RowStride_; }
__host__ __device__ constexpr index_t GetElementSpace() const { return NRow_ * RowStride_; }
__host__ __device__ unsigned Get1dIndex(unsigned irow, unsigned icol) const
__host__ __device__ index_t Get1dIndex(index_t irow, index_t icol) const
{
#if DEVICE_BACKEND_HIP
return __mul24(irow, RowStride_) + icol;
#else
return irow * RowStride_ + icol;
#endif
}
template <unsigned SubNRow, unsigned SubNCol>
template <index_t SubNRow, index_t SubNCol>
__host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
Number<SubNCol>) const
{
......@@ -34,13 +38,13 @@ struct ConstantMatrixDescriptor
}
};
template <unsigned NRow, unsigned NCol>
template <index_t NRow, index_t NCol>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>)
{
return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
}
template <unsigned NRow, unsigned NCol, unsigned RowStride>
template <index_t NRow, index_t NCol, index_t RowStride>
__host__ __device__ constexpr auto
make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>)
{
......
......@@ -2,35 +2,35 @@
#include "common.hip.hpp"
// this is ugly, only for 2d
template <unsigned L0, unsigned L1>
template <index_t L0, index_t L1>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>)
{
return Sequence<L1, 1>{};
}
// this is ugly, only for 4d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3>
template <index_t L0, index_t L1, index_t L2, index_t L3>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
{
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
}
// this is ugly, only for 6d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned L4, unsigned L5>
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t L4, index_t L5>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5>)
{
return Sequence<L1 * L2 * L3 * L4 * L5, L2 * L3 * L4 * L5, L3 * L4 * L5, L4 * L5, L5, 1>{};
}
// this is ugly, only for 8d
template <unsigned L0,
unsigned L1,
unsigned L2,
unsigned L3,
unsigned L4,
unsigned L5,
unsigned L6,
unsigned L7>
template <index_t L0,
index_t L1,
index_t L2,
index_t L3,
index_t L4,
index_t L5,
index_t L6,
index_t L7>
__host__ __device__ constexpr auto
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7>)
{
......@@ -45,20 +45,20 @@ __host__ __device__ constexpr auto
}
// this is ugly, only for 2d
template <unsigned L0, unsigned L1, unsigned Align>
template <index_t L0, index_t L1, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1>,
Number<Align>)
{
constexpr unsigned L1_align = Align * ((L1 + Align - 1) / Align);
constexpr index_t L1_align = Align * ((L1 + Align - 1) / Align);
return Sequence<L1_align, 1>{};
}
// this is ugly, only for 4d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned Align>
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2, L3>,
Number<Align>)
{
constexpr unsigned L3_align = Align * ((L3 + Align - 1) / Align);
constexpr index_t L3_align = Align * ((L3 + Align - 1) / Align);
return Sequence<L1 * L2 * L3_align, L2 * L3_align, L3_align, 1>{};
}
......@@ -66,27 +66,27 @@ template <class Lengths, class Strides>
struct ConstantTensorDescriptor
{
using Type = ConstantTensorDescriptor<Lengths, Strides>;
static constexpr unsigned nDim = Lengths::nDim;
static constexpr index_t nDim = Lengths::nDim;
__host__ __device__ constexpr ConstantTensorDescriptor()
{
static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
}
__host__ __device__ constexpr unsigned GetDimension() const { return nDim; }
__host__ __device__ constexpr index_t GetDimension() const { return nDim; }
__host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; }
__host__ __device__ constexpr Strides GetStrides() const { return Strides{}; }
template <unsigned I>
__host__ __device__ constexpr unsigned GetLength(Number<I>) const
template <index_t I>
__host__ __device__ constexpr index_t GetLength(Number<I>) const
{
return Lengths{}.Get(Number<I>{});
}
template <unsigned I>
__host__ __device__ constexpr unsigned GetStride(Number<I>) const
template <index_t I>
__host__ __device__ constexpr index_t GetStride(Number<I>) const
{
return Strides{}.Get(Number<I>{});
}
......@@ -95,18 +95,18 @@ struct ConstantTensorDescriptor
struct GetElementSize_f
{
template <class IDim>
__host__ __device__ constexpr unsigned operator()(IDim idim) const
__host__ __device__ constexpr index_t operator()(IDim idim) const
{
return Type{}.GetLength(idim);
}
};
__host__ __device__ constexpr unsigned GetElementSize() const
__host__ __device__ constexpr index_t GetElementSize() const
{
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct multiply
{
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
{
return a * b;
}
......@@ -119,19 +119,19 @@ struct ConstantTensorDescriptor
struct GetElementSpace_f
{
template <class IDim>
__host__ __device__ constexpr unsigned operator()(IDim idim) const
__host__ __device__ constexpr index_t operator()(IDim idim) const
{
return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim);
}
};
template <class Align = Number<1>>
__host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const
__host__ __device__ constexpr index_t GetElementSpace(Align align = Align{}) const
{
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct add
{
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
{
return a + b;
}
......@@ -141,17 +141,21 @@ struct ConstantTensorDescriptor
}
template <class... Is>
__host__ __device__ unsigned Get1dIndex(Is... is) const
__host__ __device__ index_t Get1dIndex(Is... is) const
{
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
const auto multi_id = Array<unsigned, nDim>(is...);
const auto multi_id = Array<index_t, nDim>(is...);
unsigned id = 0;
index_t id = 0;
static_loop_n<nDim>{}([&](auto IDim) {
constexpr unsigned idim = IDim.Get();
constexpr index_t idim = IDim.Get();
#if DEVICE_BACKEND_HIP
id += __mul24(multi_id[idim], GetStride(IDim));
#else
id += multi_id[idim] * GetStride(IDim);
#endif
});
return id;
......@@ -163,7 +167,7 @@ struct ConstantTensorDescriptor
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
}
template <unsigned IDim, unsigned NVector>
template <index_t IDim, index_t NVector>
__host__ __device__ constexpr auto Vectorize(Number<IDim>, Number<NVector>) const
{
assert(false); // not implemented
......@@ -183,7 +187,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride
return ConstantTensorDescriptor<Lengths, Strides>{};
}
template <class Lengths, unsigned Align>
template <class Lengths, index_t Align>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
{
using Strides = decltype(calculate_default_strides_aligned(Lengths{}, Number<Align>{}));
......@@ -194,7 +198,7 @@ template <class TDesc>
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
{
constexpr auto desc = TDesc{};
constexpr unsigned ndim = desc.GetDimension();
constexpr index_t ndim = desc.GetDimension();
static_assert(ndim >= 2 && ndim <= 8, "wrong!");
......
......@@ -2,38 +2,38 @@
#include "constant_integral.hip.hpp"
#include "functional.hip.hpp"
template <unsigned... Is>
template <index_t... Is>
struct Sequence
{
using Type = Sequence<Is...>;
static constexpr unsigned nDim = sizeof...(Is);
static constexpr index_t nDim = sizeof...(Is);
const unsigned mData[nDim] = {Is...};
const index_t mData[nDim] = {Is...};
template <unsigned I>
__host__ __device__ constexpr unsigned Get(Number<I>) const
template <index_t I>
__host__ __device__ constexpr index_t Get(Number<I>) const
{
return mData[I];
}
// this is ugly, only for nDIm = 4
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
template <index_t I0, index_t I1, index_t I2, index_t I3>
__host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const
{
static_assert(nDim == 4, "nDim != 4");
constexpr auto old_sequence = Type{};
constexpr unsigned NR0 = old_sequence.mData[I0];
constexpr unsigned NR1 = old_sequence.mData[I1];
constexpr unsigned NR2 = old_sequence.mData[I2];
constexpr unsigned NR3 = old_sequence.mData[I3];
constexpr index_t NR0 = old_sequence.mData[I0];
constexpr index_t NR1 = old_sequence.mData[I1];
constexpr index_t NR2 = old_sequence.mData[I2];
constexpr index_t NR3 = old_sequence.mData[I3];
return Sequence<NR0, NR1, NR2, NR3>{};
}
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
template <index_t I0, index_t I1, index_t I2, index_t I3>
__host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence<I0, I1, I2, I3>) const
{
// don't know how to implement this
......@@ -41,7 +41,7 @@ struct Sequence
assert(false);
}
template <unsigned I>
template <index_t I>
__host__ __device__ constexpr auto PushBack(Number<I>) const
{
return Sequence<Is..., I>{};
......@@ -56,14 +56,14 @@ struct Sequence
}
};
template <unsigned... Is, unsigned I>
template <index_t... Is, index_t I>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
{
static_assert(sizeof...(Is) >= 1, "empty Sequence!");
return Sequence<Is...>{};
}
template <class F, unsigned... Xs, unsigned... Ys>
template <class F, index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequence<Ys...>, F f)
{
static_assert(Sequence<Xs...>::nDim == Sequence<Ys...>::nDim, "Dim not the same");
......@@ -71,12 +71,12 @@ __host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequenc
return Sequence<f(Xs, Ys)...>{};
}
template <unsigned... Xs, unsigned... Ys>
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequence<Ys...>)
{
struct add
{
__host__ __device__ constexpr unsigned operator()(unsigned x, unsigned y) const
__host__ __device__ constexpr index_t operator()(index_t x, index_t y) const
{
return x + y;
}
......@@ -85,7 +85,7 @@ __host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequen
return sequence_sequence_op(Sequence<Xs...>{}, Sequence<Ys...>{}, add{});
}
template <unsigned... Is>
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
{
return sequence_pop_back(Type{});
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -3,16 +3,16 @@
#include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_direct_convolution.hip.hpp"
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class InBlockDesc,
class WeiBlockDesc,
class OutBlockDesc,
unsigned NPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned HoPerThread,
unsigned WoPerThread>
index_t NPerThread,
index_t KPerThread,
index_t CPerThread,
index_t HoPerThread,
index_t WoPerThread>
__device__ void blockwise_direct_convolution(InBlockDesc,
Float* const __restrict__ p_in_block,
WeiBlockDesc,
......@@ -29,17 +29,17 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr auto wei_block_desc = WeiBlockDesc{};
constexpr auto out_block_desc = OutBlockDesc{};
constexpr unsigned Y = wei_block_desc.GetLength(I2);
constexpr unsigned X = wei_block_desc.GetLength(I3);
constexpr index_t Y = wei_block_desc.GetLength(I2);
constexpr index_t X = wei_block_desc.GetLength(I3);
constexpr unsigned InTileSizeH = HoPerThread + Y - 1;
constexpr unsigned InTileSizeW = WoPerThread + X - 1;
constexpr index_t InTileSizeH = HoPerThread + Y - 1;
constexpr index_t InTileSizeW = WoPerThread + X - 1;
// divide thread work
constexpr unsigned NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread;
constexpr unsigned KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread;
constexpr unsigned YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread;
constexpr unsigned XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread;
constexpr index_t NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread;
constexpr index_t KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread;
constexpr index_t YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread;
constexpr index_t XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread;
#if 0
if(threadIdx.x == 0)
......@@ -68,27 +68,27 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr auto out_thread_block_desc =
make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_block_desc.GetStrides());
const unsigned thread_id = threadIdx.x;
const index_t thread_id = threadIdx.x;
for(unsigned thread_work_id = thread_id;
for(index_t thread_work_id = thread_id;
thread_work_id < NThreadWork * KThreadWork * YThreadWork * XThreadWork;
thread_work_id += BlockSize)
{
unsigned itmp = thread_work_id;
unsigned n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork);
index_t itmp = thread_work_id;
index_t n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork);
itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork);
unsigned k_thread_work_id = itmp / (YThreadWork * XThreadWork);
index_t k_thread_work_id = itmp / (YThreadWork * XThreadWork);
itmp -= k_thread_work_id * (YThreadWork * XThreadWork);
unsigned y_thread_work_id = itmp / XThreadWork;
unsigned x_thread_work_id = itmp - y_thread_work_id * XThreadWork;
index_t y_thread_work_id = itmp / XThreadWork;
index_t x_thread_work_id = itmp - y_thread_work_id * XThreadWork;
unsigned n_thread_data_begin = n_thread_work_id * NPerThread;
unsigned k_thread_data_begin = k_thread_work_id * KPerThread;
unsigned ho_thread_data_begin = y_thread_work_id * HoPerThread;
unsigned wo_thread_data_begin = x_thread_work_id * WoPerThread;
index_t n_thread_data_begin = n_thread_work_id * NPerThread;
index_t k_thread_data_begin = k_thread_work_id * KPerThread;
index_t ho_thread_data_begin = y_thread_work_id * HoPerThread;
index_t wo_thread_data_begin = x_thread_work_id * WoPerThread;
unsigned hi_thread_data_begin = ho_thread_data_begin; // minus padding
unsigned wi_thread_data_begin = wo_thread_data_begin; // minus padding
index_t hi_thread_data_begin = ho_thread_data_begin; // minus padding
index_t wi_thread_data_begin = wo_thread_data_begin; // minus padding
Float p_out_thread[out_thread_desc.GetElementSpace()];
......@@ -102,7 +102,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
p_out_thread,
out_thread_desc.GetLengths());
for(unsigned c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1);
for(index_t c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1);
c_thread_data_begin += CPerThread)
{
// threadwise convolution
......
This diff is collapsed.
......@@ -5,9 +5,9 @@
#include "Array.hip.hpp"
#include "functional.hip.hpp"
__device__ unsigned get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ unsigned get_block_1d_id() { return blockIdx.x; }
__device__ index_t get_block_1d_id() { return blockIdx.x; }
template <class T1, class T2>
struct is_same
......@@ -35,7 +35,7 @@ __host__ __device__ constexpr T min(T a, T b)
}
#endif
__host__ __device__ constexpr unsigned integer_divide_ceil(unsigned a, unsigned b)
__host__ __device__ constexpr index_t integer_divide_ceil(index_t a, index_t b)
{
return (a + b - 1) / b;
}
......@@ -11,3 +11,5 @@
#include "nvToolsExt.h"
#include "helper_cuda.h"
#endif
using index_t = uint32_t;
......@@ -8,5 +8,5 @@ struct integral_constant
__host__ __device__ constexpr T Get() const { return value; }
};
template <unsigned N>
using Number = integral_constant<unsigned, N>;
template <index_t N>
using Number = integral_constant<index_t, N>;
#pragma once
#include "config.h"
template <class T, unsigned N>
template <class T, index_t N>
struct vector_type
{
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment