Commit 766b0a9e authored by Chao Liu's avatar Chao Liu
Browse files

experimenting

parent f35c64eb
...@@ -10,7 +10,7 @@ void device_direct_convolution_1(InDesc, ...@@ -10,7 +10,7 @@ void device_direct_convolution_1(InDesc,
const Tensor<T>& wei, const Tensor<T>& wei,
OutDesc, OutDesc,
Tensor<T>& out, Tensor<T>& out,
unsigned nrepeat) index_t nrepeat)
{ {
std::size_t data_sz = sizeof(T); std::size_t data_sz = sizeof(T);
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace()); DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
...@@ -34,28 +34,28 @@ void device_direct_convolution_1(InDesc, ...@@ -34,28 +34,28 @@ void device_direct_convolution_1(InDesc,
#if 1 #if 1
// 3x3, 34x34 // 3x3, 34x34
constexpr unsigned NPerBlock = 2; constexpr index_t NPerBlock = 2;
constexpr unsigned KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2; constexpr index_t NPerThread = 2;
constexpr unsigned KPerThread = 4; constexpr index_t KPerThread = 4;
constexpr unsigned CPerThread = 2; constexpr index_t CPerThread = 2;
constexpr unsigned HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr unsigned WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#endif #endif
constexpr unsigned GridSize = constexpr index_t GridSize =
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) * (out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock); (out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel(gridwise_direct_convolution_1<T, float time = launch_kernel(gridwise_direct_convolution_1<T,
InDesc, InDesc,
......
...@@ -10,7 +10,7 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc, ...@@ -10,7 +10,7 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& wei, const Tensor<T>& wei,
OutDesc, OutDesc,
Tensor<T>& out, Tensor<T>& out,
unsigned nrepeat) index_t nrepeat)
{ {
std::size_t data_sz = sizeof(T); std::size_t data_sz = sizeof(T);
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace()); DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
...@@ -34,49 +34,49 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc, ...@@ -34,49 +34,49 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
#if 1 #if 1
// 3x3, 34x34, 128 thread // 3x3, 34x34, 128 thread
constexpr unsigned NPerBlock = 2; constexpr index_t NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr index_t KPerBlock = 32;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2; constexpr index_t NPerThread = 2;
constexpr unsigned KPerThread = 4; constexpr index_t KPerThread = 4;
constexpr unsigned CPerThread = 2; constexpr index_t CPerThread = 2;
constexpr unsigned HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr unsigned WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr index_t InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1 #elif 1
// 3x3, 34x34, 128 thread, fp16 // 3x3, 34x34, 128 thread, fp16
constexpr unsigned NPerBlock = 2; constexpr index_t NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr index_t KPerBlock = 32;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2; constexpr index_t NPerThread = 2;
constexpr unsigned KPerThread = 4; constexpr index_t KPerThread = 4;
constexpr unsigned CPerThread = 2; constexpr index_t CPerThread = 2;
constexpr unsigned HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr unsigned WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr index_t InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#endif #endif
constexpr unsigned GridSize = constexpr index_t GridSize =
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) * (out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock); (out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = float time =
launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw<T, launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw<T,
......
...@@ -10,13 +10,13 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -10,13 +10,13 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
const Tensor<TInWei>& wei_kcyx, const Tensor<TInWei>& wei_kcyx,
OutDesc, OutDesc,
Tensor<TOut>& out_nkhw, Tensor<TOut>& out_nkhw,
unsigned nrepeat) index_t nrepeat)
{ {
// this suppose in / wei data type is int8x4 // this suppose in / wei data type is int8x4
constexpr unsigned NVector = 4; constexpr index_t NVector = 4;
using accum_t = int32_t; using accum_t = int32_t;
using vector_t = vector_type<TInWei, NVector>; using vector_t = vector_type<TInWei, NVector>;
using vector_mem_t = typename vector_t::MemoryType; using vector_mem_t = typename vector_t::MemoryType;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -27,17 +27,17 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -27,17 +27,17 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{}; constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{}; constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2); constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3); constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0); constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0); constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1); constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2); constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3); constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// vectorized input // vectorized input
auto in_nchw_vec_desc = make_ConstantTensorDescriptor(Sequence<N, C / NVector, Hi, Wi>{}); auto in_nchw_vec_desc = make_ConstantTensorDescriptor(Sequence<N, C / NVector, Hi, Wi>{});
...@@ -96,84 +96,84 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -96,84 +96,84 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
#if 0 #if 0
// 3x3, 34x34, 128 thread, fp32, vector = 1 // 3x3, 34x34, 128 thread, fp32, vector = 1
constexpr unsigned NPerBlock = 2; constexpr index_t NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr index_t KPerBlock = 32;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2; constexpr index_t NPerThread = 2;
constexpr unsigned KPerThread = 4; constexpr index_t KPerThread = 4;
constexpr unsigned CPerThread = 2; constexpr index_t CPerThread = 2;
constexpr unsigned HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr unsigned WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr index_t InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2; constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 3x3, 34x34, 128 thread, fp32, vector = 2 // 3x3, 34x34, 128 thread, fp32, vector = 2
constexpr unsigned NPerBlock = 2; constexpr index_t NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr index_t KPerBlock = 32;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2; constexpr index_t NPerThread = 2;
constexpr unsigned KPerThread = 4; constexpr index_t KPerThread = 4;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr unsigned WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr index_t InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2; constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 3x3, 34x34, 128 thread, int8, vector = 4 // 3x3, 34x34, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 2; constexpr index_t NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr index_t KPerBlock = 32;
constexpr unsigned CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 1; constexpr index_t NPerThread = 1;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned CPerThread = 2; constexpr index_t CPerThread = 2;
constexpr unsigned HoPerThread = 4; constexpr index_t HoPerThread = 4;
constexpr unsigned WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr index_t InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2; constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1 #elif 1
// 1x1, 32x32, 128 thread, int8, vector = 4 // 1x1, 32x32, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 1; constexpr index_t NPerBlock = 1;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 16; constexpr index_t CPerBlock = 16;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 1; constexpr index_t NPerThread = 1;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned CPerThread = 2; constexpr index_t CPerThread = 2;
constexpr unsigned HoPerThread = 4; constexpr index_t HoPerThread = 4;
constexpr unsigned WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr index_t InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2; constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#endif #endif
constexpr unsigned GridSize = constexpr index_t GridSize =
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock); (N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel( float time = launch_kernel(
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei, gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei,
......
...@@ -10,7 +10,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -10,7 +10,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kcyx,
OutDesc, OutDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
unsigned nrepeat) index_t nrepeat)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -21,17 +21,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -21,17 +21,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{}; constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{}; constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2); constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3); constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0); constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0); constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1); constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2); constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3); constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight // reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{}); auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
...@@ -76,218 +76,218 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -76,218 +76,218 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
#if 0 #if 0
// for 3x3, 34x34 // for 3x3, 34x34
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 8; constexpr index_t NPerThread = 8;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 4; constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr unsigned InBlockCopy_ThreadPerDimH = 4; constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2; constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2; constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 5x5, 36x36 // for 5x5, 36x36
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 8; constexpr index_t NPerThread = 8;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopy_ThreadPerDimC = 2; constexpr index_t InBlockCopy_ThreadPerDimC = 2;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2; constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 4; constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 2; constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2; constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 3x3 58x58, NKC = 64, 64, 256 // 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 2; // not used, yet constexpr index_t InBlockCopyDataPerRead = 2; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 3x3 58x58, NKC = 16,256,128 // 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8; constexpr index_t NPerBlock = 8;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 7x7, 38x38 // for 7x7, 38x38
constexpr unsigned NPerBlock = 8; constexpr index_t NPerBlock = 8;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 1; constexpr index_t CPerBlock = 1;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet constexpr index_t InBlockCopyDataPerRead = 4; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 3x3, 56x56 // for 3x3, 56x56
constexpr unsigned NPerBlock = 32; constexpr index_t NPerBlock = 32;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2; constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 1x1, 28x28 // for 1x1, 28x28
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr unsigned CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2; constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 8; constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2; constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2; constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2; constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1 #elif 1
// for 1x1, 14x14 // for 1x1, 14x14
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr unsigned CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2; constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 8; constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2; constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2; constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2; constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#endif #endif
constexpr unsigned GridSize = constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel( float time = launch_kernel(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn<GridSize, gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn<GridSize,
......
...@@ -12,7 +12,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc, ...@@ -12,7 +12,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
LowerPads, LowerPads,
UpperPads, UpperPads,
unsigned nrepeat) index_t nrepeat)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -23,17 +23,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc, ...@@ -23,17 +23,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{}; constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{}; constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2); constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3); constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0); constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0); constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1); constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2); constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3); constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight // reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{}); auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
...@@ -77,177 +77,177 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc, ...@@ -77,177 +77,177 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
out_khwn_device_buf.ToDevice(out_khwn.mData.data()); out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 0 #if 0
constexpr unsigned NPerBlock = 1; constexpr index_t NPerBlock = 1;
constexpr unsigned KPerBlock = 1; constexpr index_t KPerBlock = 1;
constexpr unsigned CPerBlock = 1; constexpr index_t CPerBlock = 1;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 1; constexpr index_t NPerThread = 1;
constexpr unsigned KPerThread = 1; constexpr index_t KPerThread = 1;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 1; constexpr index_t WeiBlockCopyThreadPerDim0 = 1;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 1; constexpr index_t WeiBlockCopyThreadPerDim1 = 1;
constexpr unsigned BlockSize = 8; constexpr index_t BlockSize = 8;
#elif 1 #elif 1
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256 // for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 3x3 58x58, NKC = 16,256,128 // 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8; constexpr index_t NPerBlock = 8;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 5x5, 36x36 // for 5x5, 36x36
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 7x7, 38x38 // for 7x7, 38x38
constexpr unsigned NPerBlock = 8; constexpr index_t NPerBlock = 8;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 3x3, 56x56 // for 3x3, 56x56
constexpr unsigned NPerBlock = 32; constexpr index_t NPerBlock = 32;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2; constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1 #elif 1
// 3x3 56x56, NKC = 16,256,128, with padding // 3x3 56x56, NKC = 16,256,128, with padding
// 3x3 28x28, NKC = 16,512,256, with padding // 3x3 28x28, NKC = 16,512,256, with padding
// 3x3 20x84, NKC = 16,256,256, with padding // 3x3 20x84, NKC = 16,256,256, with padding
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 2; constexpr index_t WeiBlockCopyThreadPerDim0 = 2;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 64; constexpr index_t WeiBlockCopyThreadPerDim1 = 64;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 5x5 filter, 20x84 image, 1x1 padding // for 5x5 filter, 20x84 image, 1x1 padding
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 1; constexpr index_t CPerBlock = 1;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 5x5 filter, 28x28 image, 2x2 padding // 5x5 filter, 28x28 image, 2x2 padding
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 32; constexpr index_t KPerBlock = 32;
constexpr unsigned CPerBlock = 2; constexpr index_t CPerBlock = 2;
constexpr unsigned HoPerBlock = 4; constexpr index_t HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4; constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr index_t CPerThread = 1;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// for 1x1, 28x28 // for 1x1, 28x28
constexpr unsigned NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr unsigned KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr unsigned CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr unsigned HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2; constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr unsigned CPerThread = 2; constexpr index_t CPerThread = 2;
constexpr unsigned HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#endif #endif
constexpr unsigned GridSize = constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel( float time = launch_kernel(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded<GridSize, gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded<GridSize,
......
...@@ -11,7 +11,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -11,7 +11,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kcyx,
OutDesc, OutDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
unsigned nrepeat) index_t nrepeat)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -22,19 +22,19 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -22,19 +22,19 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{}; constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{}; constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned N = in_nchw_desc.GetLength(I0); constexpr index_t N = in_nchw_desc.GetLength(I0);
constexpr unsigned Hi = in_nchw_desc.GetLength(I2); constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3); constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0); constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1); constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2); constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3); constexpr index_t X = wei_kcyx_desc.GetLength(I3);
constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1); constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
// convert in_nchw to in_cnhw // convert in_nchw to in_cnhw
auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{}); auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{});
...@@ -71,128 +71,158 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -71,128 +71,158 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
#if 0 #if 0
// 3x3, 34x34 // 3x3, 34x34
// need to use register double buffer for GEMM // need to use register double buffer for GEMM
constexpr unsigned BPerBlock = 128; constexpr index_t BPerBlock = 128;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr index_t CPerBlock = 4;
constexpr unsigned BPerThread = 8; constexpr index_t BPerThread = 8;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 8; constexpr index_t GemmNLevel1Cluster = 8;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8; constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8; constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4; constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16; constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 1x1, 28x28, 64 threads // 1x1, 28x28, 64 threads
constexpr unsigned BPerBlock = 64; constexpr index_t BPerBlock = 64;
constexpr unsigned KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr unsigned CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8; constexpr index_t BPerThread = 8;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8; constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8; constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4; constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16; constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 64; constexpr index_t BlockSize = 64;
#elif 1 #elif 0
// 1x1, 28x28, 128 threads, no lds-double-buffer // 1x1, 28x28, 128 threads, no lds-double-buffer
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128 // 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
constexpr unsigned BPerBlock = 64; constexpr index_t BPerBlock = 64;
constexpr unsigned KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr unsigned CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8; constexpr index_t BPerThread = 8;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8; constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8; constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4; constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16; constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
// 1x1, 28x28, 256 thread // 1x1, 28x28, 256 thread
constexpr unsigned BPerBlock = 128; constexpr index_t BPerBlock = 128;
constexpr unsigned KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr unsigned CPerBlock = 8; constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 256;
#elif 1
// 1x1, 14x14, Vega 10
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8; constexpr index_t BPerThread = 8;
constexpr unsigned KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8; constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8; constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4; constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16; constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 256; constexpr index_t BlockSize = 128;
#endif #endif
constexpr unsigned GridSize = constexpr index_t GridSize =
((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); ((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
...@@ -208,7 +238,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -208,7 +238,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data()); out_khwn_device_buf.ToDevice(out_khwn.mData.data());
for(unsigned i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel( float time = launch_kernel(
#if 1 #if 1
......
...@@ -40,11 +40,11 @@ struct GeneratorTensor_Checkboard ...@@ -40,11 +40,11 @@ struct GeneratorTensor_Checkboard
template <class... Ts> template <class... Ts>
double operator()(Ts... Xs) const double operator()(Ts... Xs) const
{ {
std::array<unsigned long, sizeof...(Ts)> dims = {{Xs...}}; std::array<index_t, sizeof...(Ts)> dims = {{Xs...}};
return std::accumulate(dims.begin(), return std::accumulate(dims.begin(),
dims.end(), dims.end(),
true, true,
[](bool init, unsigned long x) -> int { return init != (x % 2); }) [](bool init, index_t x) -> int { return init != (x % 2); })
? 1 ? 1
: -1; : -1;
} }
...@@ -80,9 +80,9 @@ auto make_TensorDescriptor(TConstTensorDesc) ...@@ -80,9 +80,9 @@ auto make_TensorDescriptor(TConstTensorDesc)
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto desc = TConstTensorDesc{}; constexpr auto desc = TConstTensorDesc{};
std::initializer_list<unsigned> lengths = { std::initializer_list<index_t> lengths = {
desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3)}; desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3)};
std::initializer_list<unsigned> strides = { std::initializer_list<index_t> strides = {
desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)}; desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)};
return TensorDescriptor(lengths, strides); return TensorDescriptor(lengths, strides);
...@@ -95,11 +95,11 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw, ...@@ -95,11 +95,11 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
LowerPads, LowerPads,
UpperPads) UpperPads)
{ {
unsigned h_pad_low = LowerPads{}.Get(Number<0>{}); index_t h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{}); index_t w_pad_low = LowerPads{}.Get(Number<1>{});
unsigned h_pad_up = UpperPads{}.Get(Number<0>{}); index_t h_pad_up = UpperPads{}.Get(Number<0>{});
unsigned w_pad_up = UpperPads{}.Get(Number<1>{}); index_t w_pad_up = UpperPads{}.Get(Number<1>{});
auto f = [&](auto n, auto k, auto ho, auto wo) { auto f = [&](auto n, auto k, auto ho, auto wo) {
double v = 0; double v = 0;
...@@ -153,11 +153,11 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw, ...@@ -153,11 +153,11 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
std::size_t HO = out_nkhw.mDesc.GetLengths()[2]; std::size_t HO = out_nkhw.mDesc.GetLengths()[2];
std::size_t WO = out_nkhw.mDesc.GetLengths()[3]; std::size_t WO = out_nkhw.mDesc.GetLengths()[3];
unsigned h_pad_low = LowerPads{}.Get(Number<0>{}); index_t h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{}); index_t w_pad_low = LowerPads{}.Get(Number<1>{});
unsigned h_pad_up = UpperPads{}.Get(Number<0>{}); index_t h_pad_up = UpperPads{}.Get(Number<0>{});
unsigned w_pad_up = UpperPads{}.Get(Number<1>{}); index_t w_pad_up = UpperPads{}.Get(Number<1>{});
std::size_t HiPerTile = HoPerTile + Y - 1; std::size_t HiPerTile = HoPerTile + Y - 1;
std::size_t WiPerTile = WoPerTile + X - 1; std::size_t WiPerTile = WoPerTile + X - 1;
...@@ -399,211 +399,211 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -399,211 +399,211 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
#if 0 #if 0
constexpr unsigned N = 1; constexpr index_t N = 1;
constexpr unsigned C = 1; constexpr index_t C = 1;
constexpr unsigned HI = 28; constexpr index_t HI = 28;
constexpr unsigned WI = 28; constexpr index_t WI = 28;
constexpr unsigned K = 1; constexpr index_t K = 1;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr unsigned N = 64; constexpr index_t N = 64;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 34; constexpr index_t HI = 34;
constexpr unsigned WI = 34; constexpr index_t WI = 34;
constexpr unsigned K = 64; constexpr index_t K = 64;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
// 3x3, 56x56 // 3x3, 56x56
constexpr unsigned N = 64; constexpr index_t N = 64;
constexpr unsigned C = 64; constexpr index_t C = 64;
constexpr unsigned HI = 56; constexpr index_t HI = 56;
constexpr unsigned WI = 56; constexpr index_t WI = 56;
constexpr unsigned K = 64; constexpr index_t K = 64;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
#elif 0 #elif 0
// 3x3, 58x58 // 3x3, 58x58
constexpr unsigned N = 64; constexpr index_t N = 64;
constexpr unsigned C = 64; constexpr index_t C = 64;
constexpr unsigned HI = 58; constexpr index_t HI = 58;
constexpr unsigned WI = 58; constexpr index_t WI = 58;
constexpr unsigned K = 64; constexpr index_t K = 64;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
#elif 0 #elif 0
// 5x5, 36x36 // 5x5, 36x36
constexpr unsigned N = 64; constexpr index_t N = 64;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 36; constexpr index_t HI = 36;
constexpr unsigned WI = 36; constexpr index_t WI = 36;
constexpr unsigned K = 64; constexpr index_t K = 64;
constexpr unsigned Y = 5; constexpr index_t Y = 5;
constexpr unsigned X = 5; constexpr index_t X = 5;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
// 7x7, 38x38 // 7x7, 38x38
constexpr unsigned N = 64; constexpr index_t N = 64;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 38; constexpr index_t HI = 38;
constexpr unsigned WI = 38; constexpr index_t WI = 38;
constexpr unsigned K = 64; constexpr index_t K = 64;
constexpr unsigned Y = 7; constexpr index_t Y = 7;
constexpr unsigned X = 7; constexpr index_t X = 7;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
// 3x3, 58x58 // 3x3, 58x58
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 128; constexpr index_t C = 128;
constexpr unsigned HI = 58; constexpr index_t HI = 58;
constexpr unsigned WI = 58; constexpr index_t WI = 58;
constexpr unsigned K = 256; constexpr index_t K = 256;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
#elif 0 #elif 0
// 3x3 filter, 58x58 image, 0x0 padding // 3x3 filter, 58x58 image, 0x0 padding
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 128; constexpr index_t C = 128;
constexpr unsigned HI = 58; constexpr index_t HI = 58;
constexpr unsigned WI = 58; constexpr index_t WI = 58;
constexpr unsigned K = 256; constexpr index_t K = 256;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
// 3x3 filter, 56x56 image, 1x1 padding // 3x3 filter, 56x56 image, 1x1 padding
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 128; constexpr index_t C = 128;
constexpr unsigned HI = 56; constexpr index_t HI = 56;
constexpr unsigned WI = 56; constexpr index_t WI = 56;
constexpr unsigned K = 256; constexpr index_t K = 256;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
constexpr unsigned HPad = 1; constexpr index_t HPad = 1;
constexpr unsigned WPad = 1; constexpr index_t WPad = 1;
#elif 0 #elif 0
// 3x3 filter, 28x28 image, 1x1 padding // 3x3 filter, 28x28 image, 1x1 padding
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 28; constexpr index_t HI = 28;
constexpr unsigned WI = 28; constexpr index_t WI = 28;
constexpr unsigned K = 512; constexpr index_t K = 512;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
constexpr unsigned HPad = 1; constexpr index_t HPad = 1;
constexpr unsigned WPad = 1; constexpr index_t WPad = 1;
#elif 0 #elif 0
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 28; constexpr index_t HI = 28;
constexpr unsigned WI = 28; constexpr index_t WI = 28;
constexpr unsigned K = 512; constexpr index_t K = 512;
constexpr unsigned Y = 1; constexpr index_t Y = 1;
constexpr unsigned X = 1; constexpr index_t X = 1;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
// 3x3 filter, 20x84 image, 1x1 padding // 3x3 filter, 20x84 image, 1x1 padding
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 20; constexpr index_t HI = 20;
constexpr unsigned WI = 84; constexpr index_t WI = 84;
constexpr unsigned K = 256; constexpr index_t K = 256;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
constexpr unsigned HPad = 1; constexpr index_t HPad = 1;
constexpr unsigned WPad = 1; constexpr index_t WPad = 1;
#elif 0 #elif 0
// 3x3 filter, 112x112 image, 1x1 padding // 3x3 filter, 112x112 image, 1x1 padding
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 64; constexpr index_t C = 64;
constexpr unsigned HI = 112; constexpr index_t HI = 112;
constexpr unsigned WI = 112; constexpr index_t WI = 112;
constexpr unsigned K = 128; constexpr index_t K = 128;
constexpr unsigned Y = 3; constexpr index_t Y = 3;
constexpr unsigned X = 3; constexpr index_t X = 3;
constexpr unsigned HPad = 1; constexpr index_t HPad = 1;
constexpr unsigned WPad = 1; constexpr index_t WPad = 1;
#elif 0 #elif 0
// 5x5 filter, 20x86 image, 1x1 padding // 5x5 filter, 20x86 image, 1x1 padding
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 20; constexpr index_t HI = 20;
constexpr unsigned WI = 86; constexpr index_t WI = 86;
constexpr unsigned K = 512; constexpr index_t K = 512;
constexpr unsigned Y = 5; constexpr index_t Y = 5;
constexpr unsigned X = 5; constexpr index_t X = 5;
constexpr unsigned HPad = 1; constexpr index_t HPad = 1;
constexpr unsigned WPad = 1; constexpr index_t WPad = 1;
#elif 0 #elif 0
// 5x5 filter, 28x28 image, 2x2 padding // 5x5 filter, 28x28 image, 2x2 padding
constexpr unsigned N = 16; constexpr index_t N = 16;
constexpr unsigned C = 192; constexpr index_t C = 192;
constexpr unsigned HI = 28; constexpr index_t HI = 28;
constexpr unsigned WI = 28; constexpr index_t WI = 28;
constexpr unsigned K = 32; constexpr index_t K = 32;
constexpr unsigned Y = 5; constexpr index_t Y = 5;
constexpr unsigned X = 5; constexpr index_t X = 5;
constexpr unsigned HPad = 2; constexpr index_t HPad = 2;
constexpr unsigned WPad = 2; constexpr index_t WPad = 2;
#elif 0 #elif 0
// 1x1 filter, 32x32 image // 1x1 filter, 32x32 image
constexpr unsigned N = 64; constexpr index_t N = 64;
constexpr unsigned C = 256; constexpr index_t C = 256;
constexpr unsigned HI = 32; constexpr index_t HI = 32;
constexpr unsigned WI = 32; constexpr index_t WI = 32;
constexpr unsigned K = 512; constexpr index_t K = 512;
constexpr unsigned Y = 1; constexpr index_t Y = 1;
constexpr unsigned X = 1; constexpr index_t X = 1;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
// 1x1 filter, 14x14 image // 1x1 filter, 14x14 image, C = 2048
constexpr unsigned N = 128; constexpr index_t N = 128;
constexpr unsigned C = 2048; constexpr index_t C = 2048;
constexpr unsigned HI = 14; constexpr index_t HI = 14;
constexpr unsigned WI = 14; constexpr index_t WI = 14;
constexpr unsigned K = 512; constexpr index_t K = 512;
constexpr unsigned Y = 1; constexpr index_t Y = 1;
constexpr unsigned X = 1; constexpr index_t X = 1;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 1
// 1x1 filter, 14x14 image, C = 512 // 1x1 filter, 14x14 image, C = 512
constexpr unsigned N = 128; constexpr index_t N = 128;
constexpr unsigned C = 512; constexpr index_t C = 512;
constexpr unsigned HI = 14; constexpr index_t HI = 14;
constexpr unsigned WI = 14; constexpr index_t WI = 14;
constexpr unsigned K = 512; constexpr index_t K = 512;
constexpr unsigned Y = 1; constexpr index_t Y = 1;
constexpr unsigned X = 1; constexpr index_t X = 1;
constexpr unsigned HPad = 0; constexpr index_t HPad = 0;
constexpr unsigned WPad = 0; constexpr index_t WPad = 0;
#endif #endif
auto lower_pads = Sequence<HPad, WPad>{}; auto lower_pads = Sequence<HPad, WPad>{};
...@@ -634,7 +634,7 @@ int main(int argc, char* argv[]) ...@@ -634,7 +634,7 @@ int main(int argc, char* argv[])
} }
bool do_verification = atoi(argv[1]); bool do_verification = atoi(argv[1]);
unsigned nrepeat = atoi(argv[2]); index_t nrepeat = atoi(argv[2]);
if(do_verification) if(do_verification)
{ {
......
#pragma once #pragma once
template <class TData, unsigned NSize> template <class TData, index_t NSize>
struct Array struct Array
{ {
using Type = Array<TData, NSize>; using Type = Array<TData, NSize>;
static constexpr unsigned nSize = NSize; static constexpr index_t nSize = NSize;
unsigned mData[nSize]; index_t mData[nSize];
template <class... Xs> template <class... Xs>
__host__ __device__ Array(Xs... xs) : mData{static_cast<TData>(xs)...} __host__ __device__ Array(Xs... xs) : mData{static_cast<TData>(xs)...}
{ {
} }
__host__ __device__ TData operator[](unsigned i) const { return mData[i]; } __host__ __device__ TData operator[](index_t i) const { return mData[i]; }
}; };
#pragma once #pragma once
#include "common.hip.hpp" #include "common.hip.hpp"
template <unsigned NRow_, unsigned NCol_, unsigned RowStride_> template <index_t NRow_, index_t NCol_, index_t RowStride_>
struct ConstantMatrixDescriptor struct ConstantMatrixDescriptor
{ {
__host__ __device__ constexpr ConstantMatrixDescriptor() __host__ __device__ constexpr ConstantMatrixDescriptor()
...@@ -9,24 +9,28 @@ struct ConstantMatrixDescriptor ...@@ -9,24 +9,28 @@ struct ConstantMatrixDescriptor
static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!"); static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!");
} }
__host__ __device__ constexpr unsigned NRow() const { return NRow_; } __host__ __device__ constexpr index_t NRow() const { return NRow_; }
__host__ __device__ constexpr unsigned NCol() const { return NCol_; } __host__ __device__ constexpr index_t NCol() const { return NCol_; }
__host__ __device__ constexpr unsigned RowStride() const { return RowStride_; } __host__ __device__ constexpr index_t RowStride() const { return RowStride_; }
__host__ __device__ constexpr auto GetLengths() const { return Sequence<NRow_, NCol_>{}; } __host__ __device__ constexpr auto GetLengths() const { return Sequence<NRow_, NCol_>{}; }
__host__ __device__ constexpr unsigned GetElementSize() const { return NRow_ * NCol_; } __host__ __device__ constexpr index_t GetElementSize() const { return NRow_ * NCol_; }
__host__ __device__ constexpr unsigned GetElementSpace() const { return NRow_ * RowStride_; } __host__ __device__ constexpr index_t GetElementSpace() const { return NRow_ * RowStride_; }
__host__ __device__ unsigned Get1dIndex(unsigned irow, unsigned icol) const __host__ __device__ index_t Get1dIndex(index_t irow, index_t icol) const
{ {
#if DEVICE_BACKEND_HIP
return __mul24(irow, RowStride_) + icol;
#else
return irow * RowStride_ + icol; return irow * RowStride_ + icol;
#endif
} }
template <unsigned SubNRow, unsigned SubNCol> template <index_t SubNRow, index_t SubNCol>
__host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>, __host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
Number<SubNCol>) const Number<SubNCol>) const
{ {
...@@ -34,13 +38,13 @@ struct ConstantMatrixDescriptor ...@@ -34,13 +38,13 @@ struct ConstantMatrixDescriptor
} }
}; };
template <unsigned NRow, unsigned NCol> template <index_t NRow, index_t NCol>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>) __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>)
{ {
return ConstantMatrixDescriptor<NRow, NCol, NCol>{}; return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
} }
template <unsigned NRow, unsigned NCol, unsigned RowStride> template <index_t NRow, index_t NCol, index_t RowStride>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>) make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>)
{ {
......
...@@ -2,35 +2,35 @@ ...@@ -2,35 +2,35 @@
#include "common.hip.hpp" #include "common.hip.hpp"
// this is ugly, only for 2d // this is ugly, only for 2d
template <unsigned L0, unsigned L1> template <index_t L0, index_t L1>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>) __host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>)
{ {
return Sequence<L1, 1>{}; return Sequence<L1, 1>{};
} }
// this is ugly, only for 4d // this is ugly, only for 4d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3> template <index_t L0, index_t L1, index_t L2, index_t L3>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>) __host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
{ {
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{}; return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
} }
// this is ugly, only for 6d // this is ugly, only for 6d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned L4, unsigned L5> template <index_t L0, index_t L1, index_t L2, index_t L3, index_t L4, index_t L5>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5>) __host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5>)
{ {
return Sequence<L1 * L2 * L3 * L4 * L5, L2 * L3 * L4 * L5, L3 * L4 * L5, L4 * L5, L5, 1>{}; return Sequence<L1 * L2 * L3 * L4 * L5, L2 * L3 * L4 * L5, L3 * L4 * L5, L4 * L5, L5, 1>{};
} }
// this is ugly, only for 8d // this is ugly, only for 8d
template <unsigned L0, template <index_t L0,
unsigned L1, index_t L1,
unsigned L2, index_t L2,
unsigned L3, index_t L3,
unsigned L4, index_t L4,
unsigned L5, index_t L5,
unsigned L6, index_t L6,
unsigned L7> index_t L7>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7>) calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7>)
{ {
...@@ -45,48 +45,48 @@ __host__ __device__ constexpr auto ...@@ -45,48 +45,48 @@ __host__ __device__ constexpr auto
} }
// this is ugly, only for 2d // this is ugly, only for 2d
template <unsigned L0, unsigned L1, unsigned Align> template <index_t L0, index_t L1, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1>, __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1>,
Number<Align>) Number<Align>)
{ {
constexpr unsigned L1_align = Align * ((L1 + Align - 1) / Align); constexpr index_t L1_align = Align * ((L1 + Align - 1) / Align);
return Sequence<L1_align, 1>{}; return Sequence<L1_align, 1>{};
} }
// this is ugly, only for 4d // this is ugly, only for 4d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned Align> template <index_t L0, index_t L1, index_t L2, index_t L3, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2, L3>, __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2, L3>,
Number<Align>) Number<Align>)
{ {
constexpr unsigned L3_align = Align * ((L3 + Align - 1) / Align); constexpr index_t L3_align = Align * ((L3 + Align - 1) / Align);
return Sequence<L1 * L2 * L3_align, L2 * L3_align, L3_align, 1>{}; return Sequence<L1 * L2 * L3_align, L2 * L3_align, L3_align, 1>{};
} }
template <class Lengths, class Strides> template <class Lengths, class Strides>
struct ConstantTensorDescriptor struct ConstantTensorDescriptor
{ {
using Type = ConstantTensorDescriptor<Lengths, Strides>; using Type = ConstantTensorDescriptor<Lengths, Strides>;
static constexpr unsigned nDim = Lengths::nDim; static constexpr index_t nDim = Lengths::nDim;
__host__ __device__ constexpr ConstantTensorDescriptor() __host__ __device__ constexpr ConstantTensorDescriptor()
{ {
static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent"); static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
} }
__host__ __device__ constexpr unsigned GetDimension() const { return nDim; } __host__ __device__ constexpr index_t GetDimension() const { return nDim; }
__host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; } __host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; }
__host__ __device__ constexpr Strides GetStrides() const { return Strides{}; } __host__ __device__ constexpr Strides GetStrides() const { return Strides{}; }
template <unsigned I> template <index_t I>
__host__ __device__ constexpr unsigned GetLength(Number<I>) const __host__ __device__ constexpr index_t GetLength(Number<I>) const
{ {
return Lengths{}.Get(Number<I>{}); return Lengths{}.Get(Number<I>{});
} }
template <unsigned I> template <index_t I>
__host__ __device__ constexpr unsigned GetStride(Number<I>) const __host__ __device__ constexpr index_t GetStride(Number<I>) const
{ {
return Strides{}.Get(Number<I>{}); return Strides{}.Get(Number<I>{});
} }
...@@ -95,18 +95,18 @@ struct ConstantTensorDescriptor ...@@ -95,18 +95,18 @@ struct ConstantTensorDescriptor
struct GetElementSize_f struct GetElementSize_f
{ {
template <class IDim> template <class IDim>
__host__ __device__ constexpr unsigned operator()(IDim idim) const __host__ __device__ constexpr index_t operator()(IDim idim) const
{ {
return Type{}.GetLength(idim); return Type{}.GetLength(idim);
} }
}; };
__host__ __device__ constexpr unsigned GetElementSize() const __host__ __device__ constexpr index_t GetElementSize() const
{ {
// c++14 doesn't support constexpr lambdas, has to use this trick instead // c++14 doesn't support constexpr lambdas, has to use this trick instead
struct multiply struct multiply
{ {
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const __host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
{ {
return a * b; return a * b;
} }
...@@ -119,19 +119,19 @@ struct ConstantTensorDescriptor ...@@ -119,19 +119,19 @@ struct ConstantTensorDescriptor
struct GetElementSpace_f struct GetElementSpace_f
{ {
template <class IDim> template <class IDim>
__host__ __device__ constexpr unsigned operator()(IDim idim) const __host__ __device__ constexpr index_t operator()(IDim idim) const
{ {
return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim); return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim);
} }
}; };
template <class Align = Number<1>> template <class Align = Number<1>>
__host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const __host__ __device__ constexpr index_t GetElementSpace(Align align = Align{}) const
{ {
// c++14 doesn't support constexpr lambdas, has to use this trick instead // c++14 doesn't support constexpr lambdas, has to use this trick instead
struct add struct add
{ {
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const __host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
{ {
return a + b; return a + b;
} }
...@@ -141,17 +141,21 @@ struct ConstantTensorDescriptor ...@@ -141,17 +141,21 @@ struct ConstantTensorDescriptor
} }
template <class... Is> template <class... Is>
__host__ __device__ unsigned Get1dIndex(Is... is) const __host__ __device__ index_t Get1dIndex(Is... is) const
{ {
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong"); static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
const auto multi_id = Array<unsigned, nDim>(is...); const auto multi_id = Array<index_t, nDim>(is...);
unsigned id = 0; index_t id = 0;
static_loop_n<nDim>{}([&](auto IDim) { static_loop_n<nDim>{}([&](auto IDim) {
constexpr unsigned idim = IDim.Get(); constexpr index_t idim = IDim.Get();
#if DEVICE_BACKEND_HIP
id += __mul24(multi_id[idim], GetStride(IDim));
#else
id += multi_id[idim] * GetStride(IDim); id += multi_id[idim] * GetStride(IDim);
#endif
}); });
return id; return id;
...@@ -163,7 +167,7 @@ struct ConstantTensorDescriptor ...@@ -163,7 +167,7 @@ struct ConstantTensorDescriptor
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{}; return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
} }
template <unsigned IDim, unsigned NVector> template <index_t IDim, index_t NVector>
__host__ __device__ constexpr auto Vectorize(Number<IDim>, Number<NVector>) const __host__ __device__ constexpr auto Vectorize(Number<IDim>, Number<NVector>) const
{ {
assert(false); // not implemented assert(false); // not implemented
...@@ -183,7 +187,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride ...@@ -183,7 +187,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride
return ConstantTensorDescriptor<Lengths, Strides>{}; return ConstantTensorDescriptor<Lengths, Strides>{};
} }
template <class Lengths, unsigned Align> template <class Lengths, index_t Align>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>) __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
{ {
using Strides = decltype(calculate_default_strides_aligned(Lengths{}, Number<Align>{})); using Strides = decltype(calculate_default_strides_aligned(Lengths{}, Number<Align>{}));
...@@ -193,8 +197,8 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths ...@@ -193,8 +197,8 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths
template <class TDesc> template <class TDesc>
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
{ {
constexpr auto desc = TDesc{}; constexpr auto desc = TDesc{};
constexpr unsigned ndim = desc.GetDimension(); constexpr index_t ndim = desc.GetDimension();
static_assert(ndim >= 2 && ndim <= 8, "wrong!"); static_assert(ndim >= 2 && ndim <= 8, "wrong!");
......
...@@ -2,38 +2,38 @@ ...@@ -2,38 +2,38 @@
#include "constant_integral.hip.hpp" #include "constant_integral.hip.hpp"
#include "functional.hip.hpp" #include "functional.hip.hpp"
template <unsigned... Is> template <index_t... Is>
struct Sequence struct Sequence
{ {
using Type = Sequence<Is...>; using Type = Sequence<Is...>;
static constexpr unsigned nDim = sizeof...(Is); static constexpr index_t nDim = sizeof...(Is);
const unsigned mData[nDim] = {Is...}; const index_t mData[nDim] = {Is...};
template <unsigned I> template <index_t I>
__host__ __device__ constexpr unsigned Get(Number<I>) const __host__ __device__ constexpr index_t Get(Number<I>) const
{ {
return mData[I]; return mData[I];
} }
// this is ugly, only for nDIm = 4 // this is ugly, only for nDIm = 4
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3> template <index_t I0, index_t I1, index_t I2, index_t I3>
__host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const __host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const
{ {
static_assert(nDim == 4, "nDim != 4"); static_assert(nDim == 4, "nDim != 4");
constexpr auto old_sequence = Type{}; constexpr auto old_sequence = Type{};
constexpr unsigned NR0 = old_sequence.mData[I0]; constexpr index_t NR0 = old_sequence.mData[I0];
constexpr unsigned NR1 = old_sequence.mData[I1]; constexpr index_t NR1 = old_sequence.mData[I1];
constexpr unsigned NR2 = old_sequence.mData[I2]; constexpr index_t NR2 = old_sequence.mData[I2];
constexpr unsigned NR3 = old_sequence.mData[I3]; constexpr index_t NR3 = old_sequence.mData[I3];
return Sequence<NR0, NR1, NR2, NR3>{}; return Sequence<NR0, NR1, NR2, NR3>{};
} }
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3> template <index_t I0, index_t I1, index_t I2, index_t I3>
__host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence<I0, I1, I2, I3>) const __host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence<I0, I1, I2, I3>) const
{ {
// don't know how to implement this // don't know how to implement this
...@@ -41,7 +41,7 @@ struct Sequence ...@@ -41,7 +41,7 @@ struct Sequence
assert(false); assert(false);
} }
template <unsigned I> template <index_t I>
__host__ __device__ constexpr auto PushBack(Number<I>) const __host__ __device__ constexpr auto PushBack(Number<I>) const
{ {
return Sequence<Is..., I>{}; return Sequence<Is..., I>{};
...@@ -56,14 +56,14 @@ struct Sequence ...@@ -56,14 +56,14 @@ struct Sequence
} }
}; };
template <unsigned... Is, unsigned I> template <index_t... Is, index_t I>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>) __host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
{ {
static_assert(sizeof...(Is) >= 1, "empty Sequence!"); static_assert(sizeof...(Is) >= 1, "empty Sequence!");
return Sequence<Is...>{}; return Sequence<Is...>{};
} }
template <class F, unsigned... Xs, unsigned... Ys> template <class F, index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequence<Ys...>, F f) __host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequence<Ys...>, F f)
{ {
static_assert(Sequence<Xs...>::nDim == Sequence<Ys...>::nDim, "Dim not the same"); static_assert(Sequence<Xs...>::nDim == Sequence<Ys...>::nDim, "Dim not the same");
...@@ -71,12 +71,12 @@ __host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequenc ...@@ -71,12 +71,12 @@ __host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequenc
return Sequence<f(Xs, Ys)...>{}; return Sequence<f(Xs, Ys)...>{};
} }
template <unsigned... Xs, unsigned... Ys> template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequence<Ys...>) __host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequence<Ys...>)
{ {
struct add struct add
{ {
__host__ __device__ constexpr unsigned operator()(unsigned x, unsigned y) const __host__ __device__ constexpr index_t operator()(index_t x, index_t y) const
{ {
return x + y; return x + y;
} }
...@@ -85,7 +85,7 @@ __host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequen ...@@ -85,7 +85,7 @@ __host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequen
return sequence_sequence_op(Sequence<Xs...>{}, Sequence<Ys...>{}, add{}); return sequence_sequence_op(Sequence<Xs...>{}, Sequence<Ys...>{}, add{});
} }
template <unsigned... Is> template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopBack() const __host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
{ {
return sequence_pop_back(Type{}); return sequence_pop_back(Type{});
......
#pragma once #pragma once
#include "ConstantTensorDescriptor.hip.hpp" #include "ConstantTensorDescriptor.hip.hpp"
template <unsigned BlockSize, class Float, class DstDesc, class F> template <index_t BlockSize, class Float, class DstDesc, class F>
__device__ void __device__ void
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f) blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
{ {
...@@ -20,19 +20,19 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst ...@@ -20,19 +20,19 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
} }
#endif #endif
constexpr unsigned NLoop = desc.GetElementSize() / BlockSize; constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
for(unsigned iloop = 0; iloop < NLoop; ++iloop) for(index_t iloop = 0; iloop < NLoop; ++iloop)
{ {
unsigned is = threadIdx.x + iloop * BlockSize; index_t is = threadIdx.x + iloop * BlockSize;
const unsigned did0 = is / desc.GetStride(I0); const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0); is -= did0 * desc.GetStride(I0);
const unsigned did1 = is / desc.GetStride(I1); const index_t did1 = is / desc.GetStride(I1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1); const index_t dindex = dst_desc.Get1dIndex(did0, did1);
f(p_dst[dindex]); f(p_dst[dindex]);
} }
...@@ -41,17 +41,17 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst ...@@ -41,17 +41,17 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
if(has_tail) if(has_tail)
{ {
unsigned is = threadIdx.x + NLoop * BlockSize; index_t is = threadIdx.x + NLoop * BlockSize;
if(is < desc.GetElementSize()) if(is < desc.GetElementSize())
{ {
const unsigned did0 = is / desc.GetStride(I0); const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0); is -= did0 * desc.GetStride(I0);
const unsigned did1 = is / desc.GetStride(I1); const index_t did1 = is / desc.GetStride(I1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1); const index_t dindex = dst_desc.Get1dIndex(did0, did1);
f(p_dst[dindex]); f(p_dst[dindex]);
} }
...@@ -61,7 +61,7 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst ...@@ -61,7 +61,7 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3] // Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
// TODO: in order to optimize mem access for different mem type, // TODO: in order to optimize mem access for different mem type,
// need to write specialized version // need to write specialized version
template <unsigned BlockSize, template <index_t BlockSize,
class Float, class Float,
class SrcDesc, class SrcDesc,
class DstDesc, class DstDesc,
...@@ -80,20 +80,20 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds ...@@ -80,20 +80,20 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0); constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1); constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
constexpr auto src_desc = SrcDesc{}; constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{}; constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize; constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
for(unsigned iloop = 0; iloop < NLoop; ++iloop) for(index_t iloop = 0; iloop < NLoop; ++iloop)
{ {
unsigned is = threadIdx.x + iloop * BlockSize; index_t is = threadIdx.x + iloop * BlockSize;
unsigned did[2]; index_t did[2];
did[0] = is / ref_desc.GetStride(I0); did[0] = is / ref_desc.GetStride(I0);
...@@ -101,9 +101,9 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds ...@@ -101,9 +101,9 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
did[1] = is / ref_desc.GetStride(I1); did[1] = is / ref_desc.GetStride(I1);
const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]); const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]); const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
f(p_src[aindex], p_dst[bindex]); f(p_src[aindex], p_dst[bindex]);
} }
...@@ -112,11 +112,11 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds ...@@ -112,11 +112,11 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
if(has_tail) if(has_tail)
{ {
unsigned is = threadIdx.x + NLoop * BlockSize; index_t is = threadIdx.x + NLoop * BlockSize;
if(is < ref_desc.GetElementSize()) if(is < ref_desc.GetElementSize())
{ {
unsigned did[2]; index_t did[2];
did[0] = is / ref_desc.GetStride(I0); did[0] = is / ref_desc.GetStride(I0);
...@@ -124,16 +124,16 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds ...@@ -124,16 +124,16 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
did[1] = is / ref_desc.GetStride(I1); did[1] = is / ref_desc.GetStride(I1);
const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]); const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]); const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
f(p_src[aindex], p_dst[bindex]); f(p_src[aindex], p_dst[bindex]);
} }
} }
} }
template <unsigned BlockSize, class Float, class DstDesc> template <index_t BlockSize, class Float, class DstDesc>
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst) __device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
{ {
auto f_set_zero = [](Float& v) { v = Float(0); }; auto f_set_zero = [](Float& v) { v = Float(0); };
...@@ -141,7 +141,7 @@ __device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst) ...@@ -141,7 +141,7 @@ __device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero); blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
} }
template <unsigned BlockSize, template <index_t BlockSize,
class Float, class Float,
class SrcDesc, class SrcDesc,
class DstDesc, class DstDesc,
...@@ -161,7 +161,7 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, ...@@ -161,7 +161,7 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy); SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
} }
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths> template <index_t BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
struct Blockwise2dTensorCopy1 struct Blockwise2dTensorCopy1
{ {
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
...@@ -175,17 +175,17 @@ struct Blockwise2dTensorCopy1 ...@@ -175,17 +175,17 @@ struct Blockwise2dTensorCopy1
// need to be aligned to float4 and float2 // need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination // stride1 need to be 1 for both source and destination
template <unsigned BlockSize, template <index_t BlockSize,
class Float, class Float,
class SrcDesc, class SrcDesc,
class DstDesc, class DstDesc,
class SrcOpLengths, class SrcOpLengths,
unsigned ThreadPerDim0, index_t ThreadPerDim0,
unsigned ThreadPerDim1> index_t ThreadPerDim1>
struct Blockwise2dTensorCopy2 struct Blockwise2dTensorCopy2
{ {
unsigned mThreadId0; index_t mThreadId0;
unsigned mThreadId1; index_t mThreadId1;
__device__ Blockwise2dTensorCopy2() __device__ Blockwise2dTensorCopy2()
{ {
...@@ -222,61 +222,61 @@ struct Blockwise2dTensorCopy2 ...@@ -222,61 +222,61 @@ struct Blockwise2dTensorCopy2
constexpr bool align_v2 = constexpr bool align_v2 =
src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0; src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;
constexpr unsigned L0 = SrcOpLengths{}.Get(I0); constexpr index_t L0 = SrcOpLengths{}.Get(I0);
constexpr unsigned L1 = SrcOpLengths{}.Get(I1); constexpr index_t L1 = SrcOpLengths{}.Get(I1);
constexpr unsigned Dim0Loop = L0 / ThreadPerDim0; constexpr index_t Dim0Loop = L0 / ThreadPerDim0;
constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop); constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
constexpr unsigned Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0; constexpr index_t Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
constexpr unsigned Dim1V2Loop = constexpr index_t Dim1V2Loop =
align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0; align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;
constexpr unsigned Dim1V1Loop = constexpr index_t Dim1V1Loop =
(L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) / (L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
ThreadPerDim1; ThreadPerDim1;
constexpr bool d1_has_tail = constexpr bool d1_has_tail =
(L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop)); (L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));
for(unsigned d0loop = 0; d0loop < Dim0Loop; ++d0loop) for(index_t d0loop = 0; d0loop < Dim0Loop; ++d0loop)
{ {
unsigned did0 = d0loop * ThreadPerDim0 + mThreadId0; index_t did0 = d0loop * ThreadPerDim0 + mThreadId0;
// v4 // v4
for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop) for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
{ {
unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1; index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1); const index_t sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1); const index_t dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float4*>(p_dst + dindex)) = *(reinterpret_cast<Float4*>(p_dst + dindex)) =
*(reinterpret_cast<const Float4*>(p_src + sindex)); *(reinterpret_cast<const Float4*>(p_src + sindex));
} }
// v2 // v2
for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop) for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
{ {
unsigned did1 = index_t did1 =
Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1; Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1); const index_t sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1); const index_t dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float2*>(p_dst + dindex)) = *(reinterpret_cast<Float2*>(p_dst + dindex)) =
*(reinterpret_cast<const Float2*>(p_src + sindex)); *(reinterpret_cast<const Float2*>(p_src + sindex));
} }
// v1 // v1
for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop) for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
{ {
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 + index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
d1v1loop * ThreadPerDim1 + mThreadId1; d1v1loop * ThreadPerDim1 + mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1); const index_t sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1); const index_t dindex = dst_desc.Get1dIndex(did0, did1);
p_dst[dindex] = p_src[sindex]; p_dst[dindex] = p_src[sindex];
} }
...@@ -284,13 +284,13 @@ struct Blockwise2dTensorCopy2 ...@@ -284,13 +284,13 @@ struct Blockwise2dTensorCopy2
// dim-1 tail // dim-1 tail
if(d1_has_tail) if(d1_has_tail)
{ {
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 + index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
Dim1V1Loop * ThreadPerDim1 + mThreadId1; Dim1V1Loop * ThreadPerDim1 + mThreadId1;
if(did1 < L1) if(did1 < L1)
{ {
const unsigned sindex = src_desc.Get1dIndex(did0, did1); const index_t sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1); const index_t dindex = dst_desc.Get1dIndex(did0, did1);
p_dst[dindex] = p_src[sindex]; p_dst[dindex] = p_src[sindex];
} }
...@@ -300,45 +300,44 @@ struct Blockwise2dTensorCopy2 ...@@ -300,45 +300,44 @@ struct Blockwise2dTensorCopy2
// dim-0 tail // dim-0 tail
if(d0_has_tail) if(d0_has_tail)
{ {
unsigned did0 = Dim0Loop * ThreadPerDim0 + mThreadId0; index_t did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
if(did0 < L0) if(did0 < L0)
{ {
// v4 // v4
for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop) for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
{ {
unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1; index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1); const index_t sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1); const index_t dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float4*>(p_dst + dindex)) = *(reinterpret_cast<Float4*>(p_dst + dindex)) =
*(reinterpret_cast<const Float4*>(p_src + sindex)); *(reinterpret_cast<const Float4*>(p_src + sindex));
} }
// v2 // v2
for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop) for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
{ {
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
2 * mThreadId1; 2 * mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1); const index_t sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1); const index_t dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float2*>(p_dst + dindex)) = *(reinterpret_cast<Float2*>(p_dst + dindex)) =
*(reinterpret_cast<const Float2*>(p_src + sindex)); *(reinterpret_cast<const Float2*>(p_src + sindex));
} }
// v1 // v1
for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop) for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
{ {
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
Dim1V2Loop * 2 * ThreadPerDim1 + d1v1loop * ThreadPerDim1 + d1v1loop * ThreadPerDim1 + mThreadId1;
mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1); const index_t sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1); const index_t dindex = dst_desc.Get1dIndex(did0, did1);
p_dst[dindex] = p_src[sindex]; p_dst[dindex] = p_src[sindex];
} }
...@@ -346,14 +345,13 @@ struct Blockwise2dTensorCopy2 ...@@ -346,14 +345,13 @@ struct Blockwise2dTensorCopy2
// tail // tail
if(d1_has_tail) if(d1_has_tail)
{ {
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
Dim1V2Loop * 2 * ThreadPerDim1 + Dim1V1Loop * ThreadPerDim1 + Dim1V1Loop * ThreadPerDim1 + mThreadId1;
mThreadId1;
if(did1 < L1) if(did1 < L1)
{ {
const unsigned sindex = src_desc.Get1dIndex(did0, did1); const index_t sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1); const index_t dindex = dst_desc.Get1dIndex(did0, did1);
p_dst[dindex] = p_src[sindex]; p_dst[dindex] = p_src[sindex];
} }
...@@ -365,18 +363,18 @@ struct Blockwise2dTensorCopy2 ...@@ -365,18 +363,18 @@ struct Blockwise2dTensorCopy2
// starting point need to be aligned to float4 or float2 or float // starting point need to be aligned to float4 or float2 or float
// stride1 need to be 1 for both source and destination // stride1 need to be 1 for both source and destination
template <unsigned BlockSize, template <index_t BlockSize,
class Float, class Float,
class SrcDesc, class SrcDesc,
class DstDesc, class DstDesc,
class CopyLengths, class CopyLengths,
unsigned DataPerRead> index_t DataPerRead>
struct Blockwise2dTensorCopy3 struct Blockwise2dTensorCopy3
{ {
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType; using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
unsigned mSrcMyThreadOffset; index_t mSrcMyThreadOffset;
unsigned mDstMyThreadOffset; index_t mDstMyThreadOffset;
__device__ Blockwise2dTensorCopy3() __device__ Blockwise2dTensorCopy3()
{ {
...@@ -394,11 +392,11 @@ struct Blockwise2dTensorCopy3 ...@@ -394,11 +392,11 @@ struct Blockwise2dTensorCopy3
DstDesc{}.GetStride(I0) % DataPerRead == 0, DstDesc{}.GetStride(I0) % DataPerRead == 0,
"src and dst stride should be multiple of DataPerRead to keep alignment"); "src and dst stride should be multiple of DataPerRead to keep alignment");
constexpr unsigned L0 = CopyLengths{}.Get(I0); constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1); constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1; constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
// we allow out-of-bound read from src in D1 dimension, // we allow out-of-bound read from src in D1 dimension,
// but we need to make sure dst stride is big enough, // but we need to make sure dst stride is big enough,
...@@ -408,7 +406,7 @@ struct Blockwise2dTensorCopy3 ...@@ -408,7 +406,7 @@ struct Blockwise2dTensorCopy3
static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n"); static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n");
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1; constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread) if(BlockSize > num_active_thread)
{ {
...@@ -418,8 +416,8 @@ struct Blockwise2dTensorCopy3 ...@@ -418,8 +416,8 @@ struct Blockwise2dTensorCopy3
} }
} }
const unsigned thread_id_d0 = get_thread_local_1d_id() / thread_per_d1; const index_t thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
const unsigned thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1; const index_t thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead); mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead); mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
...@@ -430,13 +428,13 @@ struct Blockwise2dTensorCopy3 ...@@ -430,13 +428,13 @@ struct Blockwise2dTensorCopy3
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0); constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1); constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1; constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1; constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread) if(BlockSize > num_active_thread)
{ {
...@@ -446,18 +444,18 @@ struct Blockwise2dTensorCopy3 ...@@ -446,18 +444,18 @@ struct Blockwise2dTensorCopy3
} }
} }
constexpr unsigned nloop_d0 = L0 / thread_per_d0; constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0; constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0; constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
auto f_copy = [&](unsigned iloop) { auto f_copy = [&](index_t iloop) {
*(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) = *(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
*(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset + *(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
iloop * src_loop_stride)); iloop * src_loop_stride));
}; };
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop) for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
{ {
f_copy(iloop); f_copy(iloop);
} }
...@@ -466,7 +464,7 @@ struct Blockwise2dTensorCopy3 ...@@ -466,7 +464,7 @@ struct Blockwise2dTensorCopy3
if(has_tail_d0) if(has_tail_d0)
{ {
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0; constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1) if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{ {
...@@ -475,18 +473,18 @@ struct Blockwise2dTensorCopy3 ...@@ -475,18 +473,18 @@ struct Blockwise2dTensorCopy3
} }
} }
__device__ constexpr unsigned GetRegisterClipboardSize() const __device__ constexpr index_t GetRegisterClipboardSize() const
{ {
static_assert(is_same<Float, float>::value, "wrong! only support float!\n"); static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0); constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1); constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1; constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0; return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
} }
...@@ -497,13 +495,13 @@ struct Blockwise2dTensorCopy3 ...@@ -497,13 +495,13 @@ struct Blockwise2dTensorCopy3
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0); constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1); constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1; constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1; constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread) if(BlockSize > num_active_thread)
{ {
...@@ -513,18 +511,18 @@ struct Blockwise2dTensorCopy3 ...@@ -513,18 +511,18 @@ struct Blockwise2dTensorCopy3
} }
} }
constexpr unsigned nloop_d0 = L0 / thread_per_d0; constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0; constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0; constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
auto f_copy = [&](unsigned iloop) { auto f_copy = [&](index_t iloop) {
*(reinterpret_cast<vector_t*>(p_clipboard + iloop * 4)) = *(reinterpret_cast<vector_t*>(p_clipboard + iloop * 4)) =
*(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset + *(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
iloop * src_loop_stride)); iloop * src_loop_stride));
}; };
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop) for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
{ {
f_copy(iloop); f_copy(iloop);
} }
...@@ -533,7 +531,7 @@ struct Blockwise2dTensorCopy3 ...@@ -533,7 +531,7 @@ struct Blockwise2dTensorCopy3
if(has_tail_d0) if(has_tail_d0)
{ {
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0; constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1) if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{ {
...@@ -548,13 +546,13 @@ struct Blockwise2dTensorCopy3 ...@@ -548,13 +546,13 @@ struct Blockwise2dTensorCopy3
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0); constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1); constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1; constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1; constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread) if(BlockSize > num_active_thread)
{ {
...@@ -564,17 +562,17 @@ struct Blockwise2dTensorCopy3 ...@@ -564,17 +562,17 @@ struct Blockwise2dTensorCopy3
} }
} }
constexpr unsigned nloop_d0 = L0 / thread_per_d0; constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0; constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0; constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
auto f_copy = [&](unsigned iloop) { auto f_copy = [&](index_t iloop) {
*(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) = *(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
*(reinterpret_cast<const vector_t*>(p_clipboard + iloop * 4)); *(reinterpret_cast<const vector_t*>(p_clipboard + iloop * 4));
}; };
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop) for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
{ {
f_copy(iloop); f_copy(iloop);
} }
...@@ -583,7 +581,7 @@ struct Blockwise2dTensorCopy3 ...@@ -583,7 +581,7 @@ struct Blockwise2dTensorCopy3
if(has_tail_d0) if(has_tail_d0)
{ {
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0; constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1) if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{ {
......
#pragma once #pragma once
#include "ConstantTensorDescriptor.hip.hpp" #include "ConstantTensorDescriptor.hip.hpp"
template <unsigned BlockSize, class Float, class DstDesc, class F> template <index_t BlockSize, class Float, class DstDesc, class F>
__device__ void __device__ void
blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f) blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
{ {
...@@ -22,27 +22,27 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst ...@@ -22,27 +22,27 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
} }
#endif #endif
constexpr unsigned NLoop = desc.GetElementSize() / BlockSize; constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
for(unsigned iloop = 0; iloop < NLoop; ++iloop) for(index_t iloop = 0; iloop < NLoop; ++iloop)
{ {
unsigned is = threadIdx.x + iloop * BlockSize; index_t is = threadIdx.x + iloop * BlockSize;
const unsigned did0 = is / desc.GetStride(I0); const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0); is -= did0 * desc.GetStride(I0);
const unsigned did1 = is / desc.GetStride(I1); const index_t did1 = is / desc.GetStride(I1);
is -= did1 * desc.GetStride(I1); is -= did1 * desc.GetStride(I1);
const unsigned did2 = is / desc.GetStride(I2); const index_t did2 = is / desc.GetStride(I2);
is -= did2 * desc.GetStride(I2); is -= did2 * desc.GetStride(I2);
const unsigned did3 = is / desc.GetStride(I3); const index_t did3 = is / desc.GetStride(I3);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3); const index_t dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
f(p_dst[dindex]); f(p_dst[dindex]);
} }
...@@ -51,25 +51,25 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst ...@@ -51,25 +51,25 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
if(has_tail) if(has_tail)
{ {
unsigned is = threadIdx.x + NLoop * BlockSize; index_t is = threadIdx.x + NLoop * BlockSize;
if(is < desc.GetElementSize()) if(is < desc.GetElementSize())
{ {
const unsigned did0 = is / desc.GetStride(I0); const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0); is -= did0 * desc.GetStride(I0);
const unsigned did1 = is / desc.GetStride(I1); const index_t did1 = is / desc.GetStride(I1);
is -= did1 * desc.GetStride(I1); is -= did1 * desc.GetStride(I1);
const unsigned did2 = is / desc.GetStride(I2); const index_t did2 = is / desc.GetStride(I2);
is -= did2 * desc.GetStride(I2); is -= did2 * desc.GetStride(I2);
const unsigned did3 = is / desc.GetStride(I3); const index_t did3 = is / desc.GetStride(I3);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3); const index_t dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
f(p_dst[dindex]); f(p_dst[dindex]);
} }
...@@ -79,7 +79,7 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst ...@@ -79,7 +79,7 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3] // Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
// TODO: in order to optimize mem access for different mem type, // TODO: in order to optimize mem access for different mem type,
// need to write specialized version // need to write specialized version
template <unsigned BlockSize, template <index_t BlockSize,
class Float, class Float,
class SrcDesc, class SrcDesc,
class DstDesc, class DstDesc,
...@@ -100,22 +100,22 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds ...@@ -100,22 +100,22 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0); constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1); constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
constexpr unsigned IR2 = DstFromSrcReorder{}.Get(I2); constexpr index_t IR2 = DstFromSrcReorder{}.Get(I2);
constexpr unsigned IR3 = DstFromSrcReorder{}.Get(I3); constexpr index_t IR3 = DstFromSrcReorder{}.Get(I3);
constexpr auto src_desc = SrcDesc{}; constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{}; constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize; constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
for(unsigned iloop = 0; iloop < NLoop; ++iloop) for(index_t iloop = 0; iloop < NLoop; ++iloop)
{ {
unsigned is = threadIdx.x + iloop * BlockSize; index_t is = threadIdx.x + iloop * BlockSize;
unsigned did[4]; index_t did[4];
did[0] = is / ref_desc.GetStride(I0); did[0] = is / ref_desc.GetStride(I0);
...@@ -131,9 +131,9 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds ...@@ -131,9 +131,9 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
did[3] = is / ref_desc.GetStride(I3); did[3] = is / ref_desc.GetStride(I3);
const unsigned src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]); const index_t src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
const unsigned dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); const index_t dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
f(p_src[src_index], p_dst[dst_index]); f(p_src[src_index], p_dst[dst_index]);
} }
...@@ -142,11 +142,11 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds ...@@ -142,11 +142,11 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
if(has_tail) if(has_tail)
{ {
unsigned is = threadIdx.x + NLoop * BlockSize; index_t is = threadIdx.x + NLoop * BlockSize;
if(is < ref_desc.GetElementSize()) if(is < ref_desc.GetElementSize())
{ {
unsigned did[4]; index_t did[4];
did[0] = is / ref_desc.GetStride(I0); did[0] = is / ref_desc.GetStride(I0);
...@@ -162,16 +162,16 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds ...@@ -162,16 +162,16 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
did[3] = is / ref_desc.GetStride(I3); did[3] = is / ref_desc.GetStride(I3);
const unsigned src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]); const index_t src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
const unsigned dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); const index_t dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
f(p_src[src_index], p_dst[dst_index]); f(p_src[src_index], p_dst[dst_index]);
} }
} }
} }
template <unsigned BlockSize, class Float, class DstDesc> template <index_t BlockSize, class Float, class DstDesc>
__device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst) __device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
{ {
auto f_set_zero = [](Float& v) { v = Float(0); }; auto f_set_zero = [](Float& v) { v = Float(0); };
...@@ -179,7 +179,7 @@ __device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst) ...@@ -179,7 +179,7 @@ __device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
blockwise_4d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero); blockwise_4d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
} }
template <unsigned BlockSize, template <index_t BlockSize,
class Float, class Float,
class SrcDesc, class SrcDesc,
class DstDesc, class DstDesc,
...@@ -199,12 +199,12 @@ blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, ...@@ -199,12 +199,12 @@ blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy); SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
} }
template <unsigned BlockSize, template <index_t BlockSize,
class Float, class Float,
class SrcDesc, class SrcDesc,
class DstDesc, class DstDesc,
class CopyLengths, class CopyLengths,
unsigned DataPerRead> index_t DataPerRead>
struct Blockwise4dTensorCopy1 struct Blockwise4dTensorCopy1
{ {
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType; using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
...@@ -230,8 +230,8 @@ struct Blockwise4dTensorCopy1 ...@@ -230,8 +230,8 @@ struct Blockwise4dTensorCopy1
// we allow out-of-bound read from src in D3 dimension, // we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride2 is big enough, // but we need to make sure dst stride2 is big enough,
// so that the out-of-bound write won't contaminate next line in dst // so that the out-of-bound write won't contaminate next line in dst
constexpr unsigned L3 = CopyLengths{}.Get(I3); constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr unsigned read_per_d3 = integer_divide_ceil(L3, DataPerRead); constexpr index_t read_per_d3 = integer_divide_ceil(L3, DataPerRead);
static_assert(read_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2), static_assert(read_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
"wrong! out-of-bound write will contaminate next line!\n"); "wrong! out-of-bound write will contaminate next line!\n");
...@@ -247,20 +247,20 @@ struct Blockwise4dTensorCopy1 ...@@ -247,20 +247,20 @@ struct Blockwise4dTensorCopy1
constexpr auto src_desc = SrcDesc{}; constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{}; constexpr auto dst_desc = DstDesc{};
constexpr unsigned L0 = CopyLengths{}.Get(I0); constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1); constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned L2 = CopyLengths{}.Get(I2); constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr unsigned L3 = CopyLengths{}.Get(I3); constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr unsigned read_per_d3 = integer_divide_ceil(L3, DataPerRead); constexpr index_t read_per_d3 = integer_divide_ceil(L3, DataPerRead);
constexpr auto ref_desc = constexpr auto ref_desc =
make_ConstantTensorDescriptor(Sequence<L0, L1, L2, read_per_d3>{}); make_ConstantTensorDescriptor(Sequence<L0, L1, L2, read_per_d3>{});
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize; constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
auto f_copy = [&](unsigned is) { auto f_copy = [&](index_t is) {
unsigned did[4]; index_t did[4];
did[0] = is / ref_desc.GetStride(I0); did[0] = is / ref_desc.GetStride(I0);
...@@ -276,18 +276,18 @@ struct Blockwise4dTensorCopy1 ...@@ -276,18 +276,18 @@ struct Blockwise4dTensorCopy1
did[3] = is / ref_desc.GetStride(I3); did[3] = is / ref_desc.GetStride(I3);
const unsigned src_index = const index_t src_index =
src_desc.Get1dIndex(did[0], did[1], did[2], did[3] * DataPerRead); src_desc.Get1dIndex(did[0], did[1], did[2], did[3] * DataPerRead);
const unsigned dst_index = const index_t dst_index =
dst_desc.Get1dIndex(did[0], did[1], did[2], did[3] * DataPerRead); dst_desc.Get1dIndex(did[0], did[1], did[2], did[3] * DataPerRead);
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) = *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src + src_index)); *(reinterpret_cast<const vector_t*>(p_src + src_index));
}; };
for(unsigned iloop = 0; iloop < NLoop; ++iloop) for(index_t iloop = 0; iloop < NLoop; ++iloop)
{ {
unsigned is = threadIdx.x + iloop * BlockSize; index_t is = threadIdx.x + iloop * BlockSize;
f_copy(is); f_copy(is);
} }
...@@ -296,7 +296,7 @@ struct Blockwise4dTensorCopy1 ...@@ -296,7 +296,7 @@ struct Blockwise4dTensorCopy1
if(has_tail) if(has_tail)
{ {
unsigned is = threadIdx.x + NLoop * BlockSize; index_t is = threadIdx.x + NLoop * BlockSize;
if(is < ref_desc.GetElementSize()) if(is < ref_desc.GetElementSize())
{ {
...@@ -306,7 +306,7 @@ struct Blockwise4dTensorCopy1 ...@@ -306,7 +306,7 @@ struct Blockwise4dTensorCopy1
} }
}; };
template <unsigned BlockSize, template <index_t BlockSize,
class Float, class Float,
class SrcDesc, class SrcDesc,
class DstDesc, class DstDesc,
...@@ -315,15 +315,15 @@ template <unsigned BlockSize, ...@@ -315,15 +315,15 @@ template <unsigned BlockSize,
struct BlockwiseChwnTensorCopyPadded struct BlockwiseChwnTensorCopyPadded
{ {
__device__ void Run(const Float* __restrict__ p_src, __device__ void Run(const Float* __restrict__ p_src,
unsigned c_block_data_begin, index_t c_block_data_begin,
unsigned ho_block_data_begin, index_t ho_block_data_begin,
unsigned wo_block_data_begin, index_t wo_block_data_begin,
unsigned n_block_data_begin, index_t n_block_data_begin,
Float* __restrict__ p_dst, Float* __restrict__ p_dst,
unsigned h_block_pad_low, index_t h_block_pad_low,
unsigned w_block_pad_low, index_t w_block_pad_low,
unsigned h_block_pad_up, index_t h_block_pad_up,
unsigned w_block_pad_up) const index_t w_block_pad_up) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -337,7 +337,7 @@ struct BlockwiseChwnTensorCopyPadded ...@@ -337,7 +337,7 @@ struct BlockwiseChwnTensorCopyPadded
constexpr auto h_global_pad_low = GlobalLowerPads{}.Get(I0); constexpr auto h_global_pad_low = GlobalLowerPads{}.Get(I0);
constexpr auto w_global_pad_low = GlobalLowerPads{}.Get(I1); constexpr auto w_global_pad_low = GlobalLowerPads{}.Get(I1);
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize; constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
const Float* p_src_tmp = const Float* p_src_tmp =
p_src + p_src +
...@@ -368,11 +368,11 @@ struct BlockwiseChwnTensorCopyPadded ...@@ -368,11 +368,11 @@ struct BlockwiseChwnTensorCopyPadded
} }
#endif #endif
for(unsigned iloop = 0; iloop < NLoop; ++iloop) for(index_t iloop = 0; iloop < NLoop; ++iloop)
{ {
unsigned is = threadIdx.x + iloop * BlockSize; index_t is = threadIdx.x + iloop * BlockSize;
unsigned did[4]; index_t did[4];
did[0] = is / ref_desc.GetStride(I0); did[0] = is / ref_desc.GetStride(I0);
...@@ -388,7 +388,7 @@ struct BlockwiseChwnTensorCopyPadded ...@@ -388,7 +388,7 @@ struct BlockwiseChwnTensorCopyPadded
did[3] = is / ref_desc.GetStride(I3); did[3] = is / ref_desc.GetStride(I3);
const unsigned bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]); const index_t bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
p_dst[bindex] = p_dst[bindex] =
(did[1] < h_block_pad_low || did[1] + h_block_pad_up >= ref_desc.GetLength(I1) || (did[1] < h_block_pad_low || did[1] + h_block_pad_up >= ref_desc.GetLength(I1) ||
...@@ -401,11 +401,11 @@ struct BlockwiseChwnTensorCopyPadded ...@@ -401,11 +401,11 @@ struct BlockwiseChwnTensorCopyPadded
if(has_tail) if(has_tail)
{ {
unsigned is = threadIdx.x + NLoop * BlockSize; index_t is = threadIdx.x + NLoop * BlockSize;
if(is < ref_desc.GetElementSize()) if(is < ref_desc.GetElementSize())
{ {
unsigned did[4]; index_t did[4];
did[0] = is / ref_desc.GetStride(I0); did[0] = is / ref_desc.GetStride(I0);
...@@ -421,7 +421,7 @@ struct BlockwiseChwnTensorCopyPadded ...@@ -421,7 +421,7 @@ struct BlockwiseChwnTensorCopyPadded
did[3] = is / ref_desc.GetStride(I3); did[3] = is / ref_desc.GetStride(I3);
const unsigned bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]); const index_t bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
p_dst[bindex] = p_dst[bindex] =
(did[1] < h_block_pad_low || (did[1] < h_block_pad_low ||
...@@ -436,19 +436,19 @@ struct BlockwiseChwnTensorCopyPadded ...@@ -436,19 +436,19 @@ struct BlockwiseChwnTensorCopyPadded
// starting point need to be aligned to float4 or float2 or float // starting point need to be aligned to float4 or float2 or float
// stride3 need to be 1 for both source and destination // stride3 need to be 1 for both source and destination
template <unsigned BlockSize, template <index_t BlockSize,
class Float, class Float,
class SrcDesc, class SrcDesc,
class DstDesc, class DstDesc,
class CopyLengths, class CopyLengths,
class ThreadPerDims, class ThreadPerDims,
unsigned DataPerRead> index_t DataPerRead>
struct Blockwise4dTensorCopy3 struct Blockwise4dTensorCopy3
{ {
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType; using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
unsigned mSrcMyThreadOffset; index_t mSrcMyThreadOffset;
unsigned mDstMyThreadOffset; index_t mDstMyThreadOffset;
__device__ Blockwise4dTensorCopy3() __device__ Blockwise4dTensorCopy3()
{ {
...@@ -469,20 +469,20 @@ struct Blockwise4dTensorCopy3 ...@@ -469,20 +469,20 @@ struct Blockwise4dTensorCopy3
DstDesc{}.GetStride(I2) % DataPerRead == 0, DstDesc{}.GetStride(I2) % DataPerRead == 0,
"wrong! src and dst stride2 should be multiple of DataPerRead to keep alignment"); "wrong! src and dst stride2 should be multiple of DataPerRead to keep alignment");
constexpr unsigned L0 = CopyLengths{}.Get(I0); constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1); constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned L2 = CopyLengths{}.Get(I2); constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr unsigned L3 = CopyLengths{}.Get(I3); constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr unsigned thread_per_d0 = ThreadPerDims{}.Get(I0); constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr unsigned thread_per_d1 = ThreadPerDims{}.Get(I1); constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr unsigned thread_per_d2 = ThreadPerDims{}.Get(I2); constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr unsigned thread_per_d3 = ThreadPerDims{}.Get(I3); constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
// we allow out-of-bound read from src in D3 dimension, // we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride is big enough, // but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst // so that the out-of-bound write won't contaminate next line in dst
constexpr unsigned nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
static_assert(nloop_d3 * thread_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2), static_assert(nloop_d3 * thread_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
"wrong! out-of-bound write will contaminate next line!\n"); "wrong! out-of-bound write will contaminate next line!\n");
...@@ -493,7 +493,7 @@ struct Blockwise4dTensorCopy3 ...@@ -493,7 +493,7 @@ struct Blockwise4dTensorCopy3
static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3, static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3,
"wrrong! BlockSize is not big enough for ThreadPerDims!"); "wrrong! BlockSize is not big enough for ThreadPerDims!");
constexpr unsigned num_active_thread = constexpr index_t num_active_thread =
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3; thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
if(BlockSize > num_active_thread) if(BlockSize > num_active_thread)
...@@ -504,14 +504,14 @@ struct Blockwise4dTensorCopy3 ...@@ -504,14 +504,14 @@ struct Blockwise4dTensorCopy3
} }
} }
const unsigned thread_id_d0 = const index_t thread_id_d0 =
get_thread_local_1d_id() / (thread_per_d1 * thread_per_d2 * thread_per_d3); get_thread_local_1d_id() / (thread_per_d1 * thread_per_d2 * thread_per_d3);
unsigned itmp = get_thread_local_1d_id() - index_t itmp = get_thread_local_1d_id() -
thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3); thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3);
const unsigned thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3); const index_t thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3);
itmp -= thread_id_d1 * (thread_per_d2 * thread_per_d3); itmp -= thread_id_d1 * (thread_per_d2 * thread_per_d3);
const unsigned thread_id_d2 = itmp / thread_per_d3; const index_t thread_id_d2 = itmp / thread_per_d3;
const unsigned thread_id_d3 = itmp - thread_id_d2 * thread_per_d3; const index_t thread_id_d3 = itmp - thread_id_d2 * thread_per_d3;
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex( mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(
thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead); thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead);
...@@ -526,17 +526,17 @@ struct Blockwise4dTensorCopy3 ...@@ -526,17 +526,17 @@ struct Blockwise4dTensorCopy3
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0); constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1); constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned L2 = CopyLengths{}.Get(I2); constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr unsigned L3 = CopyLengths{}.Get(I3); constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr unsigned thread_per_d0 = ThreadPerDims{}.Get(I0); constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr unsigned thread_per_d1 = ThreadPerDims{}.Get(I1); constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr unsigned thread_per_d2 = ThreadPerDims{}.Get(I2); constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr unsigned thread_per_d3 = ThreadPerDims{}.Get(I3); constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
constexpr unsigned num_active_thread = constexpr index_t num_active_thread =
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3; thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
if(BlockSize > num_active_thread) if(BlockSize > num_active_thread)
...@@ -547,30 +547,30 @@ struct Blockwise4dTensorCopy3 ...@@ -547,30 +547,30 @@ struct Blockwise4dTensorCopy3
} }
} }
constexpr unsigned nloop_d0 = L0 / thread_per_d0; constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr unsigned nloop_d1 = L1 / thread_per_d1; constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr unsigned nloop_d2 = L2 / thread_per_d2; constexpr index_t nloop_d2 = L2 / thread_per_d2;
constexpr unsigned nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
#pragma unroll #pragma unroll
for(unsigned iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0) for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
{ {
#pragma unroll #pragma unroll
for(unsigned iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1) for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
{ {
#pragma unroll #pragma unroll
for(unsigned iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2) for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
{ {
#pragma unroll #pragma unroll
for(unsigned iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3) for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
{ {
const unsigned src_offset = const index_t src_offset =
SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1, iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2, iloop_d2 * thread_per_d2,
iloop_d3 * thread_per_d3 * DataPerRead); iloop_d3 * thread_per_d3 * DataPerRead);
const unsigned dst_offset = const index_t dst_offset =
DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1, iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2, iloop_d2 * thread_per_d2,
......
#pragma once #pragma once
#include "threadwise_gemm.hip.hpp" #include "threadwise_gemm.hip.hpp"
template <unsigned BlockSize, template <index_t BlockSize,
class BlockMatrixA, class BlockMatrixA,
class BlockMatrixB, class BlockMatrixB,
class ThreadMatrixC, class ThreadMatrixC,
bool TransA, bool TransA,
bool TransB, bool TransB,
bool TransC, bool TransC,
unsigned BlockMatrixStrideA, index_t BlockMatrixStrideA,
unsigned BlockMatrixStrideB, index_t BlockMatrixStrideB,
unsigned ThreadMatrixStrideC, index_t ThreadMatrixStrideC,
unsigned BatchSize, index_t BatchSize,
unsigned BatchPerThread, index_t BatchPerThread,
unsigned KPerThreadLoop, index_t KPerThreadLoop,
bool DistributeThreadAlongColumnFirst> bool DistributeThreadAlongColumnFirst>
struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
{ {
unsigned mMyThreadOffsetA = 0; index_t mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0; index_t mMyThreadOffsetB = 0;
struct MatrixIndex struct MatrixIndex
{ {
unsigned batch; index_t batch;
unsigned row; index_t row;
unsigned col; index_t col;
}; };
__device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC() __device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC()
...@@ -61,7 +61,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC ...@@ -61,7 +61,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
#endif #endif
} }
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const __device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
{ {
if(TransA && (!TransB) && (!TransC)) if(TransA && (!TransB) && (!TransC))
...@@ -72,22 +72,22 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC ...@@ -72,22 +72,22 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! k dimension not consistent!"); "wrong! k dimension not consistent!");
constexpr unsigned MPerBlock = a_block_mtx.NCol(); constexpr index_t MPerBlock = a_block_mtx.NCol();
constexpr unsigned NPerBlock = b_block_mtx.NCol(); constexpr index_t NPerBlock = b_block_mtx.NCol();
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
// divide thread work // divide thread work
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert(BatchSize % BatchPerThread == 0, "BatchSize % BatchPerThread != 0"); static_assert(BatchSize % BatchPerThread == 0, "BatchSize % BatchPerThread != 0");
static_assert(MPerBlock % MPerThread == 0, "MPerBlock % MPerThread != 0"); static_assert(MPerBlock % MPerThread == 0, "MPerBlock % MPerThread != 0");
static_assert(NPerBlock % NPerThread == 0, "NPerBlock % NPerThread != 0"); static_assert(NPerBlock % NPerThread == 0, "NPerBlock % NPerThread != 0");
constexpr unsigned BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread; constexpr index_t BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread;
constexpr unsigned MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread; constexpr index_t MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread;
constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
static_assert(BlockSize == BatchThreadWork * MThreadWork * NThreadWork, static_assert(BlockSize == BatchThreadWork * MThreadWork * NThreadWork,
"wrong! wrong BlockSize"); "wrong! wrong BlockSize");
...@@ -95,10 +95,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC ...@@ -95,10 +95,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
if(DistributeThreadAlongColumnFirst) if(DistributeThreadAlongColumnFirst)
{ {
// num of operations can be reduced // num of operations can be reduced
const unsigned b_work_id = thread_id / (MThreadWork * NThreadWork); const index_t b_work_id = thread_id / (MThreadWork * NThreadWork);
unsigned itmp = thread_id - b_work_id * (MThreadWork * NThreadWork); index_t itmp = thread_id - b_work_id * (MThreadWork * NThreadWork);
const unsigned m_work_id = itmp / NThreadWork; const index_t m_work_id = itmp / NThreadWork;
const unsigned n_work_id = itmp - m_work_id * NThreadWork; const index_t n_work_id = itmp - m_work_id * NThreadWork;
return MatrixIndex{ return MatrixIndex{
b_work_id * BatchPerThread, m_work_id * MPerThread, n_work_id * NPerThread}; b_work_id * BatchPerThread, m_work_id * MPerThread, n_work_id * NPerThread};
...@@ -118,7 +118,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC ...@@ -118,7 +118,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
// this should be optimized away if input is known // this should be optimized away if input is known
__device__ static MatrixIndex __device__ static MatrixIndex
GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c) GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
{ {
return MatrixIndex{batch_in_c, m_in_c, n_in_c}; return MatrixIndex{batch_in_c, m_in_c, n_in_c};
} }
...@@ -138,10 +138,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC ...@@ -138,10 +138,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
// a is transposed, b is not // a is transposed, b is not
constexpr auto a_thread_mtx = constexpr auto a_thread_mtx =
...@@ -154,7 +154,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC ...@@ -154,7 +154,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
// loop over k // loop over k
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{ {
// read first batch of a, b // read first batch of a, b
threadwise_matrix_copy(a_block_mtx, threadwise_matrix_copy(a_block_mtx,
...@@ -172,7 +172,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC ...@@ -172,7 +172,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
b_thread_mtx.GetLengths()); b_thread_mtx.GetLengths());
// loop over batch // loop over batch
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
{ {
// do current batch of gemm // do current batch of gemm
threadwise_gemm(a_thread_mtx, threadwise_gemm(a_thread_mtx,
...@@ -226,32 +226,32 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC ...@@ -226,32 +226,32 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
} }
}; };
template <unsigned BlockSize, template <index_t BlockSize,
class BlockMatrixA, class BlockMatrixA,
class BlockMatrixB, class BlockMatrixB,
class ThreadMatrixC, class ThreadMatrixC,
unsigned BlockMatrixStrideA, index_t BlockMatrixStrideA,
unsigned BlockMatrixStrideB, index_t BlockMatrixStrideB,
unsigned ThreadMatrixStrideC, index_t ThreadMatrixStrideC,
unsigned BatchSize, index_t BatchSize,
unsigned MPerThreadSubC, index_t MPerThreadSubC,
unsigned NPerThreadSubC, index_t NPerThreadSubC,
unsigned MLevel0Cluster, index_t MLevel0Cluster,
unsigned NLevel0Cluster, index_t NLevel0Cluster,
unsigned MLevel1Cluster, index_t MLevel1Cluster,
unsigned NLevel1Cluster, index_t NLevel1Cluster,
unsigned KPerThreadLoop, index_t KPerThreadLoop,
unsigned BatchPerThread> index_t BatchPerThread>
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{ {
unsigned mMyThreadOffsetA = 0; index_t mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0; index_t mMyThreadOffsetB = 0;
struct MatrixIndex struct MatrixIndex
{ {
unsigned batch; index_t batch;
unsigned row; index_t row;
unsigned col; index_t col;
}; };
__device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2() __device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2()
...@@ -259,9 +259,9 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -259,9 +259,9 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(BatchSize % BatchPerThread == 0, static_assert(BatchSize % BatchPerThread == 0,
"wrong! BatchSize is not dividable by BatchPerThread"); "wrong! BatchSize is not dividable by BatchPerThread");
constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread; constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
constexpr unsigned ThreadPerLevel1Cluster = constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster, static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster,
...@@ -274,31 +274,31 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -274,31 +274,31 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! K dimension not consistent\n"); "wrong! K dimension not consistent\n");
constexpr unsigned M = a_block_mtx.NCol(); // A is transposed constexpr index_t M = a_block_mtx.NCol(); // A is transposed
constexpr unsigned N = b_block_mtx.NCol(); constexpr index_t N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow(); constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0), static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
"wrong! Cannot evenly divide thread work among repeat \n"); "wrong! Cannot evenly divide thread work among repeat \n");
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert((M % MRepeat == 0) && (N % NRepeat == 0), static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
"wrong! Cannot evenly divide work among repeat\n"); "wrong! Cannot evenly divide work among repeat\n");
constexpr unsigned MPerLevel1Cluster = M / MRepeat; constexpr index_t MPerLevel1Cluster = M / MRepeat;
constexpr unsigned NPerLevel1Cluster = N / NRepeat; constexpr index_t NPerLevel1Cluster = N / NRepeat;
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) && static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
(NPerLevel1Cluster % NLevel1Cluster == 0), (NPerLevel1Cluster % NLevel1Cluster == 0),
"wrong! Cannot evenly divide work among Level1Cluster\n"); "wrong! Cannot evenly divide work among Level1Cluster\n");
constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster; constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster; constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) && static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
(NPerLevel0Cluster % NLevel0Cluster == 0), (NPerLevel0Cluster % NLevel0Cluster == 0),
...@@ -335,28 +335,28 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -335,28 +335,28 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
#endif #endif
} }
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const __device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
{ {
constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread; constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
constexpr unsigned ThreadPerLevel1Cluster = constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster; constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
unsigned batch_work_id = thread_id / ThreadPerLevel1Cluster; index_t batch_work_id = thread_id / ThreadPerLevel1Cluster;
unsigned cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster; index_t cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster;
unsigned level1_id = cluster_id / ThreadPerLevel0Cluster; index_t level1_id = cluster_id / ThreadPerLevel0Cluster;
unsigned level1_m_id = level1_id / NLevel1Cluster; index_t level1_m_id = level1_id / NLevel1Cluster;
unsigned level1_n_id = level1_id % NLevel1Cluster; index_t level1_n_id = level1_id % NLevel1Cluster;
unsigned level0_id = cluster_id % ThreadPerLevel0Cluster; index_t level0_id = cluster_id % ThreadPerLevel0Cluster;
unsigned level0_m_id = level0_id / NLevel0Cluster; index_t level0_m_id = level0_id / NLevel0Cluster;
unsigned level0_n_id = level0_id % NLevel0Cluster; index_t level0_n_id = level0_id % NLevel0Cluster;
constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster; constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster; constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
return MatrixIndex{batch_work_id * BatchPerThread, return MatrixIndex{batch_work_id * BatchPerThread,
level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC, level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
...@@ -365,24 +365,24 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -365,24 +365,24 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// this should be optimized away if input is known // this should be optimized away if input is known
__device__ static MatrixIndex __device__ static MatrixIndex
GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c) GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
{ {
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
unsigned m_repeat = m_in_c / MPerThreadSubC; index_t m_repeat = m_in_c / MPerThreadSubC;
unsigned n_repeat = n_in_c / NPerThreadSubC; index_t n_repeat = n_in_c / NPerThreadSubC;
unsigned m_in_sub_c = m_in_c % MPerThreadSubC; index_t m_in_sub_c = m_in_c % MPerThreadSubC;
unsigned n_in_sub_c = n_in_c % NPerThreadSubC; index_t n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{batch_in_c, return MatrixIndex{batch_in_c,
m_repeat * MPerLevel1Cluster + m_in_sub_c, m_repeat * MPerLevel1Cluster + m_in_sub_c,
...@@ -402,10 +402,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -402,10 +402,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM // thread A, B for GEMM
// A is transposed, b is not // A is transposed, b is not
...@@ -425,20 +425,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -425,20 +425,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// loop over k // loop over k
#pragma unroll #pragma unroll
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{ {
// read first batch of A, B // read first batch of A, B
// copy A-sub to form A // copy A-sub to form A
#pragma unroll #pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
a_block_mtx, a_block_mtx,
...@@ -451,7 +451,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -451,7 +451,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// copy B-sub to form B // copy B-sub to form B
#pragma unroll #pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
b_block_mtx, b_block_mtx,
...@@ -464,7 +464,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -464,7 +464,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// loop over batch // loop over batch
#pragma unroll #pragma unroll
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
{ {
// do current batch of gemm // do current batch of gemm
threadwise_gemm(a_thread_mtx, threadwise_gemm(a_thread_mtx,
...@@ -482,7 +482,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -482,7 +482,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideA != 0) if(BlockMatrixStrideA != 0)
{ {
#pragma unroll #pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
a_block_mtx, a_block_mtx,
...@@ -498,7 +498,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -498,7 +498,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideB != 0) if(BlockMatrixStrideB != 0)
{ {
#pragma unroll #pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
b_block_mtx, b_block_mtx,
...@@ -539,10 +539,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -539,10 +539,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM // thread A, B for GEMM
// A is transposed, b is not // A is transposed, b is not
...@@ -562,25 +562,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -562,25 +562,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// loop over k // loop over k
//#pragma unroll //#pragma unroll
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{ {
// read first batch of A, B // read first batch of A, B
// copy A-sub to form A // copy A-sub to form A
//#pragma unroll //#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ {
for(unsigned i = 0; i < a_thread_sub_mtx.NRow(); ++i) for(index_t i = 0; i < a_thread_sub_mtx.NRow(); ++i)
{ {
#if 1 #if 1
for(unsigned j = 0; j < a_thread_sub_mtx.NCol(); ++j) for(index_t j = 0; j < a_thread_sub_mtx.NCol(); ++j)
{ {
p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] = p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] =
p_a_block[a_block_mtx.Get1dIndex(k_begin + i, p_a_block[a_block_mtx.Get1dIndex(k_begin + i,
...@@ -596,11 +596,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -596,11 +596,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// copy B-sub to form B // copy B-sub to form B
//#pragma unroll //#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ {
for(unsigned i = 0; i < b_thread_sub_mtx.NRow(); ++i) for(index_t i = 0; i < b_thread_sub_mtx.NRow(); ++i)
{ {
for(unsigned j = 0; j < b_thread_sub_mtx.NCol(); ++j) for(index_t j = 0; j < b_thread_sub_mtx.NCol(); ++j)
{ {
p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] = p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] =
p_b_block[b_block_mtx.Get1dIndex(k_begin + i, p_b_block[b_block_mtx.Get1dIndex(k_begin + i,
...@@ -612,20 +612,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -612,20 +612,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// loop over batch // loop over batch
//#pragma unroll //#pragma unroll
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
{ {
// do current batch of gemm // do current batch of gemm
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) for(index_t k = 0; k < a_thread_mtx.NRow(); ++k)
{ {
#if 0 #if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{ {
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j) for(index_t j = 0; j < c_thread_mtx.NCol(); ++j)
{ {
const unsigned aindex = const index_t aindex =
a_thread_mtx.Get1dIndex(k, i); // A is transposed a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j); const index_t bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex = const index_t cindex =
c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC; c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
...@@ -635,11 +635,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -635,11 +635,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4, static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4,
"asm is only for 16x4"); "asm is only for 16x4");
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0); const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{ {
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0); const index_t cindex = c_thread_mtx.Get1dIndex(i, 0);
asm volatile("\n \ asm volatile("\n \
v_mac_f32 %0, %4, %5 \n \ v_mac_f32 %0, %4, %5 \n \
...@@ -668,11 +668,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -668,11 +668,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideA != 0) if(BlockMatrixStrideA != 0)
{ {
//#pragma unroll //#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ {
for(unsigned i = 0; i < a_thread_sub_mtx.NRow(); ++i) for(index_t i = 0; i < a_thread_sub_mtx.NRow(); ++i)
{ {
for(unsigned j = 0; j < a_thread_sub_mtx.NCol(); ++j) for(index_t j = 0; j < a_thread_sub_mtx.NCol(); ++j)
{ {
p_a_thread[a_thread_mtx.Get1dIndex(i, p_a_thread[a_thread_mtx.Get1dIndex(i,
m_repeat * MPerThreadSubC + j)] = m_repeat * MPerThreadSubC + j)] =
...@@ -687,11 +687,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -687,11 +687,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideB != 0) if(BlockMatrixStrideB != 0)
{ {
//#pragma unroll //#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ {
for(unsigned i = 0; i < b_thread_sub_mtx.NRow(); ++i) for(index_t i = 0; i < b_thread_sub_mtx.NRow(); ++i)
{ {
for(unsigned j = 0; j < b_thread_sub_mtx.NCol(); ++j) for(index_t j = 0; j < b_thread_sub_mtx.NCol(); ++j)
{ {
p_b_thread[b_thread_mtx.Get1dIndex(i, p_b_thread[b_thread_mtx.Get1dIndex(i,
n_repeat * NPerThreadSubC + j)] = n_repeat * NPerThreadSubC + j)] =
...@@ -705,16 +705,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -705,16 +705,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
} }
// do last batch of gemm // do last batch of gemm
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) for(index_t k = 0; k < a_thread_mtx.NRow(); ++k)
{ {
#if 0 #if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{ {
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j) for(index_t j = 0; j < c_thread_mtx.NCol(); ++j)
{ {
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j); const index_t bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex = c_thread_mtx.Get1dIndex(i, j) + const index_t cindex = c_thread_mtx.Get1dIndex(i, j) +
(BatchPerThread - 1) * ThreadMatrixStrideC; (BatchPerThread - 1) * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
...@@ -724,11 +724,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -724,11 +724,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4, static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4,
"asm is only for 16x4"); "asm is only for 16x4");
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0); const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{ {
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned cindex = const index_t cindex =
c_thread_mtx.Get1dIndex(i, 0) + (BatchPerThread - 1) * ThreadMatrixStrideC; c_thread_mtx.Get1dIndex(i, 0) + (BatchPerThread - 1) * ThreadMatrixStrideC;
asm volatile("\n \ asm volatile("\n \
...@@ -756,34 +756,34 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -756,34 +756,34 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
} }
} }
template <class BlockMatrixC, unsigned BlockMatrixStrideC, class FloatC> template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
__device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread, __device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
FloatC* __restrict__ p_c_block) const FloatC* __restrict__ p_c_block) const
{ {
constexpr auto c_block_mtx = BlockMatrixC{}; constexpr auto c_block_mtx = BlockMatrixC{};
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor( constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{}); Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned c_thread_offset = const index_t c_thread_offset =
c_thread_mtx_begin.batch * BlockMatrixStrideC + c_thread_mtx_begin.batch * BlockMatrixStrideC +
c_block_mtx.Get1dIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col); c_block_mtx.Get1dIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col);
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ {
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
c_thread_sub_mtx, c_thread_sub_mtx,
......
...@@ -3,16 +3,16 @@ ...@@ -3,16 +3,16 @@
#include "threadwise_4d_tensor_op.hip.hpp" #include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_direct_convolution.hip.hpp" #include "threadwise_direct_convolution.hip.hpp"
template <unsigned BlockSize, template <index_t BlockSize,
class Float, class Float,
class InBlockDesc, class InBlockDesc,
class WeiBlockDesc, class WeiBlockDesc,
class OutBlockDesc, class OutBlockDesc,
unsigned NPerThread, index_t NPerThread,
unsigned KPerThread, index_t KPerThread,
unsigned CPerThread, index_t CPerThread,
unsigned HoPerThread, index_t HoPerThread,
unsigned WoPerThread> index_t WoPerThread>
__device__ void blockwise_direct_convolution(InBlockDesc, __device__ void blockwise_direct_convolution(InBlockDesc,
Float* const __restrict__ p_in_block, Float* const __restrict__ p_in_block,
WeiBlockDesc, WeiBlockDesc,
...@@ -29,17 +29,17 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -29,17 +29,17 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr auto wei_block_desc = WeiBlockDesc{}; constexpr auto wei_block_desc = WeiBlockDesc{};
constexpr auto out_block_desc = OutBlockDesc{}; constexpr auto out_block_desc = OutBlockDesc{};
constexpr unsigned Y = wei_block_desc.GetLength(I2); constexpr index_t Y = wei_block_desc.GetLength(I2);
constexpr unsigned X = wei_block_desc.GetLength(I3); constexpr index_t X = wei_block_desc.GetLength(I3);
constexpr unsigned InTileSizeH = HoPerThread + Y - 1; constexpr index_t InTileSizeH = HoPerThread + Y - 1;
constexpr unsigned InTileSizeW = WoPerThread + X - 1; constexpr index_t InTileSizeW = WoPerThread + X - 1;
// divide thread work // divide thread work
constexpr unsigned NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread; constexpr index_t NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread;
constexpr unsigned KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread; constexpr index_t KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread;
constexpr unsigned YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread; constexpr index_t YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread;
constexpr unsigned XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread; constexpr index_t XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread;
#if 0 #if 0
if(threadIdx.x == 0) if(threadIdx.x == 0)
...@@ -68,27 +68,27 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -68,27 +68,27 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr auto out_thread_block_desc = constexpr auto out_thread_block_desc =
make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_block_desc.GetStrides()); make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_block_desc.GetStrides());
const unsigned thread_id = threadIdx.x; const index_t thread_id = threadIdx.x;
for(unsigned thread_work_id = thread_id; for(index_t thread_work_id = thread_id;
thread_work_id < NThreadWork * KThreadWork * YThreadWork * XThreadWork; thread_work_id < NThreadWork * KThreadWork * YThreadWork * XThreadWork;
thread_work_id += BlockSize) thread_work_id += BlockSize)
{ {
unsigned itmp = thread_work_id; index_t itmp = thread_work_id;
unsigned n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork); index_t n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork);
itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork); itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork);
unsigned k_thread_work_id = itmp / (YThreadWork * XThreadWork); index_t k_thread_work_id = itmp / (YThreadWork * XThreadWork);
itmp -= k_thread_work_id * (YThreadWork * XThreadWork); itmp -= k_thread_work_id * (YThreadWork * XThreadWork);
unsigned y_thread_work_id = itmp / XThreadWork; index_t y_thread_work_id = itmp / XThreadWork;
unsigned x_thread_work_id = itmp - y_thread_work_id * XThreadWork; index_t x_thread_work_id = itmp - y_thread_work_id * XThreadWork;
unsigned n_thread_data_begin = n_thread_work_id * NPerThread; index_t n_thread_data_begin = n_thread_work_id * NPerThread;
unsigned k_thread_data_begin = k_thread_work_id * KPerThread; index_t k_thread_data_begin = k_thread_work_id * KPerThread;
unsigned ho_thread_data_begin = y_thread_work_id * HoPerThread; index_t ho_thread_data_begin = y_thread_work_id * HoPerThread;
unsigned wo_thread_data_begin = x_thread_work_id * WoPerThread; index_t wo_thread_data_begin = x_thread_work_id * WoPerThread;
unsigned hi_thread_data_begin = ho_thread_data_begin; // minus padding index_t hi_thread_data_begin = ho_thread_data_begin; // minus padding
unsigned wi_thread_data_begin = wo_thread_data_begin; // minus padding index_t wi_thread_data_begin = wo_thread_data_begin; // minus padding
Float p_out_thread[out_thread_desc.GetElementSpace()]; Float p_out_thread[out_thread_desc.GetElementSpace()];
...@@ -102,7 +102,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -102,7 +102,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
p_out_thread, p_out_thread,
out_thread_desc.GetLengths()); out_thread_desc.GetLengths());
for(unsigned c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1); for(index_t c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1);
c_thread_data_begin += CPerThread) c_thread_data_begin += CPerThread)
{ {
// threadwise convolution // threadwise convolution
......
#pragma once #pragma once
#include "threadwise_gemm.hip.hpp" #include "threadwise_gemm.hip.hpp"
template <unsigned BlockSize, template <index_t BlockSize,
class BlockMatrixA, class BlockMatrixA,
class BlockMatrixB, class BlockMatrixB,
class ThreadMatrixC, class ThreadMatrixC,
bool TransA, bool TransA,
bool TransB, bool TransB,
bool TransC, bool TransC,
unsigned KPerThreadLoop, index_t KPerThreadLoop,
unsigned MThreadPerCluster, index_t MThreadPerCluster,
unsigned NThreadPerCluster, index_t NThreadPerCluster,
bool DistributeThreadAlongColumnFirst> bool DistributeThreadAlongColumnFirst>
struct BlockwiseGemmBlockABlockBThreadC struct BlockwiseGemmBlockABlockBThreadC
{ {
unsigned mMyThreadOffsetA = 0; index_t mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0; index_t mMyThreadOffsetB = 0;
struct MatrixIndex struct MatrixIndex
{ {
unsigned row; index_t row;
unsigned col; index_t col;
}; };
__device__ BlockwiseGemmBlockABlockBThreadC() __device__ BlockwiseGemmBlockABlockBThreadC()
...@@ -55,7 +55,7 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -55,7 +55,7 @@ struct BlockwiseGemmBlockABlockBThreadC
#endif #endif
} }
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const __device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
{ {
if(TransA && (!TransB) && (!TransC)) if(TransA && (!TransB) && (!TransC))
...@@ -66,14 +66,14 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -66,14 +66,14 @@ struct BlockwiseGemmBlockABlockBThreadC
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! k dimension not consistent!"); "wrong! k dimension not consistent!");
constexpr unsigned MPerBlock = a_block_mtx.NCol(); constexpr index_t MPerBlock = a_block_mtx.NCol();
constexpr unsigned NPerBlock = b_block_mtx.NCol(); constexpr index_t NPerBlock = b_block_mtx.NCol();
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
// divide thread work // divide thread work
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert(MPerBlock % (MPerThread * MThreadPerCluster) == 0, static_assert(MPerBlock % (MPerThread * MThreadPerCluster) == 0,
"MPerBlock % (MPerThread * MThreadPerCluster) != 0"); "MPerBlock % (MPerThread * MThreadPerCluster) != 0");
...@@ -81,10 +81,10 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -81,10 +81,10 @@ struct BlockwiseGemmBlockABlockBThreadC
static_assert(NPerBlock % (NPerThread * NThreadPerCluster) == 0, static_assert(NPerBlock % (NPerThread * NThreadPerCluster) == 0,
"NPerBlock % (NPerThread * NThreadPerCluster) != 0"); "NPerBlock % (NPerThread * NThreadPerCluster) != 0");
constexpr unsigned MClusterWork = constexpr index_t MClusterWork =
(MPerBlock + MPerThread * MThreadPerCluster - 1) / (MPerThread * MThreadPerCluster); (MPerBlock + MPerThread * MThreadPerCluster - 1) / (MPerThread * MThreadPerCluster);
constexpr unsigned NClusterWork = constexpr index_t NClusterWork =
(NPerBlock + NPerThread * NThreadPerCluster - 1) / (NPerThread * NThreadPerCluster); (NPerBlock + NPerThread * NThreadPerCluster - 1) / (NPerThread * NThreadPerCluster);
static_assert(BlockSize == static_assert(BlockSize ==
...@@ -94,19 +94,18 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -94,19 +94,18 @@ struct BlockwiseGemmBlockABlockBThreadC
if(DistributeThreadAlongColumnFirst) if(DistributeThreadAlongColumnFirst)
{ {
const unsigned cluster_work_block_id = const index_t cluster_work_block_id =
thread_id / (MThreadPerCluster * NThreadPerCluster); thread_id / (MThreadPerCluster * NThreadPerCluster);
const unsigned thread_work_cluster_id = const index_t thread_work_cluster_id =
thread_id - cluster_work_block_id * (MThreadPerCluster * NThreadPerCluster); thread_id - cluster_work_block_id * (MThreadPerCluster * NThreadPerCluster);
const unsigned m_cluster_work_block_id = cluster_work_block_id / NClusterWork; const index_t m_cluster_work_block_id = cluster_work_block_id / NClusterWork;
const unsigned n_cluster_work_block_id = const index_t n_cluster_work_block_id =
cluster_work_block_id - m_cluster_work_block_id * NClusterWork; cluster_work_block_id - m_cluster_work_block_id * NClusterWork;
const unsigned m_thread_work_cluster_id = const index_t m_thread_work_cluster_id = thread_work_cluster_id / NThreadPerCluster;
thread_work_cluster_id / NThreadPerCluster; const index_t n_thread_work_cluster_id =
const unsigned n_thread_work_cluster_id =
thread_work_cluster_id - m_thread_work_cluster_id * NThreadPerCluster; thread_work_cluster_id - m_thread_work_cluster_id * NThreadPerCluster;
#if 0 #if 0
...@@ -143,8 +142,8 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -143,8 +142,8 @@ struct BlockwiseGemmBlockABlockBThreadC
} }
// this should be optimized away if input is known // this should be optimized away if input is known
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c, __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
unsigned n_in_c) index_t n_in_c)
{ {
return MatrixIndex{m_in_c, n_in_c}; return MatrixIndex{m_in_c, n_in_c};
} }
...@@ -164,10 +163,10 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -164,10 +163,10 @@ struct BlockwiseGemmBlockABlockBThreadC
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
// a is transposed, b is not // a is transposed, b is not
constexpr auto a_thread_mtx = constexpr auto a_thread_mtx =
...@@ -180,7 +179,7 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -180,7 +179,7 @@ struct BlockwiseGemmBlockABlockBThreadC
FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
// loop over k // loop over k
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{ {
threadwise_matrix_copy(a_block_mtx, threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA + p_a_block + mMyThreadOffsetA +
...@@ -213,31 +212,31 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -213,31 +212,31 @@ struct BlockwiseGemmBlockABlockBThreadC
// if following number are power of 2, index calculation shall be greatly reduced: // if following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster // MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
template <unsigned BlockSize, template <index_t BlockSize,
class BlockMatrixA, class BlockMatrixA,
class BlockMatrixB, class BlockMatrixB,
class ThreadMatrixC, class ThreadMatrixC,
unsigned MPerThreadSubC, index_t MPerThreadSubC,
unsigned NPerThreadSubC, index_t NPerThreadSubC,
unsigned MLevel0Cluster, index_t MLevel0Cluster,
unsigned NLevel0Cluster, index_t NLevel0Cluster,
unsigned MLevel1Cluster, index_t MLevel1Cluster,
unsigned NLevel1Cluster, index_t NLevel1Cluster,
unsigned KPerThreadLoop> index_t KPerThreadLoop>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{ {
struct MatrixIndex struct MatrixIndex
{ {
unsigned row; index_t row;
unsigned col; index_t col;
}; };
unsigned mMyThreadOffsetA; index_t mMyThreadOffsetA;
unsigned mMyThreadOffsetB; index_t mMyThreadOffsetB;
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2() __device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
{ {
constexpr unsigned ThreadPerLevel1Cluster = constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n"); static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
...@@ -249,31 +248,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -249,31 +248,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! K dimension not consistent\n"); "wrong! K dimension not consistent\n");
constexpr unsigned M = a_block_mtx.NCol(); // A is transposed constexpr index_t M = a_block_mtx.NCol(); // A is transposed
constexpr unsigned N = b_block_mtx.NCol(); constexpr index_t N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow(); constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0), static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
"wrong! Cannot evenly divide thread work among repeat \n"); "wrong! Cannot evenly divide thread work among repeat \n");
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert((M % MRepeat == 0) && (N % NRepeat == 0), static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
"wrong! Cannot evenly divide work among repeat\n"); "wrong! Cannot evenly divide work among repeat\n");
constexpr unsigned MPerLevel1Cluster = M / MRepeat; constexpr index_t MPerLevel1Cluster = M / MRepeat;
constexpr unsigned NPerLevel1Cluster = N / NRepeat; constexpr index_t NPerLevel1Cluster = N / NRepeat;
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) && static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
(NPerLevel1Cluster % NLevel1Cluster == 0), (NPerLevel1Cluster % NLevel1Cluster == 0),
"wrong! Cannot evenly divide work among Level1Cluster\n"); "wrong! Cannot evenly divide work among Level1Cluster\n");
constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster; constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster; constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) && static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
(NPerLevel0Cluster % NLevel0Cluster == 0), (NPerLevel0Cluster % NLevel0Cluster == 0),
...@@ -289,45 +288,45 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -289,45 +288,45 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col); mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);
} }
__device__ static MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{ {
constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster; constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
unsigned level1_id = thread_id / ThreadPerLevel0Cluster; index_t level1_id = thread_id / ThreadPerLevel0Cluster;
unsigned level1_m_id = level1_id / NLevel1Cluster; index_t level1_m_id = level1_id / NLevel1Cluster;
unsigned level1_n_id = level1_id % NLevel1Cluster; index_t level1_n_id = level1_id % NLevel1Cluster;
unsigned level0_id = thread_id % ThreadPerLevel0Cluster; index_t level0_id = thread_id % ThreadPerLevel0Cluster;
unsigned level0_m_id = level0_id / NLevel0Cluster; index_t level0_m_id = level0_id / NLevel0Cluster;
unsigned level0_n_id = level0_id % NLevel0Cluster; index_t level0_n_id = level0_id % NLevel0Cluster;
constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster; constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster; constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC, return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
} }
// this should be optimized away if input is known // this should be optimized away if input is known
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c, __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
unsigned n_in_c) index_t n_in_c)
{ {
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
unsigned m_repeat = m_in_c / MPerThreadSubC; index_t m_repeat = m_in_c / MPerThreadSubC;
unsigned n_repeat = n_in_c / NPerThreadSubC; index_t n_repeat = n_in_c / NPerThreadSubC;
unsigned m_in_sub_c = m_in_c % MPerThreadSubC; index_t m_in_sub_c = m_in_c % MPerThreadSubC;
unsigned n_in_sub_c = n_in_c % NPerThreadSubC; index_t n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c, return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
n_repeat * NPerLevel1Cluster + n_in_sub_c}; n_repeat * NPerLevel1Cluster + n_in_sub_c};
...@@ -346,12 +345,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -346,12 +345,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol(); constexpr index_t M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol(); constexpr index_t N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow(); constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM // thread A, B for GEMM
constexpr auto a_thread_mtx = constexpr auto a_thread_mtx =
...@@ -370,19 +369,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -370,19 +369,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
#pragma unroll #pragma unroll
// loop over k // loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{ {
#pragma unroll #pragma unroll
// copy A-sub to form A // copy A-sub to form A
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
a_block_mtx, a_block_mtx,
...@@ -395,7 +394,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -395,7 +394,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll #pragma unroll
// copy B-sub to form B // copy B-sub to form B
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
b_block_mtx, b_block_mtx,
...@@ -433,12 +432,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -433,12 +432,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol(); constexpr index_t M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol(); constexpr index_t N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow(); constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM // thread A, B for GEMM
constexpr auto a_thread_mtx = constexpr auto a_thread_mtx =
...@@ -457,19 +456,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -457,19 +456,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
#pragma unroll #pragma unroll
// loop over k // loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{ {
#pragma unroll //#pragma unroll
// copy A-sub to form A // copy A-sub to form A
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
a_block_mtx, a_block_mtx,
...@@ -480,9 +479,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -480,9 +479,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths());
} }
#pragma unroll //#pragma unroll
// copy B-sub to form B // copy B-sub to form B
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
b_block_mtx, b_block_mtx,
...@@ -505,19 +504,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -505,19 +504,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
False, False,
p_c_thread, p_c_thread,
f_accum); f_accum);
#else #elif 0
// inline asm // inline asm
static_assert(c_thread_mtx.NRow() == 8 && c_thread_mtx.NCol() == 8, static_assert(c_thread_mtx.NRow() == 8 && c_thread_mtx.NCol() == 8,
"asm is only for 8x8"); "asm is only for 8x8");
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed for(index_t k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed
{ {
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0); const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{ {
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0); const index_t cindex = c_thread_mtx.Get1dIndex(i, 0);
asm volatile("\n \ asm volatile("\n \
v_mac_f32 %0, %8, %9 \n \ v_mac_f32 %0, %8, %9 \n \
...@@ -573,12 +572,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -573,12 +572,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol(); constexpr index_t M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol(); constexpr index_t N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow(); constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM // thread A, B for GEMM
constexpr auto a_thread_mtx = constexpr auto a_thread_mtx =
...@@ -601,15 +600,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -601,15 +600,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()]; FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()]; FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// preload A, B // preload A, B
#pragma unroll #pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A { // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx, threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster, p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
...@@ -619,7 +618,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -619,7 +618,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
} }
#pragma unroll #pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B { // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx, threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster, p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
...@@ -631,7 +630,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -631,7 +630,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
bool even_loop = true; bool even_loop = true;
#pragma unroll #pragma unroll
for(unsigned k_begin = 0; k_begin + KPerThreadLoop < K; for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
k_begin += KPerThreadLoop, even_loop = !even_loop) k_begin += KPerThreadLoop, even_loop = !even_loop)
{ // loop over k { // loop over k
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1; FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
...@@ -642,7 +641,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -642,7 +641,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// preload next A, B // preload next A, B
#pragma unroll #pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A { // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx, threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA + p_a_block + mMyThreadOffsetA +
...@@ -654,7 +653,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -654,7 +653,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
} }
#pragma unroll #pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B { // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx, threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB + p_b_block + mMyThreadOffsetB +
...@@ -710,12 +709,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -710,12 +709,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol(); constexpr index_t M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol(); constexpr index_t N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow(); constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A-sub, B-sub, C-sub // thread A-sub, B-sub, C-sub
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
...@@ -737,15 +736,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -737,15 +736,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
#pragma unroll #pragma unroll
// loop over k // loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{ {
// C-sub(s) in first row-wise subblock of C // C-sub(s) in first row-wise subblock of C
{ {
...@@ -779,7 +778,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -779,7 +778,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll #pragma unroll
// copy next B-sub, and do GEMM // copy next B-sub, and do GEMM
for(unsigned n_repeat = 1; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 1; n_repeat < NRepeat; ++n_repeat)
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
b_block_mtx, b_block_mtx,
...@@ -805,7 +804,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -805,7 +804,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll #pragma unroll
// loop over rest of row-wise subblock // loop over rest of row-wise subblock
// all B-sub(s) has been copied, so only A-sub(s) need to be copied // all B-sub(s) has been copied, so only A-sub(s) need to be copied
for(unsigned m_repeat = 1; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 1; m_repeat < MRepeat; ++m_repeat)
{ {
// copy a A-sub // copy a A-sub
threadwise_matrix_copy( threadwise_matrix_copy(
...@@ -817,7 +816,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -817,7 +816,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths());
// do some GEMMs // do some GEMMs
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ {
threadwise_gemm( threadwise_gemm(
a_thread_sub_mtx, a_thread_sub_mtx,
......
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
#include "Array.hip.hpp" #include "Array.hip.hpp"
#include "functional.hip.hpp" #include "functional.hip.hpp"
__device__ unsigned get_thread_local_1d_id() { return threadIdx.x; } __device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ unsigned get_block_1d_id() { return blockIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; }
template <class T1, class T2> template <class T1, class T2>
struct is_same struct is_same
...@@ -35,7 +35,7 @@ __host__ __device__ constexpr T min(T a, T b) ...@@ -35,7 +35,7 @@ __host__ __device__ constexpr T min(T a, T b)
} }
#endif #endif
__host__ __device__ constexpr unsigned integer_divide_ceil(unsigned a, unsigned b) __host__ __device__ constexpr index_t integer_divide_ceil(index_t a, index_t b)
{ {
return (a + b - 1) / b; return (a + b - 1) / b;
} }
...@@ -11,3 +11,5 @@ ...@@ -11,3 +11,5 @@
#include "nvToolsExt.h" #include "nvToolsExt.h"
#include "helper_cuda.h" #include "helper_cuda.h"
#endif #endif
using index_t = uint32_t;
...@@ -8,5 +8,5 @@ struct integral_constant ...@@ -8,5 +8,5 @@ struct integral_constant
__host__ __device__ constexpr T Get() const { return value; } __host__ __device__ constexpr T Get() const { return value; }
}; };
template <unsigned N> template <index_t N>
using Number = integral_constant<unsigned, N>; using Number = integral_constant<index_t, N>;
#pragma once #pragma once
#include "config.h" #include "config.h"
template <class T, unsigned N> template <class T, index_t N>
struct vector_type struct vector_type
{ {
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment