Commit 766b0a9e authored by Chao Liu's avatar Chao Liu
Browse files

experimenting

parent f35c64eb
...@@ -10,7 +10,7 @@ void device_direct_convolution_1(InDesc, ...@@ -10,7 +10,7 @@ void device_direct_convolution_1(InDesc,
const Tensor<T>& wei, const Tensor<T>& wei,
OutDesc, OutDesc,
Tensor<T>& out, Tensor<T>& out,
unsigned nrepeat) index_t nrepeat)
{ {
std::size_t data_sz = sizeof(T); std::size_t data_sz = sizeof(T);
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace()); DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
...@@ -34,28 +34,28 @@ void device_direct_convolution_1(InDesc, ...@@ -34,28 +34,28 @@ void device_direct_convolution_1(InDesc,
#if 1 #if 1
// 3x3, 34x34 // 3x3, 34x34
constexpr unsigned NPerBlock = 2; constexpr index_t NPerBlock = 2;
constexpr unsigned KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2; constexpr index_t NPerThread = 2;
constexpr unsigned KPerThread = 4; constexpr index_t KPerThread = 4;
constexpr unsigned CPerThread = 2; constexpr index_t CPerThread = 2;
constexpr unsigned HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr unsigned WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#endif #endif
constexpr unsigned GridSize = constexpr index_t GridSize =
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) * (out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock); (out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel(gridwise_direct_convolution_1<T, float time = launch_kernel(gridwise_direct_convolution_1<T,
InDesc, InDesc,
......
...@@ -10,7 +10,7 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc, ...@@ -10,7 +10,7 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& wei, const Tensor<T>& wei,
OutDesc, OutDesc,
Tensor<T>& out, Tensor<T>& out,
unsigned nrepeat) index_t nrepeat)
{ {
std::size_t data_sz = sizeof(T); std::size_t data_sz = sizeof(T);
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace()); DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
...@@ -34,49 +34,49 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc, ...@@ -34,49 +34,49 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
#if 1 #if 1
// 3x3, 34x34, 128 thread // 3x3, 34x34, 128 thread
constexpr unsigned NPerBlock = 2; constexpr index_t NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr index_t KPerBlock = 32;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2; constexpr index_t NPerThread = 2;
constexpr unsigned KPerThread = 4; constexpr index_t KPerThread = 4;
constexpr unsigned CPerThread = 2; constexpr index_t CPerThread = 2;
constexpr unsigned HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr unsigned WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr index_t InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1 #elif 1
// 3x3, 34x34, 128 thread, fp16 // 3x3, 34x34, 128 thread, fp16
constexpr unsigned NPerBlock = 2; constexpr index_t NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr index_t KPerBlock = 32;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2; constexpr index_t NPerThread = 2;
constexpr unsigned KPerThread = 4; constexpr index_t KPerThread = 4;
constexpr unsigned CPerThread = 2; constexpr index_t CPerThread = 2;
constexpr unsigned HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr unsigned WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr index_t InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#endif #endif
constexpr unsigned GridSize = constexpr index_t GridSize =
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) * (out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock); (out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = float time =
launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw<T, launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw<T,
......
...@@ -10,10 +10,10 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -10,10 +10,10 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
const Tensor<TInWei>& wei_kcyx, const Tensor<TInWei>& wei_kcyx,
OutDesc, OutDesc,
Tensor<TOut>& out_nkhw, Tensor<TOut>& out_nkhw,
unsigned nrepeat) index_t nrepeat)
{ {
// this suppose in / wei data type is int8x4 // this suppose in / wei data type is int8x4
constexpr unsigned NVector = 4; constexpr index_t NVector = 4;
using accum_t = int32_t; using accum_t = int32_t;
using vector_t = vector_type<TInWei, NVector>; using vector_t = vector_type<TInWei, NVector>;
using vector_mem_t = typename vector_t::MemoryType; using vector_mem_t = typename vector_t::MemoryType;
...@@ -27,17 +27,17 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -27,17 +27,17 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{}; constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{}; constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2); constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3); constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0); constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0); constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1); constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2); constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3); constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// vectorized input // vectorized input
auto in_nchw_vec_desc = make_ConstantTensorDescriptor(Sequence<N, C / NVector, Hi, Wi>{}); auto in_nchw_vec_desc = make_ConstantTensorDescriptor(Sequence<N, C / NVector, Hi, Wi>{});
...@@ -96,84 +96,84 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -96,84 +96,84 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
#if 0 #if 0
// 3x3, 34x34, 128 thread, fp32, vector = 1 // 3x3, 34x34, 128 thread, fp32, vector = 1
constexpr unsigned NPerBlock = 2; constexpr index_t NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr index_t KPerBlock = 32;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2; constexpr index_t NPerThread = 2;
constexpr unsigned KPerThread = 4; constexpr index_t KPerThread = 4;
constexpr unsigned CPerThread = 2; constexpr index_t CPerThread = 2;
constexpr unsigned HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr unsigned WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr index_t InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2; constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 3x3, 34x34, 128 thread, fp32, vector = 2 // 3x3, 34x34, 128 thread, fp32, vector = 2
constexpr unsigned NPerBlock = 2; constexpr index_t NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr index_t KPerBlock = 32;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2; constexpr index_t NPerThread = 2;
constexpr unsigned KPerThread = 4; constexpr index_t KPerThread = 4;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr unsigned WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr index_t InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2; constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 3x3, 34x34, 128 thread, int8, vector = 4 // 3x3, 34x34, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 2; constexpr index_t NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr index_t KPerBlock = 32;
constexpr unsigned CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 1; constexpr index_t NPerThread = 1;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned CPerThread = 2; constexpr index_t CPerThread = 2;
constexpr unsigned HoPerThread = 4; constexpr index_t HoPerThread = 4;
constexpr unsigned WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr index_t InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2; constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1 #elif 1
// 1x1, 32x32, 128 thread, int8, vector = 4 // 1x1, 32x32, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 1; constexpr index_t NPerBlock = 1;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 16; constexpr index_t CPerBlock = 16;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 1; constexpr index_t NPerThread = 1;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned CPerThread = 2; constexpr index_t CPerThread = 2;
constexpr unsigned HoPerThread = 4; constexpr index_t HoPerThread = 4;
constexpr unsigned WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr index_t InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2; constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#endif #endif
constexpr unsigned GridSize = constexpr index_t GridSize =
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock); (N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel( float time = launch_kernel(
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei, gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei,
......
...@@ -10,7 +10,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -10,7 +10,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kcyx,
OutDesc, OutDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
unsigned nrepeat) index_t nrepeat)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -21,17 +21,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -21,17 +21,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{}; constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{}; constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2); constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3); constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0); constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0); constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1); constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2); constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3); constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight // reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{}); auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
...@@ -76,218 +76,218 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -76,218 +76,218 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
#if 0 #if 0
// for 3x3, 34x34 // for 3x3, 34x34
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 8; constexpr index_t NPerThread = 8;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 4; constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr unsigned InBlockCopy_ThreadPerDimH = 4; constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2; constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2; constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 5x5, 36x36 // for 5x5, 36x36
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 8; constexpr index_t NPerThread = 8;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopy_ThreadPerDimC = 2; constexpr index_t InBlockCopy_ThreadPerDimC = 2;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2; constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 4; constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 2; constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2; constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 3x3 58x58, NKC = 64, 64, 256 // 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 2; // not used, yet constexpr index_t InBlockCopyDataPerRead = 2; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 3x3 58x58, NKC = 16,256,128 // 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8; constexpr index_t NPerBlock = 8;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 7x7, 38x38 // for 7x7, 38x38
constexpr unsigned NPerBlock = 8; constexpr index_t NPerBlock = 8;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 1; constexpr index_t CPerBlock = 1;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet constexpr index_t InBlockCopyDataPerRead = 4; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 3x3, 56x56 // for 3x3, 56x56
constexpr unsigned NPerBlock = 32; constexpr index_t NPerBlock = 32;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2; constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 1x1, 28x28 // for 1x1, 28x28
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr unsigned CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2; constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 8; constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2; constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2; constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2; constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1 #elif 1
// for 1x1, 14x14 // for 1x1, 14x14
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr unsigned CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2; constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 8; constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2; constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2; constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2; constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#endif #endif
constexpr unsigned GridSize = constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel( float time = launch_kernel(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn<GridSize, gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn<GridSize,
......
...@@ -12,7 +12,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc, ...@@ -12,7 +12,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
LowerPads, LowerPads,
UpperPads, UpperPads,
unsigned nrepeat) index_t nrepeat)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -23,17 +23,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc, ...@@ -23,17 +23,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{}; constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{}; constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2); constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3); constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0); constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0); constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1); constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2); constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3); constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight // reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{}); auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
...@@ -77,177 +77,177 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc, ...@@ -77,177 +77,177 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
out_khwn_device_buf.ToDevice(out_khwn.mData.data()); out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 0 #if 0
constexpr unsigned NPerBlock = 1; constexpr index_t NPerBlock = 1;
constexpr unsigned KPerBlock = 1; constexpr index_t KPerBlock = 1;
constexpr unsigned CPerBlock = 1; constexpr index_t CPerBlock = 1;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 1; constexpr index_t NPerThread = 1;
constexpr unsigned KPerThread = 1; constexpr index_t KPerThread = 1;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 1; constexpr index_t WeiBlockCopyThreadPerDim0 = 1;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 1; constexpr index_t WeiBlockCopyThreadPerDim1 = 1;
constexpr unsigned BlockSize = 8; constexpr index_t BlockSize = 8;
#elif 1 #elif 1
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256 // for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 3x3 58x58, NKC = 16,256,128 // 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8; constexpr index_t NPerBlock = 8;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 5x5, 36x36 // for 5x5, 36x36
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 7x7, 38x38 // for 7x7, 38x38
constexpr unsigned NPerBlock = 8; constexpr index_t NPerBlock = 8;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 3x3, 56x56 // for 3x3, 56x56
constexpr unsigned NPerBlock = 32; constexpr index_t NPerBlock = 32;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2; constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1 #elif 1
// 3x3 56x56, NKC = 16,256,128, with padding // 3x3 56x56, NKC = 16,256,128, with padding
// 3x3 28x28, NKC = 16,512,256, with padding // 3x3 28x28, NKC = 16,512,256, with padding
// 3x3 20x84, NKC = 16,256,256, with padding // 3x3 20x84, NKC = 16,256,256, with padding
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 2; constexpr index_t WeiBlockCopyThreadPerDim0 = 2;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 64; constexpr index_t WeiBlockCopyThreadPerDim1 = 64;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 5x5 filter, 20x84 image, 1x1 padding // for 5x5 filter, 20x84 image, 1x1 padding
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 1; constexpr index_t CPerBlock = 1;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 5x5 filter, 28x28 image, 2x2 padding // 5x5 filter, 28x28 image, 2x2 padding
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 32; constexpr index_t KPerBlock = 32;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 1x1, 28x28 // for 1x1, 28x28
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr unsigned CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2; constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 2; constexpr index_t CPerThread = 2;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#endif #endif
constexpr unsigned GridSize = constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel( float time = launch_kernel(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded<GridSize, gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded<GridSize,
......
...@@ -11,7 +11,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -11,7 +11,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kcyx,
OutDesc, OutDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
unsigned nrepeat) index_t nrepeat)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -22,19 +22,19 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -22,19 +22,19 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{}; constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{}; constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned N = in_nchw_desc.GetLength(I0); constexpr index_t N = in_nchw_desc.GetLength(I0);
constexpr unsigned Hi = in_nchw_desc.GetLength(I2); constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3); constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0); constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1); constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2); constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3); constexpr index_t X = wei_kcyx_desc.GetLength(I3);
constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1); constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
// convert in_nchw to in_cnhw // convert in_nchw to in_cnhw
auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{}); auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{});
...@@ -71,128 +71,158 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -71,128 +71,158 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
#if 0 #if 0
// 3x3, 34x34 // 3x3, 34x34
// need to use register double buffer for GEMM // need to use register double buffer for GEMM
constexpr unsigned BPerBlock = 128; constexpr index_t BPerBlock = 128;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned BPerThread = 8; constexpr index_t BPerThread = 8;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 8; constexpr index_t GemmNLevel1Cluster = 8;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8; constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8; constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4; constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16; constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 1x1, 28x28, 64 threads // 1x1, 28x28, 64 threads
constexpr unsigned BPerBlock = 64; constexpr index_t BPerBlock = 64;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8; constexpr index_t BPerThread = 8;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8; constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8; constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4; constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16; constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 64; constexpr index_t BlockSize = 64;
#elif 1 #elif 0
// 1x1, 28x28, 128 threads, no lds-double-buffer // 1x1, 28x28, 128 threads, no lds-double-buffer
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128 // 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
constexpr unsigned BPerBlock = 64; constexpr index_t BPerBlock = 64;
constexpr unsigned KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr unsigned CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8; constexpr index_t BPerThread = 8;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8; constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8; constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4; constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16; constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 1x1, 28x28, 256 thread // 1x1, 28x28, 256 thread
constexpr unsigned BPerBlock = 128; constexpr index_t BPerBlock = 128;
constexpr unsigned KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr unsigned CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 256;
#elif 1
// 1x1, 14x14, Vega 10
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8; constexpr index_t BPerThread = 8;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8; constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8; constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4; constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16; constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 256; constexpr index_t BlockSize = 128;
#endif #endif
constexpr unsigned GridSize = constexpr index_t GridSize =
((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); ((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
...@@ -208,7 +238,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -208,7 +238,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data()); out_khwn_device_buf.ToDevice(out_khwn.mData.data());
for(unsigned i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel( float time = launch_kernel(
#if 1 #if 1
......
...@@ -40,11 +40,11 @@ struct GeneratorTensor_Checkboard ...@@ -40,11 +40,11 @@ struct GeneratorTensor_Checkboard
template <class... Ts> template <class... Ts>
double operator()(Ts... Xs) const double operator()(Ts... Xs) const
{ {
std::array<unsigned long, sizeof...(Ts)> dims = {{Xs...}}; std::array<index_t, sizeof...(Ts)> dims = {{Xs...}};
return std::accumulate(dims.begin(), return std::accumulate(dims.begin(),
dims.end(), dims.end(),
true, true,
[](bool init, unsigned long x) -> int { return init != (x % 2); }) [](bool init, index_t x) -> int { return init != (x % 2); })
? 1 ? 1
: -1; : -1;
} }
...@@ -80,9 +80,9 @@ auto make_TensorDescriptor(TConstTensorDesc) ...@@ -80,9 +80,9 @@ auto make_TensorDescriptor(TConstTensorDesc)
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto desc = TConstTensorDesc{}; constexpr auto desc = TConstTensorDesc{};
std::initializer_list<unsigned> lengths = { std::initializer_list<index_t> lengths = {
desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3)}; desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3)};
std::initializer_list<unsigned> strides = { std::initializer_list<index_t> strides = {
desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)}; desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)};
return TensorDescriptor(lengths, strides); return TensorDescriptor(lengths, strides);
...@@ -95,11 +95,11 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw, ...@@ -95,11 +95,11 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
LowerPads, LowerPads,
UpperPads) UpperPads)
{ {
unsigned h_pad_low = LowerPads{}.Get(Number<0>{}); index_t h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{}); index_t w_pad_low = LowerPads{}.Get(Number<1>{});
unsigned h_pad_up = UpperPads{}.Get(Number<0>{}); index_t h_pad_up = UpperPads{}.Get(Number<0>{});
unsigned w_pad_up = UpperPads{}.Get(Number<1>{}); index_t w_pad_up = UpperPads{}.Get(Number<1>{});
auto f = [&](auto n, auto k, auto ho, auto wo) { auto f = [&](auto n, auto k, auto ho, auto wo) {
double v = 0; double v = 0;
...@@ -153,11 +153,11 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw, ...@@ -153,11 +153,11 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
std::size_t HO = out_nkhw.mDesc.GetLengths()[2]; std::size_t HO = out_nkhw.mDesc.GetLengths()[2];
std::size_t WO = out_nkhw.mDesc.GetLengths()[3]; std::size_t WO = out_nkhw.mDesc.GetLengths()[3];
unsigned h_pad_low = LowerPads{}.Get(Number<0>{}); index_t h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{}); index_t w_pad_low = LowerPads{}.Get(Number<1>{});
unsigned h_pad_up = UpperPads{}.Get(Number<0>{}); index_t h_pad_up = UpperPads{}.Get(Number<0>{});
unsigned w_pad_up = UpperPads{}.Get(Number<1>{}); index_t w_pad_up = UpperPads{}.Get(Number<1>{});
std::size_t HiPerTile = HoPerTile + Y - 1; std::size_t HiPerTile = HoPerTile + Y - 1;
std::size_t WiPerTile = WoPerTile + X - 1; std::size_t WiPerTile = WoPerTile + X - 1;
...@@ -399,211 +399,211 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -399,211 +399,211 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
#if 0 #if 0
constexpr unsigned N = 1; constexpr index_t N = 1;
constexpr unsigned C = 1; constexpr index_t C = 1;
constexpr unsigned HI = 28; constexpr index_t HI = 28;
constexpr unsigned WI = 28; constexpr index_t WI = 28;
constexpr unsigned K = 1; constexpr index_t K = 1;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr unsigned N = 64; constexpr index_t N = 64;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 34; constexpr index_t HI = 34;
constexpr unsigned WI = 34; constexpr index_t WI = 34;
constexpr unsigned K = 64; constexpr index_t K = 64;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
// 3x3, 56x56 // 3x3, 56x56
constexpr unsigned N = 64; constexpr index_t N = 64;
constexpr unsigned C = 64; constexpr index_t C = 64;
constexpr unsigned HI = 56; constexpr index_t HI = 56;
constexpr unsigned WI = 56; constexpr index_t WI = 56;
constexpr unsigned K = 64; constexpr index_t K = 64;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
#elif 0 #elif 0
// 3x3, 58x58 // 3x3, 58x58
constexpr unsigned N = 64; constexpr index_t N = 64;
constexpr unsigned C = 64; constexpr index_t C = 64;
constexpr unsigned HI = 58; constexpr index_t HI = 58;
constexpr unsigned WI = 58; constexpr index_t WI = 58;
constexpr unsigned K = 64; constexpr index_t K = 64;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
#elif 0 #elif 0
// 5x5, 36x36 // 5x5, 36x36
constexpr unsigned N = 64; constexpr index_t N = 64;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 36; constexpr index_t HI = 36;
constexpr unsigned WI = 36; constexpr index_t WI = 36;
constexpr unsigned K = 64; constexpr index_t K = 64;
constexpr unsigned Y = 5; constexpr index_t Y = 5;
constexpr unsigned X = 5; constexpr index_t X = 5;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
// 7x7, 38x38 // 7x7, 38x38
constexpr unsigned N = 64; constexpr index_t N = 64;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 38; constexpr index_t HI = 38;
constexpr unsigned WI = 38; constexpr index_t WI = 38;
constexpr unsigned K = 64; constexpr index_t K = 64;
constexpr unsigned Y = 7; constexpr index_t Y = 7;
constexpr unsigned X = 7; constexpr index_t X = 7;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
// 3x3, 58x58 // 3x3, 58x58
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 128; constexpr index_t C = 128;
constexpr unsigned HI = 58; constexpr index_t HI = 58;
constexpr unsigned WI = 58; constexpr index_t WI = 58;
constexpr unsigned K = 256; constexpr index_t K = 256;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
#elif 0 #elif 0
// 3x3 filter, 58x58 image, 0x0 padding // 3x3 filter, 58x58 image, 0x0 padding
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 128; constexpr index_t C = 128;
constexpr unsigned HI = 58; constexpr index_t HI = 58;
constexpr unsigned WI = 58; constexpr index_t WI = 58;
constexpr unsigned K = 256; constexpr index_t K = 256;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
// 3x3 filter, 56x56 image, 1x1 padding // 3x3 filter, 56x56 image, 1x1 padding
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 128; constexpr index_t C = 128;
constexpr unsigned HI = 56; constexpr index_t HI = 56;
constexpr unsigned WI = 56; constexpr index_t WI = 56;
constexpr unsigned K = 256; constexpr index_t K = 256;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
constexpr unsigned HPad = 1; constexpr index_t HPad = 1;
constexpr unsigned WPad = 1; constexpr index_t WPad = 1;
#elif 0 #elif 0
// 3x3 filter, 28x28 image, 1x1 padding // 3x3 filter, 28x28 image, 1x1 padding
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 28; constexpr index_t HI = 28;
constexpr unsigned WI = 28; constexpr index_t WI = 28;
constexpr unsigned K = 512; constexpr index_t K = 512;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
constexpr unsigned HPad = 1; constexpr index_t HPad = 1;
constexpr unsigned WPad = 1; constexpr index_t WPad = 1;
#elif 0 #elif 0
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 28; constexpr index_t HI = 28;
constexpr unsigned WI = 28; constexpr index_t WI = 28;
constexpr unsigned K = 512; constexpr index_t K = 512;
constexpr unsigned Y = 1; constexpr index_t Y = 1;
constexpr unsigned X = 1; constexpr index_t X = 1;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
// 3x3 filter, 20x84 image, 1x1 padding // 3x3 filter, 20x84 image, 1x1 padding
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 20; constexpr index_t HI = 20;
constexpr unsigned WI = 84; constexpr index_t WI = 84;
constexpr unsigned K = 256; constexpr index_t K = 256;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
constexpr unsigned HPad = 1; constexpr index_t HPad = 1;
constexpr unsigned WPad = 1; constexpr index_t WPad = 1;
#elif 0 #elif 0
// 3x3 filter, 112x112 image, 1x1 padding // 3x3 filter, 112x112 image, 1x1 padding
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 64; constexpr index_t C = 64;
constexpr unsigned HI = 112; constexpr index_t HI = 112;
constexpr unsigned WI = 112; constexpr index_t WI = 112;
constexpr unsigned K = 128; constexpr index_t K = 128;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
constexpr unsigned HPad = 1; constexpr index_t HPad = 1;
constexpr unsigned WPad = 1; constexpr index_t WPad = 1;
#elif 0 #elif 0
// 5x5 filter, 20x86 image, 1x1 padding // 5x5 filter, 20x86 image, 1x1 padding
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 20; constexpr index_t HI = 20;
constexpr unsigned WI = 86; constexpr index_t WI = 86;
constexpr unsigned K = 512; constexpr index_t K = 512;
constexpr unsigned Y = 5; constexpr index_t Y = 5;
constexpr unsigned X = 5; constexpr index_t X = 5;
constexpr unsigned HPad = 1; constexpr index_t HPad = 1;
constexpr unsigned WPad = 1; constexpr index_t WPad = 1;
#elif 0 #elif 0
// 5x5 filter, 28x28 image, 2x2 padding // 5x5 filter, 28x28 image, 2x2 padding
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 192; constexpr index_t C = 192;
constexpr unsigned HI = 28; constexpr index_t HI = 28;
constexpr unsigned WI = 28; constexpr index_t WI = 28;
constexpr unsigned K = 32; constexpr index_t K = 32;
constexpr unsigned Y = 5; constexpr index_t Y = 5;
constexpr unsigned X = 5; constexpr index_t X = 5;
constexpr unsigned HPad = 2; constexpr index_t HPad = 2;
constexpr unsigned WPad = 2; constexpr index_t WPad = 2;
#elif 0 #elif 0
// 1x1 filter, 32x32 image // 1x1 filter, 32x32 image
constexpr unsigned N = 64; constexpr index_t N = 64;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 32; constexpr index_t HI = 32;
constexpr unsigned WI = 32; constexpr index_t WI = 32;
constexpr unsigned K = 512; constexpr index_t K = 512;
constexpr unsigned Y = 1; constexpr index_t Y = 1;
constexpr unsigned X = 1; constexpr index_t X = 1;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
// 1x1 filter, 14x14 image // 1x1 filter, 14x14 image, C = 2048
constexpr unsigned N = 128; constexpr index_t N = 128;
constexpr unsigned C = 2048; constexpr index_t C = 2048;
constexpr unsigned HI = 14; constexpr index_t HI = 14;
constexpr unsigned WI = 14; constexpr index_t WI = 14;
constexpr unsigned K = 512; constexpr index_t K = 512;
constexpr unsigned Y = 1; constexpr index_t Y = 1;
constexpr unsigned X = 1; constexpr index_t X = 1;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 1
// 1x1 filter, 14x14 image, C = 512 // 1x1 filter, 14x14 image, C = 512
constexpr unsigned N = 128; constexpr index_t N = 128;
constexpr unsigned C = 512; constexpr index_t C = 512;
constexpr unsigned HI = 14; constexpr index_t HI = 14;
constexpr unsigned WI = 14; constexpr index_t WI = 14;
constexpr unsigned K = 512; constexpr index_t K = 512;
constexpr unsigned Y = 1; constexpr index_t Y = 1;
constexpr unsigned X = 1; constexpr index_t X = 1;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#endif #endif
auto lower_pads = Sequence<HPad, WPad>{}; auto lower_pads = Sequence<HPad, WPad>{};
...@@ -634,7 +634,7 @@ int main(int argc, char* argv[]) ...@@ -634,7 +634,7 @@ int main(int argc, char* argv[])
} }
bool do_verification = atoi(argv[1]); bool do_verification = atoi(argv[1]);
unsigned nrepeat = atoi(argv[2]); index_t nrepeat = atoi(argv[2]);
if(do_verification) if(do_verification)
{ {
......
#pragma once #pragma once
template <class TData, unsigned NSize> template <class TData, index_t NSize>
struct Array struct Array
{ {
using Type = Array<TData, NSize>; using Type = Array<TData, NSize>;
static constexpr unsigned nSize = NSize; static constexpr index_t nSize = NSize;
unsigned mData[nSize]; index_t mData[nSize];
template <class... Xs> template <class... Xs>
__host__ __device__ Array(Xs... xs) : mData{static_cast<TData>(xs)...} __host__ __device__ Array(Xs... xs) : mData{static_cast<TData>(xs)...}
{ {
} }
__host__ __device__ TData operator[](unsigned i) const { return mData[i]; } __host__ __device__ TData operator[](index_t i) const { return mData[i]; }
}; };
#pragma once #pragma once
#include "common.hip.hpp" #include "common.hip.hpp"
template <unsigned NRow_, unsigned NCol_, unsigned RowStride_> template <index_t NRow_, index_t NCol_, index_t RowStride_>
struct ConstantMatrixDescriptor struct ConstantMatrixDescriptor
{ {
__host__ __device__ constexpr ConstantMatrixDescriptor() __host__ __device__ constexpr ConstantMatrixDescriptor()
...@@ -9,24 +9,28 @@ struct ConstantMatrixDescriptor ...@@ -9,24 +9,28 @@ struct ConstantMatrixDescriptor
static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!"); static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!");
} }
__host__ __device__ constexpr unsigned NRow() const { return NRow_; } __host__ __device__ constexpr index_t NRow() const { return NRow_; }
__host__ __device__ constexpr unsigned NCol() const { return NCol_; } __host__ __device__ constexpr index_t NCol() const { return NCol_; }
__host__ __device__ constexpr unsigned RowStride() const { return RowStride_; } __host__ __device__ constexpr index_t RowStride() const { return RowStride_; }
__host__ __device__ constexpr auto GetLengths() const { return Sequence<NRow_, NCol_>{}; } __host__ __device__ constexpr auto GetLengths() const { return Sequence<NRow_, NCol_>{}; }
__host__ __device__ constexpr unsigned GetElementSize() const { return NRow_ * NCol_; } __host__ __device__ constexpr index_t GetElementSize() const { return NRow_ * NCol_; }
__host__ __device__ constexpr unsigned GetElementSpace() const { return NRow_ * RowStride_; } __host__ __device__ constexpr index_t GetElementSpace() const { return NRow_ * RowStride_; }
__host__ __device__ unsigned Get1dIndex(unsigned irow, unsigned icol) const __host__ __device__ index_t Get1dIndex(index_t irow, index_t icol) const
{ {
#if DEVICE_BACKEND_HIP
return __mul24(irow, RowStride_) + icol;
#else
return irow * RowStride_ + icol; return irow * RowStride_ + icol;
#endif
} }
template <unsigned SubNRow, unsigned SubNCol> template <index_t SubNRow, index_t SubNCol>
__host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>, __host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
Number<SubNCol>) const Number<SubNCol>) const
{ {
...@@ -34,13 +38,13 @@ struct ConstantMatrixDescriptor ...@@ -34,13 +38,13 @@ struct ConstantMatrixDescriptor
} }
}; };
template <unsigned NRow, unsigned NCol> template <index_t NRow, index_t NCol>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>) __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>)
{ {
return ConstantMatrixDescriptor<NRow, NCol, NCol>{}; return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
} }
template <unsigned NRow, unsigned NCol, unsigned RowStride> template <index_t NRow, index_t NCol, index_t RowStride>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>) make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>)
{ {
......
...@@ -2,35 +2,35 @@ ...@@ -2,35 +2,35 @@
#include "common.hip.hpp" #include "common.hip.hpp"
// this is ugly, only for 2d // this is ugly, only for 2d
template <unsigned L0, unsigned L1> template <index_t L0, index_t L1>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>) __host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>)
{ {
return Sequence<L1, 1>{}; return Sequence<L1, 1>{};
} }
// this is ugly, only for 4d // this is ugly, only for 4d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3> template <index_t L0, index_t L1, index_t L2, index_t L3>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>) __host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
{ {
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{}; return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
} }
// this is ugly, only for 6d // this is ugly, only for 6d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned L4, unsigned L5> template <index_t L0, index_t L1, index_t L2, index_t L3, index_t L4, index_t L5>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5>) __host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5>)
{ {
return Sequence<L1 * L2 * L3 * L4 * L5, L2 * L3 * L4 * L5, L3 * L4 * L5, L4 * L5, L5, 1>{}; return Sequence<L1 * L2 * L3 * L4 * L5, L2 * L3 * L4 * L5, L3 * L4 * L5, L4 * L5, L5, 1>{};
} }
// this is ugly, only for 8d // this is ugly, only for 8d
template <unsigned L0, template <index_t L0,
unsigned L1, index_t L1,
unsigned L2, index_t L2,
unsigned L3, index_t L3,
unsigned L4, index_t L4,
unsigned L5, index_t L5,
unsigned L6, index_t L6,
unsigned L7> index_t L7>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7>) calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7>)
{ {
...@@ -45,20 +45,20 @@ __host__ __device__ constexpr auto ...@@ -45,20 +45,20 @@ __host__ __device__ constexpr auto
} }
// this is ugly, only for 2d // this is ugly, only for 2d
template <unsigned L0, unsigned L1, unsigned Align> template <index_t L0, index_t L1, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1>, __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1>,
Number<Align>) Number<Align>)
{ {
constexpr unsigned L1_align = Align * ((L1 + Align - 1) / Align); constexpr index_t L1_align = Align * ((L1 + Align - 1) / Align);
return Sequence<L1_align, 1>{}; return Sequence<L1_align, 1>{};
} }
// this is ugly, only for 4d // this is ugly, only for 4d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned Align> template <index_t L0, index_t L1, index_t L2, index_t L3, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2, L3>, __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2, L3>,
Number<Align>) Number<Align>)
{ {
constexpr unsigned L3_align = Align * ((L3 + Align - 1) / Align); constexpr index_t L3_align = Align * ((L3 + Align - 1) / Align);
return Sequence<L1 * L2 * L3_align, L2 * L3_align, L3_align, 1>{}; return Sequence<L1 * L2 * L3_align, L2 * L3_align, L3_align, 1>{};
} }
...@@ -66,27 +66,27 @@ template <class Lengths, class Strides> ...@@ -66,27 +66,27 @@ template <class Lengths, class Strides>
struct ConstantTensorDescriptor struct ConstantTensorDescriptor
{ {
using Type = ConstantTensorDescriptor<Lengths, Strides>; using Type = ConstantTensorDescriptor<Lengths, Strides>;
static constexpr unsigned nDim = Lengths::nDim; static constexpr index_t nDim = Lengths::nDim;
__host__ __device__ constexpr ConstantTensorDescriptor() __host__ __device__ constexpr ConstantTensorDescriptor()
{ {
static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent"); static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
} }
__host__ __device__ constexpr unsigned GetDimension() const { return nDim; } __host__ __device__ constexpr index_t GetDimension() const { return nDim; }
__host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; } __host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; }
__host__ __device__ constexpr Strides GetStrides() const { return Strides{}; } __host__ __device__ constexpr Strides GetStrides() const { return Strides{}; }
template <unsigned I> template <index_t I>
__host__ __device__ constexpr unsigned GetLength(Number<I>) const __host__ __device__ constexpr index_t GetLength(Number<I>) const
{ {
return Lengths{}.Get(Number<I>{}); return Lengths{}.Get(Number<I>{});
} }
template <unsigned I> template <index_t I>
__host__ __device__ constexpr unsigned GetStride(Number<I>) const __host__ __device__ constexpr index_t GetStride(Number<I>) const
{ {
return Strides{}.Get(Number<I>{}); return Strides{}.Get(Number<I>{});
} }
...@@ -95,18 +95,18 @@ struct ConstantTensorDescriptor ...@@ -95,18 +95,18 @@ struct ConstantTensorDescriptor
struct GetElementSize_f struct GetElementSize_f
{ {
template <class IDim> template <class IDim>
__host__ __device__ constexpr unsigned operator()(IDim idim) const __host__ __device__ constexpr index_t operator()(IDim idim) const
{ {
return Type{}.GetLength(idim); return Type{}.GetLength(idim);
} }
}; };
__host__ __device__ constexpr unsigned GetElementSize() const __host__ __device__ constexpr index_t GetElementSize() const
{ {
// c++14 doesn't support constexpr lambdas, has to use this trick instead // c++14 doesn't support constexpr lambdas, has to use this trick instead
struct multiply struct multiply
{ {
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const __host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
{ {
return a * b; return a * b;
} }
...@@ -119,19 +119,19 @@ struct ConstantTensorDescriptor ...@@ -119,19 +119,19 @@ struct ConstantTensorDescriptor
struct GetElementSpace_f struct GetElementSpace_f
{ {
template <class IDim> template <class IDim>
__host__ __device__ constexpr unsigned operator()(IDim idim) const __host__ __device__ constexpr index_t operator()(IDim idim) const
{ {
return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim); return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim);
} }
}; };
template <class Align = Number<1>> template <class Align = Number<1>>
__host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const __host__ __device__ constexpr index_t GetElementSpace(Align align = Align{}) const
{ {
// c++14 doesn't support constexpr lambdas, has to use this trick instead // c++14 doesn't support constexpr lambdas, has to use this trick instead
struct add struct add
{ {
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const __host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
{ {
return a + b; return a + b;
} }
...@@ -141,17 +141,21 @@ struct ConstantTensorDescriptor ...@@ -141,17 +141,21 @@ struct ConstantTensorDescriptor
} }
template <class... Is> template <class... Is>
__host__ __device__ unsigned Get1dIndex(Is... is) const __host__ __device__ index_t Get1dIndex(Is... is) const
{ {
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong"); static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
const auto multi_id = Array<unsigned, nDim>(is...); const auto multi_id = Array<index_t, nDim>(is...);
unsigned id = 0; index_t id = 0;
static_loop_n<nDim>{}([&](auto IDim) { static_loop_n<nDim>{}([&](auto IDim) {
constexpr unsigned idim = IDim.Get(); constexpr index_t idim = IDim.Get();
#if DEVICE_BACKEND_HIP
id += __mul24(multi_id[idim], GetStride(IDim));
#else
id += multi_id[idim] * GetStride(IDim); id += multi_id[idim] * GetStride(IDim);
#endif
}); });
return id; return id;
...@@ -163,7 +167,7 @@ struct ConstantTensorDescriptor ...@@ -163,7 +167,7 @@ struct ConstantTensorDescriptor
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{}; return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
} }
template <unsigned IDim, unsigned NVector> template <index_t IDim, index_t NVector>
__host__ __device__ constexpr auto Vectorize(Number<IDim>, Number<NVector>) const __host__ __device__ constexpr auto Vectorize(Number<IDim>, Number<NVector>) const
{ {
assert(false); // not implemented assert(false); // not implemented
...@@ -183,7 +187,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride ...@@ -183,7 +187,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride
return ConstantTensorDescriptor<Lengths, Strides>{}; return ConstantTensorDescriptor<Lengths, Strides>{};
} }
template <class Lengths, unsigned Align> template <class Lengths, index_t Align>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>) __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
{ {
using Strides = decltype(calculate_default_strides_aligned(Lengths{}, Number<Align>{})); using Strides = decltype(calculate_default_strides_aligned(Lengths{}, Number<Align>{}));
...@@ -194,7 +198,7 @@ template <class TDesc> ...@@ -194,7 +198,7 @@ template <class TDesc>
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
{ {
constexpr auto desc = TDesc{}; constexpr auto desc = TDesc{};
constexpr unsigned ndim = desc.GetDimension(); constexpr index_t ndim = desc.GetDimension();
static_assert(ndim >= 2 && ndim <= 8, "wrong!"); static_assert(ndim >= 2 && ndim <= 8, "wrong!");
......
...@@ -2,38 +2,38 @@ ...@@ -2,38 +2,38 @@
#include "constant_integral.hip.hpp" #include "constant_integral.hip.hpp"
#include "functional.hip.hpp" #include "functional.hip.hpp"
template <unsigned... Is> template <index_t... Is>
struct Sequence struct Sequence
{ {
using Type = Sequence<Is...>; using Type = Sequence<Is...>;
static constexpr unsigned nDim = sizeof...(Is); static constexpr index_t nDim = sizeof...(Is);
const unsigned mData[nDim] = {Is...}; const index_t mData[nDim] = {Is...};
template <unsigned I> template <index_t I>
__host__ __device__ constexpr unsigned Get(Number<I>) const __host__ __device__ constexpr index_t Get(Number<I>) const
{ {
return mData[I]; return mData[I];
} }
// this is ugly, only for nDIm = 4 // this is ugly, only for nDIm = 4
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3> template <index_t I0, index_t I1, index_t I2, index_t I3>
__host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const __host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const
{ {
static_assert(nDim == 4, "nDim != 4"); static_assert(nDim == 4, "nDim != 4");
constexpr auto old_sequence = Type{}; constexpr auto old_sequence = Type{};
constexpr unsigned NR0 = old_sequence.mData[I0]; constexpr index_t NR0 = old_sequence.mData[I0];
constexpr unsigned NR1 = old_sequence.mData[I1]; constexpr index_t NR1 = old_sequence.mData[I1];
constexpr unsigned NR2 = old_sequence.mData[I2]; constexpr index_t NR2 = old_sequence.mData[I2];
constexpr unsigned NR3 = old_sequence.mData[I3]; constexpr index_t NR3 = old_sequence.mData[I3];
return Sequence<NR0, NR1, NR2, NR3>{}; return Sequence<NR0, NR1, NR2, NR3>{};
} }
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3> template <index_t I0, index_t I1, index_t I2, index_t I3>
__host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence<I0, I1, I2, I3>) const __host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence<I0, I1, I2, I3>) const
{ {
// don't know how to implement this // don't know how to implement this
...@@ -41,7 +41,7 @@ struct Sequence ...@@ -41,7 +41,7 @@ struct Sequence
assert(false); assert(false);
} }
template <unsigned I> template <index_t I>
__host__ __device__ constexpr auto PushBack(Number<I>) const __host__ __device__ constexpr auto PushBack(Number<I>) const
{ {
return Sequence<Is..., I>{}; return Sequence<Is..., I>{};
...@@ -56,14 +56,14 @@ struct Sequence ...@@ -56,14 +56,14 @@ struct Sequence
} }
}; };
template <unsigned... Is, unsigned I> template <index_t... Is, index_t I>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>) __host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
{ {
static_assert(sizeof...(Is) >= 1, "empty Sequence!"); static_assert(sizeof...(Is) >= 1, "empty Sequence!");
return Sequence<Is...>{}; return Sequence<Is...>{};
} }
template <class F, unsigned... Xs, unsigned... Ys> template <class F, index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequence<Ys...>, F f) __host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequence<Ys...>, F f)
{ {
static_assert(Sequence<Xs...>::nDim == Sequence<Ys...>::nDim, "Dim not the same"); static_assert(Sequence<Xs...>::nDim == Sequence<Ys...>::nDim, "Dim not the same");
...@@ -71,12 +71,12 @@ __host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequenc ...@@ -71,12 +71,12 @@ __host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequenc
return Sequence<f(Xs, Ys)...>{}; return Sequence<f(Xs, Ys)...>{};
} }
template <unsigned... Xs, unsigned... Ys> template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequence<Ys...>) __host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequence<Ys...>)
{ {
struct add struct add
{ {
__host__ __device__ constexpr unsigned operator()(unsigned x, unsigned y) const __host__ __device__ constexpr index_t operator()(index_t x, index_t y) const
{ {
return x + y; return x + y;
} }
...@@ -85,7 +85,7 @@ __host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequen ...@@ -85,7 +85,7 @@ __host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequen
return sequence_sequence_op(Sequence<Xs...>{}, Sequence<Ys...>{}, add{}); return sequence_sequence_op(Sequence<Xs...>{}, Sequence<Ys...>{}, add{});
} }
template <unsigned... Is> template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopBack() const __host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
{ {
return sequence_pop_back(Type{}); return sequence_pop_back(Type{});
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -3,16 +3,16 @@ ...@@ -3,16 +3,16 @@
#include "threadwise_4d_tensor_op.hip.hpp" #include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_direct_convolution.hip.hpp" #include "threadwise_direct_convolution.hip.hpp"
template <unsigned BlockSize, template <index_t BlockSize,
class Float, class Float,
class InBlockDesc, class InBlockDesc,
class WeiBlockDesc, class WeiBlockDesc,
class OutBlockDesc, class OutBlockDesc,
unsigned NPerThread, index_t NPerThread,
unsigned KPerThread, index_t KPerThread,
unsigned CPerThread, index_t CPerThread,
unsigned HoPerThread, index_t HoPerThread,
unsigned WoPerThread> index_t WoPerThread>
__device__ void blockwise_direct_convolution(InBlockDesc, __device__ void blockwise_direct_convolution(InBlockDesc,
Float* const __restrict__ p_in_block, Float* const __restrict__ p_in_block,
WeiBlockDesc, WeiBlockDesc,
...@@ -29,17 +29,17 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -29,17 +29,17 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr auto wei_block_desc = WeiBlockDesc{}; constexpr auto wei_block_desc = WeiBlockDesc{};
constexpr auto out_block_desc = OutBlockDesc{}; constexpr auto out_block_desc = OutBlockDesc{};
constexpr unsigned Y = wei_block_desc.GetLength(I2); constexpr index_t Y = wei_block_desc.GetLength(I2);
constexpr unsigned X = wei_block_desc.GetLength(I3); constexpr index_t X = wei_block_desc.GetLength(I3);
constexpr unsigned InTileSizeH = HoPerThread + Y - 1; constexpr index_t InTileSizeH = HoPerThread + Y - 1;
constexpr unsigned InTileSizeW = WoPerThread + X - 1; constexpr index_t InTileSizeW = WoPerThread + X - 1;
// divide thread work // divide thread work
constexpr unsigned NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread; constexpr index_t NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread;
constexpr unsigned KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread; constexpr index_t KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread;
constexpr unsigned YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread; constexpr index_t YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread;
constexpr unsigned XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread; constexpr index_t XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread;
#if 0 #if 0
if(threadIdx.x == 0) if(threadIdx.x == 0)
...@@ -68,27 +68,27 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -68,27 +68,27 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr auto out_thread_block_desc = constexpr auto out_thread_block_desc =
make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_block_desc.GetStrides()); make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_block_desc.GetStrides());
const unsigned thread_id = threadIdx.x; const index_t thread_id = threadIdx.x;
for(unsigned thread_work_id = thread_id; for(index_t thread_work_id = thread_id;
thread_work_id < NThreadWork * KThreadWork * YThreadWork * XThreadWork; thread_work_id < NThreadWork * KThreadWork * YThreadWork * XThreadWork;
thread_work_id += BlockSize) thread_work_id += BlockSize)
{ {
unsigned itmp = thread_work_id; index_t itmp = thread_work_id;
unsigned n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork); index_t n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork);
itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork); itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork);
unsigned k_thread_work_id = itmp / (YThreadWork * XThreadWork); index_t k_thread_work_id = itmp / (YThreadWork * XThreadWork);
itmp -= k_thread_work_id * (YThreadWork * XThreadWork); itmp -= k_thread_work_id * (YThreadWork * XThreadWork);
unsigned y_thread_work_id = itmp / XThreadWork; index_t y_thread_work_id = itmp / XThreadWork;
unsigned x_thread_work_id = itmp - y_thread_work_id * XThreadWork; index_t x_thread_work_id = itmp - y_thread_work_id * XThreadWork;
unsigned n_thread_data_begin = n_thread_work_id * NPerThread; index_t n_thread_data_begin = n_thread_work_id * NPerThread;
unsigned k_thread_data_begin = k_thread_work_id * KPerThread; index_t k_thread_data_begin = k_thread_work_id * KPerThread;
unsigned ho_thread_data_begin = y_thread_work_id * HoPerThread; index_t ho_thread_data_begin = y_thread_work_id * HoPerThread;
unsigned wo_thread_data_begin = x_thread_work_id * WoPerThread; index_t wo_thread_data_begin = x_thread_work_id * WoPerThread;
unsigned hi_thread_data_begin = ho_thread_data_begin; // minus padding index_t hi_thread_data_begin = ho_thread_data_begin; // minus padding
unsigned wi_thread_data_begin = wo_thread_data_begin; // minus padding index_t wi_thread_data_begin = wo_thread_data_begin; // minus padding
Float p_out_thread[out_thread_desc.GetElementSpace()]; Float p_out_thread[out_thread_desc.GetElementSpace()];
...@@ -102,7 +102,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -102,7 +102,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
p_out_thread, p_out_thread,
out_thread_desc.GetLengths()); out_thread_desc.GetLengths());
for(unsigned c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1); for(index_t c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1);
c_thread_data_begin += CPerThread) c_thread_data_begin += CPerThread)
{ {
// threadwise convolution // threadwise convolution
......
This diff is collapsed.
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
#include "Array.hip.hpp" #include "Array.hip.hpp"
#include "functional.hip.hpp" #include "functional.hip.hpp"
__device__ unsigned get_thread_local_1d_id() { return threadIdx.x; } __device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ unsigned get_block_1d_id() { return blockIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; }
template <class T1, class T2> template <class T1, class T2>
struct is_same struct is_same
...@@ -35,7 +35,7 @@ __host__ __device__ constexpr T min(T a, T b) ...@@ -35,7 +35,7 @@ __host__ __device__ constexpr T min(T a, T b)
} }
#endif #endif
__host__ __device__ constexpr unsigned integer_divide_ceil(unsigned a, unsigned b) __host__ __device__ constexpr index_t integer_divide_ceil(index_t a, index_t b)
{ {
return (a + b - 1) / b; return (a + b - 1) / b;
} }
...@@ -11,3 +11,5 @@ ...@@ -11,3 +11,5 @@
#include "nvToolsExt.h" #include "nvToolsExt.h"
#include "helper_cuda.h" #include "helper_cuda.h"
#endif #endif
using index_t = uint32_t;
...@@ -8,5 +8,5 @@ struct integral_constant ...@@ -8,5 +8,5 @@ struct integral_constant
__host__ __device__ constexpr T Get() const { return value; } __host__ __device__ constexpr T Get() const { return value; }
}; };
template <unsigned N> template <index_t N>
using Number = integral_constant<unsigned, N>; using Number = integral_constant<index_t, N>;
#pragma once #pragma once
#include "config.h" #include "config.h"
template <class T, unsigned N> template <class T, index_t N>
struct vector_type struct vector_type
{ {
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment