Commit 766b0a9e authored by Chao Liu's avatar Chao Liu
Browse files

experimenting

parent f35c64eb
......@@ -10,7 +10,7 @@ void device_direct_convolution_1(InDesc,
const Tensor<T>& wei,
OutDesc,
Tensor<T>& out,
unsigned nrepeat)
index_t nrepeat)
{
std::size_t data_sz = sizeof(T);
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
......@@ -34,28 +34,28 @@ void device_direct_convolution_1(InDesc,
#if 1
// 3x3, 34x34
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 16;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 16;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(gridwise_direct_convolution_1<T,
InDesc,
......
......@@ -10,7 +10,7 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& wei,
OutDesc,
Tensor<T>& out,
unsigned nrepeat)
index_t nrepeat)
{
std::size_t data_sz = sizeof(T);
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
......@@ -34,49 +34,49 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
#if 1
// 3x3, 34x34, 128 thread
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 128;
#elif 1
// 3x3, 34x34, 128 thread, fp16
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw<T,
......
......@@ -10,13 +10,13 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
const Tensor<TInWei>& wei_kcyx,
OutDesc,
Tensor<TOut>& out_nkhw,
unsigned nrepeat)
index_t nrepeat)
{
// this suppose in / wei data type is int8x4
constexpr unsigned NVector = 4;
using accum_t = int32_t;
using vector_t = vector_type<TInWei, NVector>;
using vector_mem_t = typename vector_t::MemoryType;
constexpr index_t NVector = 4;
using accum_t = int32_t;
using vector_t = vector_type<TInWei, NVector>;
using vector_mem_t = typename vector_t::MemoryType;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -27,17 +27,17 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// vectorized input
auto in_nchw_vec_desc = make_ConstantTensorDescriptor(Sequence<N, C / NVector, Hi, Wi>{});
......@@ -96,84 +96,84 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
#if 0
// 3x3, 34x34, 128 thread, fp32, vector = 1
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3, 34x34, 128 thread, fp32, vector = 2
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3, 34x34, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 4;
constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 4;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t BlockSize = 128;
#elif 1
// 1x1, 32x32, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 16;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 4;
constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 16;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 4;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei,
......
......@@ -10,7 +10,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
unsigned nrepeat)
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -21,17 +21,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
......@@ -76,218 +76,218 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
#if 0
// for 3x3, 34x34
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 4;
constexpr unsigned InBlockCopy_ThreadPerDimH = 4;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif 0
// for 5x5, 36x36
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopy_ThreadPerDimC = 2;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 4;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t InBlockCopy_ThreadPerDimC = 2;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 2; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 2; // not used, yet
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 0
// for 7x7, 38x38
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 1;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4; // not used, yet
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 56x56
constexpr unsigned NPerBlock = 32;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 0
// for 1x1, 28x28
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 8;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif 1
// for 1x1, 14x14
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 8;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn<GridSize,
......
......@@ -12,7 +12,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
Tensor<T>& out_nkhw,
LowerPads,
UpperPads,
unsigned nrepeat)
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -23,17 +23,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
......@@ -77,177 +77,177 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 0
constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 1;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 1;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 1;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 1;
constexpr unsigned BlockSize = 8;
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 1;
constexpr index_t CPerBlock = 1;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 1;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 1;
constexpr index_t WeiBlockCopyThreadPerDim1 = 1;
constexpr index_t BlockSize = 8;
#elif 1
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 0
// for 5x5, 36x36
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 0
// for 7x7, 38x38
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 56x56
constexpr unsigned NPerBlock = 32;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 1
// 3x3 56x56, NKC = 16,256,128, with padding
// 3x3 28x28, NKC = 16,512,256, with padding
// 3x3 20x84, NKC = 16,256,256, with padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 2;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 64;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 2;
constexpr index_t WeiBlockCopyThreadPerDim1 = 64;
constexpr index_t BlockSize = 128;
#elif 0
// for 5x5 filter, 20x84 image, 1x1 padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 1;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 0
// for 1x1, 28x28
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded<GridSize,
......
......@@ -11,7 +11,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
unsigned nrepeat)
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -22,19 +22,19 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned N = in_nchw_desc.GetLength(I0);
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = in_nchw_desc.GetLength(I0);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1);
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
// convert in_nchw to in_cnhw
auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{});
......@@ -71,128 +71,158 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
#if 0
// 3x3, 34x34
// need to use register double buffer for GEMM
constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 8;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 1x1, 28x28, 64 threads
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 8;
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 64;
#elif 1
constexpr index_t BlockSize = 64;
#elif 0
// 1x1, 28x28, 128 threads, no lds-double-buffer
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 1x1, 28x28, 256 thread
constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 256;
#elif 1
// 1x1, 14x14, Vega 10
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 4;
constexpr unsigned GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 256;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
......@@ -208,7 +238,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
#if 1
......
......@@ -40,11 +40,11 @@ struct GeneratorTensor_Checkboard
template <class... Ts>
double operator()(Ts... Xs) const
{
std::array<unsigned long, sizeof...(Ts)> dims = {{Xs...}};
std::array<index_t, sizeof...(Ts)> dims = {{Xs...}};
return std::accumulate(dims.begin(),
dims.end(),
true,
[](bool init, unsigned long x) -> int { return init != (x % 2); })
[](bool init, index_t x) -> int { return init != (x % 2); })
? 1
: -1;
}
......@@ -80,9 +80,9 @@ auto make_TensorDescriptor(TConstTensorDesc)
constexpr auto I3 = Number<3>{};
constexpr auto desc = TConstTensorDesc{};
std::initializer_list<unsigned> lengths = {
std::initializer_list<index_t> lengths = {
desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3)};
std::initializer_list<unsigned> strides = {
std::initializer_list<index_t> strides = {
desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)};
return TensorDescriptor(lengths, strides);
......@@ -95,11 +95,11 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
LowerPads,
UpperPads)
{
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
auto f = [&](auto n, auto k, auto ho, auto wo) {
double v = 0;
......@@ -153,11 +153,11 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
std::size_t HO = out_nkhw.mDesc.GetLengths()[2];
std::size_t WO = out_nkhw.mDesc.GetLengths()[3];
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
std::size_t HiPerTile = HoPerTile + Y - 1;
std::size_t WiPerTile = WoPerTile + X - 1;
......@@ -399,211 +399,211 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[])
{
#if 0
constexpr unsigned N = 1;
constexpr unsigned C = 1;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 1;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 1;
constexpr index_t C = 1;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 1;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 34x34
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 34;
constexpr unsigned WI = 34;
constexpr unsigned K = 64;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 34;
constexpr index_t WI = 34;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 56x56
constexpr unsigned N = 64;
constexpr unsigned C = 64;
constexpr unsigned HI = 56;
constexpr unsigned WI = 56;
constexpr unsigned K = 64;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 3x3, 58x58
constexpr unsigned N = 64;
constexpr unsigned C = 64;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 64;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 5x5, 36x36
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 36;
constexpr unsigned WI = 36;
constexpr unsigned K = 64;
constexpr unsigned Y = 5;
constexpr unsigned X = 5;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 36;
constexpr index_t WI = 36;
constexpr index_t K = 64;
constexpr index_t Y = 5;
constexpr index_t X = 5;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 7x7, 38x38
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 38;
constexpr unsigned WI = 38;
constexpr unsigned K = 64;
constexpr unsigned Y = 7;
constexpr unsigned X = 7;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 38;
constexpr index_t WI = 38;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 7;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 58x58
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 3x3 filter, 58x58 image, 0x0 padding
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3 filter, 56x56 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 56;
constexpr unsigned WI = 56;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 3x3 filter, 28x28 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 512;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 1x1 filter, 28x28 image
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3 filter, 20x84 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 20;
constexpr unsigned WI = 84;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 20;
constexpr index_t WI = 84;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 3x3 filter, 112x112 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 64;
constexpr unsigned HI = 112;
constexpr unsigned WI = 112;
constexpr unsigned K = 128;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t N = 16;
constexpr index_t C = 64;
constexpr index_t HI = 112;
constexpr index_t WI = 112;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 5x5 filter, 20x86 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 20;
constexpr unsigned WI = 86;
constexpr unsigned K = 512;
constexpr unsigned Y = 5;
constexpr unsigned X = 5;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 20;
constexpr index_t WI = 86;
constexpr index_t K = 512;
constexpr index_t Y = 5;
constexpr index_t X = 5;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
constexpr unsigned N = 16;
constexpr unsigned C = 192;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 32;
constexpr unsigned Y = 5;
constexpr unsigned X = 5;
constexpr unsigned HPad = 2;
constexpr unsigned WPad = 2;
constexpr index_t N = 16;
constexpr index_t C = 192;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 32;
constexpr index_t Y = 5;
constexpr index_t X = 5;
constexpr index_t HPad = 2;
constexpr index_t WPad = 2;
#elif 0
// 1x1 filter, 32x32 image
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 32;
constexpr unsigned WI = 32;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 32;
constexpr index_t WI = 32;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 1x1 filter, 14x14 image
constexpr unsigned N = 128;
constexpr unsigned C = 2048;
constexpr unsigned HI = 14;
constexpr unsigned WI = 14;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
// 1x1 filter, 14x14 image, C = 2048
constexpr index_t N = 128;
constexpr index_t C = 2048;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
// 1x1 filter, 14x14 image, C = 512
constexpr unsigned N = 128;
constexpr unsigned C = 512;
constexpr unsigned HI = 14;
constexpr unsigned WI = 14;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#endif
auto lower_pads = Sequence<HPad, WPad>{};
......@@ -634,7 +634,7 @@ int main(int argc, char* argv[])
}
bool do_verification = atoi(argv[1]);
unsigned nrepeat = atoi(argv[2]);
index_t nrepeat = atoi(argv[2]);
if(do_verification)
{
......
#pragma once
template <class TData, unsigned NSize>
template <class TData, index_t NSize>
struct Array
{
using Type = Array<TData, NSize>;
static constexpr unsigned nSize = NSize;
static constexpr index_t nSize = NSize;
unsigned mData[nSize];
index_t mData[nSize];
template <class... Xs>
__host__ __device__ Array(Xs... xs) : mData{static_cast<TData>(xs)...}
{
}
__host__ __device__ TData operator[](unsigned i) const { return mData[i]; }
__host__ __device__ TData operator[](index_t i) const { return mData[i]; }
};
#pragma once
#include "common.hip.hpp"
template <unsigned NRow_, unsigned NCol_, unsigned RowStride_>
template <index_t NRow_, index_t NCol_, index_t RowStride_>
struct ConstantMatrixDescriptor
{
__host__ __device__ constexpr ConstantMatrixDescriptor()
......@@ -9,24 +9,28 @@ struct ConstantMatrixDescriptor
static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!");
}
__host__ __device__ constexpr unsigned NRow() const { return NRow_; }
__host__ __device__ constexpr index_t NRow() const { return NRow_; }
__host__ __device__ constexpr unsigned NCol() const { return NCol_; }
__host__ __device__ constexpr index_t NCol() const { return NCol_; }
__host__ __device__ constexpr unsigned RowStride() const { return RowStride_; }
__host__ __device__ constexpr index_t RowStride() const { return RowStride_; }
__host__ __device__ constexpr auto GetLengths() const { return Sequence<NRow_, NCol_>{}; }
__host__ __device__ constexpr unsigned GetElementSize() const { return NRow_ * NCol_; }
__host__ __device__ constexpr index_t GetElementSize() const { return NRow_ * NCol_; }
__host__ __device__ constexpr unsigned GetElementSpace() const { return NRow_ * RowStride_; }
__host__ __device__ constexpr index_t GetElementSpace() const { return NRow_ * RowStride_; }
__host__ __device__ unsigned Get1dIndex(unsigned irow, unsigned icol) const
__host__ __device__ index_t Get1dIndex(index_t irow, index_t icol) const
{
#if DEVICE_BACKEND_HIP
return __mul24(irow, RowStride_) + icol;
#else
return irow * RowStride_ + icol;
#endif
}
template <unsigned SubNRow, unsigned SubNCol>
template <index_t SubNRow, index_t SubNCol>
__host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
Number<SubNCol>) const
{
......@@ -34,13 +38,13 @@ struct ConstantMatrixDescriptor
}
};
template <unsigned NRow, unsigned NCol>
template <index_t NRow, index_t NCol>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>)
{
return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
}
template <unsigned NRow, unsigned NCol, unsigned RowStride>
template <index_t NRow, index_t NCol, index_t RowStride>
__host__ __device__ constexpr auto
make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>)
{
......
......@@ -2,35 +2,35 @@
#include "common.hip.hpp"
// this is ugly, only for 2d
template <unsigned L0, unsigned L1>
template <index_t L0, index_t L1>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>)
{
return Sequence<L1, 1>{};
}
// this is ugly, only for 4d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3>
template <index_t L0, index_t L1, index_t L2, index_t L3>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
{
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
}
// this is ugly, only for 6d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned L4, unsigned L5>
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t L4, index_t L5>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5>)
{
return Sequence<L1 * L2 * L3 * L4 * L5, L2 * L3 * L4 * L5, L3 * L4 * L5, L4 * L5, L5, 1>{};
}
// this is ugly, only for 8d
template <unsigned L0,
unsigned L1,
unsigned L2,
unsigned L3,
unsigned L4,
unsigned L5,
unsigned L6,
unsigned L7>
template <index_t L0,
index_t L1,
index_t L2,
index_t L3,
index_t L4,
index_t L5,
index_t L6,
index_t L7>
__host__ __device__ constexpr auto
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7>)
{
......@@ -45,48 +45,48 @@ __host__ __device__ constexpr auto
}
// this is ugly, only for 2d
template <unsigned L0, unsigned L1, unsigned Align>
template <index_t L0, index_t L1, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1>,
Number<Align>)
{
constexpr unsigned L1_align = Align * ((L1 + Align - 1) / Align);
constexpr index_t L1_align = Align * ((L1 + Align - 1) / Align);
return Sequence<L1_align, 1>{};
}
// this is ugly, only for 4d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned Align>
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2, L3>,
Number<Align>)
{
constexpr unsigned L3_align = Align * ((L3 + Align - 1) / Align);
constexpr index_t L3_align = Align * ((L3 + Align - 1) / Align);
return Sequence<L1 * L2 * L3_align, L2 * L3_align, L3_align, 1>{};
}
template <class Lengths, class Strides>
struct ConstantTensorDescriptor
{
using Type = ConstantTensorDescriptor<Lengths, Strides>;
static constexpr unsigned nDim = Lengths::nDim;
using Type = ConstantTensorDescriptor<Lengths, Strides>;
static constexpr index_t nDim = Lengths::nDim;
__host__ __device__ constexpr ConstantTensorDescriptor()
{
static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
}
__host__ __device__ constexpr unsigned GetDimension() const { return nDim; }
__host__ __device__ constexpr index_t GetDimension() const { return nDim; }
__host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; }
__host__ __device__ constexpr Strides GetStrides() const { return Strides{}; }
template <unsigned I>
__host__ __device__ constexpr unsigned GetLength(Number<I>) const
template <index_t I>
__host__ __device__ constexpr index_t GetLength(Number<I>) const
{
return Lengths{}.Get(Number<I>{});
}
template <unsigned I>
__host__ __device__ constexpr unsigned GetStride(Number<I>) const
template <index_t I>
__host__ __device__ constexpr index_t GetStride(Number<I>) const
{
return Strides{}.Get(Number<I>{});
}
......@@ -95,18 +95,18 @@ struct ConstantTensorDescriptor
struct GetElementSize_f
{
template <class IDim>
__host__ __device__ constexpr unsigned operator()(IDim idim) const
__host__ __device__ constexpr index_t operator()(IDim idim) const
{
return Type{}.GetLength(idim);
}
};
__host__ __device__ constexpr unsigned GetElementSize() const
__host__ __device__ constexpr index_t GetElementSize() const
{
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct multiply
{
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
{
return a * b;
}
......@@ -119,19 +119,19 @@ struct ConstantTensorDescriptor
struct GetElementSpace_f
{
template <class IDim>
__host__ __device__ constexpr unsigned operator()(IDim idim) const
__host__ __device__ constexpr index_t operator()(IDim idim) const
{
return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim);
}
};
template <class Align = Number<1>>
__host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const
__host__ __device__ constexpr index_t GetElementSpace(Align align = Align{}) const
{
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct add
{
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
{
return a + b;
}
......@@ -141,17 +141,21 @@ struct ConstantTensorDescriptor
}
template <class... Is>
__host__ __device__ unsigned Get1dIndex(Is... is) const
__host__ __device__ index_t Get1dIndex(Is... is) const
{
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
const auto multi_id = Array<unsigned, nDim>(is...);
const auto multi_id = Array<index_t, nDim>(is...);
unsigned id = 0;
index_t id = 0;
static_loop_n<nDim>{}([&](auto IDim) {
constexpr unsigned idim = IDim.Get();
constexpr index_t idim = IDim.Get();
#if DEVICE_BACKEND_HIP
id += __mul24(multi_id[idim], GetStride(IDim));
#else
id += multi_id[idim] * GetStride(IDim);
#endif
});
return id;
......@@ -163,7 +167,7 @@ struct ConstantTensorDescriptor
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
}
template <unsigned IDim, unsigned NVector>
template <index_t IDim, index_t NVector>
__host__ __device__ constexpr auto Vectorize(Number<IDim>, Number<NVector>) const
{
assert(false); // not implemented
......@@ -183,7 +187,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride
return ConstantTensorDescriptor<Lengths, Strides>{};
}
template <class Lengths, unsigned Align>
template <class Lengths, index_t Align>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
{
using Strides = decltype(calculate_default_strides_aligned(Lengths{}, Number<Align>{}));
......@@ -193,8 +197,8 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths
template <class TDesc>
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
{
constexpr auto desc = TDesc{};
constexpr unsigned ndim = desc.GetDimension();
constexpr auto desc = TDesc{};
constexpr index_t ndim = desc.GetDimension();
static_assert(ndim >= 2 && ndim <= 8, "wrong!");
......
......@@ -2,38 +2,38 @@
#include "constant_integral.hip.hpp"
#include "functional.hip.hpp"
template <unsigned... Is>
template <index_t... Is>
struct Sequence
{
using Type = Sequence<Is...>;
static constexpr unsigned nDim = sizeof...(Is);
static constexpr index_t nDim = sizeof...(Is);
const unsigned mData[nDim] = {Is...};
const index_t mData[nDim] = {Is...};
template <unsigned I>
__host__ __device__ constexpr unsigned Get(Number<I>) const
template <index_t I>
__host__ __device__ constexpr index_t Get(Number<I>) const
{
return mData[I];
}
// this is ugly, only for nDIm = 4
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
template <index_t I0, index_t I1, index_t I2, index_t I3>
__host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const
{
static_assert(nDim == 4, "nDim != 4");
constexpr auto old_sequence = Type{};
constexpr unsigned NR0 = old_sequence.mData[I0];
constexpr unsigned NR1 = old_sequence.mData[I1];
constexpr unsigned NR2 = old_sequence.mData[I2];
constexpr unsigned NR3 = old_sequence.mData[I3];
constexpr index_t NR0 = old_sequence.mData[I0];
constexpr index_t NR1 = old_sequence.mData[I1];
constexpr index_t NR2 = old_sequence.mData[I2];
constexpr index_t NR3 = old_sequence.mData[I3];
return Sequence<NR0, NR1, NR2, NR3>{};
}
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
template <index_t I0, index_t I1, index_t I2, index_t I3>
__host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence<I0, I1, I2, I3>) const
{
// don't know how to implement this
......@@ -41,7 +41,7 @@ struct Sequence
assert(false);
}
template <unsigned I>
template <index_t I>
__host__ __device__ constexpr auto PushBack(Number<I>) const
{
return Sequence<Is..., I>{};
......@@ -56,14 +56,14 @@ struct Sequence
}
};
template <unsigned... Is, unsigned I>
template <index_t... Is, index_t I>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
{
static_assert(sizeof...(Is) >= 1, "empty Sequence!");
return Sequence<Is...>{};
}
template <class F, unsigned... Xs, unsigned... Ys>
template <class F, index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequence<Ys...>, F f)
{
static_assert(Sequence<Xs...>::nDim == Sequence<Ys...>::nDim, "Dim not the same");
......@@ -71,12 +71,12 @@ __host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequenc
return Sequence<f(Xs, Ys)...>{};
}
template <unsigned... Xs, unsigned... Ys>
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequence<Ys...>)
{
struct add
{
__host__ __device__ constexpr unsigned operator()(unsigned x, unsigned y) const
__host__ __device__ constexpr index_t operator()(index_t x, index_t y) const
{
return x + y;
}
......@@ -85,7 +85,7 @@ __host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequen
return sequence_sequence_op(Sequence<Xs...>{}, Sequence<Ys...>{}, add{});
}
template <unsigned... Is>
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
{
return sequence_pop_back(Type{});
......
#pragma once
#include "ConstantTensorDescriptor.hip.hpp"
template <unsigned BlockSize, class Float, class DstDesc, class F>
template <index_t BlockSize, class Float, class DstDesc, class F>
__device__ void
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
{
......@@ -20,19 +20,19 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
}
#endif
constexpr unsigned NLoop = desc.GetElementSize() / BlockSize;
constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
unsigned is = threadIdx.x + iloop * BlockSize;
index_t is = threadIdx.x + iloop * BlockSize;
const unsigned did0 = is / desc.GetStride(I0);
const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0);
const unsigned did1 = is / desc.GetStride(I1);
const index_t did1 = is / desc.GetStride(I1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
f(p_dst[dindex]);
}
......@@ -41,17 +41,17 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
if(has_tail)
{
unsigned is = threadIdx.x + NLoop * BlockSize;
index_t is = threadIdx.x + NLoop * BlockSize;
if(is < desc.GetElementSize())
{
const unsigned did0 = is / desc.GetStride(I0);
const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0);
const unsigned did1 = is / desc.GetStride(I1);
const index_t did1 = is / desc.GetStride(I1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
f(p_dst[dindex]);
}
......@@ -61,7 +61,7 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
......@@ -80,20 +80,20 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
unsigned is = threadIdx.x + iloop * BlockSize;
index_t is = threadIdx.x + iloop * BlockSize;
unsigned did[2];
index_t did[2];
did[0] = is / ref_desc.GetStride(I0);
......@@ -101,9 +101,9 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
did[1] = is / ref_desc.GetStride(I1);
const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]);
const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
f(p_src[aindex], p_dst[bindex]);
}
......@@ -112,11 +112,11 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
if(has_tail)
{
unsigned is = threadIdx.x + NLoop * BlockSize;
index_t is = threadIdx.x + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
unsigned did[2];
index_t did[2];
did[0] = is / ref_desc.GetStride(I0);
......@@ -124,16 +124,16 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
did[1] = is / ref_desc.GetStride(I1);
const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]);
const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
f(p_src[aindex], p_dst[bindex]);
}
}
}
template <unsigned BlockSize, class Float, class DstDesc>
template <index_t BlockSize, class Float, class DstDesc>
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
{
auto f_set_zero = [](Float& v) { v = Float(0); };
......@@ -141,7 +141,7 @@ __device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
}
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
......@@ -161,7 +161,7 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
}
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
template <index_t BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
struct Blockwise2dTensorCopy1
{
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
......@@ -175,17 +175,17 @@ struct Blockwise2dTensorCopy1
// need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class SrcOpLengths,
unsigned ThreadPerDim0,
unsigned ThreadPerDim1>
index_t ThreadPerDim0,
index_t ThreadPerDim1>
struct Blockwise2dTensorCopy2
{
unsigned mThreadId0;
unsigned mThreadId1;
index_t mThreadId0;
index_t mThreadId1;
__device__ Blockwise2dTensorCopy2()
{
......@@ -222,61 +222,61 @@ struct Blockwise2dTensorCopy2
constexpr bool align_v2 =
src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;
constexpr unsigned L0 = SrcOpLengths{}.Get(I0);
constexpr unsigned L1 = SrcOpLengths{}.Get(I1);
constexpr index_t L0 = SrcOpLengths{}.Get(I0);
constexpr index_t L1 = SrcOpLengths{}.Get(I1);
constexpr unsigned Dim0Loop = L0 / ThreadPerDim0;
constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
constexpr index_t Dim0Loop = L0 / ThreadPerDim0;
constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
constexpr unsigned Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
constexpr index_t Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
constexpr unsigned Dim1V2Loop =
constexpr index_t Dim1V2Loop =
align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;
constexpr unsigned Dim1V1Loop =
constexpr index_t Dim1V1Loop =
(L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
ThreadPerDim1;
constexpr bool d1_has_tail =
(L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));
for(unsigned d0loop = 0; d0loop < Dim0Loop; ++d0loop)
for(index_t d0loop = 0; d0loop < Dim0Loop; ++d0loop)
{
unsigned did0 = d0loop * ThreadPerDim0 + mThreadId0;
index_t did0 = d0loop * ThreadPerDim0 + mThreadId0;
// v4
for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
{
unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float4*>(p_dst + dindex)) =
*(reinterpret_cast<const Float4*>(p_src + sindex));
}
// v2
for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
{
unsigned did1 =
index_t did1 =
Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float2*>(p_dst + dindex)) =
*(reinterpret_cast<const Float2*>(p_src + sindex));
}
// v1
for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
{
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
d1v1loop * ThreadPerDim1 + mThreadId1;
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
d1v1loop * ThreadPerDim1 + mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
p_dst[dindex] = p_src[sindex];
}
......@@ -284,13 +284,13 @@ struct Blockwise2dTensorCopy2
// dim-1 tail
if(d1_has_tail)
{
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
if(did1 < L1)
{
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
p_dst[dindex] = p_src[sindex];
}
......@@ -300,45 +300,44 @@ struct Blockwise2dTensorCopy2
// dim-0 tail
if(d0_has_tail)
{
unsigned did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
index_t did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
if(did0 < L0)
{
// v4
for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
{
unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float4*>(p_dst + dindex)) =
*(reinterpret_cast<const Float4*>(p_src + sindex));
}
// v2
for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
{
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
2 * mThreadId1;
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
2 * mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float2*>(p_dst + dindex)) =
*(reinterpret_cast<const Float2*>(p_src + sindex));
}
// v1
for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
{
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 +
Dim1V2Loop * 2 * ThreadPerDim1 + d1v1loop * ThreadPerDim1 +
mThreadId1;
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
d1v1loop * ThreadPerDim1 + mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
p_dst[dindex] = p_src[sindex];
}
......@@ -346,14 +345,13 @@ struct Blockwise2dTensorCopy2
// tail
if(d1_has_tail)
{
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 +
Dim1V2Loop * 2 * ThreadPerDim1 + Dim1V1Loop * ThreadPerDim1 +
mThreadId1;
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
if(did1 < L1)
{
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
p_dst[dindex] = p_src[sindex];
}
......@@ -365,18 +363,18 @@ struct Blockwise2dTensorCopy2
// starting point need to be aligned to float4 or float2 or float
// stride1 need to be 1 for both source and destination
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
unsigned DataPerRead>
index_t DataPerRead>
struct Blockwise2dTensorCopy3
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
unsigned mSrcMyThreadOffset;
unsigned mDstMyThreadOffset;
index_t mSrcMyThreadOffset;
index_t mDstMyThreadOffset;
__device__ Blockwise2dTensorCopy3()
{
......@@ -394,11 +392,11 @@ struct Blockwise2dTensorCopy3
DstDesc{}.GetStride(I0) % DataPerRead == 0,
"src and dst stride should be multiple of DataPerRead to keep alignment");
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
// we allow out-of-bound read from src in D1 dimension,
// but we need to make sure dst stride is big enough,
......@@ -408,7 +406,7 @@ struct Blockwise2dTensorCopy3
static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n");
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
......@@ -418,8 +416,8 @@ struct Blockwise2dTensorCopy3
}
}
const unsigned thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
const unsigned thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
const index_t thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
const index_t thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
......@@ -430,13 +428,13 @@ struct Blockwise2dTensorCopy3
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
......@@ -446,18 +444,18 @@ struct Blockwise2dTensorCopy3
}
}
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
auto f_copy = [&](unsigned iloop) {
auto f_copy = [&](index_t iloop) {
*(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
*(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
iloop * src_loop_stride));
};
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
{
f_copy(iloop);
}
......@@ -466,7 +464,7 @@ struct Blockwise2dTensorCopy3
if(has_tail_d0)
{
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
......@@ -475,18 +473,18 @@ struct Blockwise2dTensorCopy3
}
}
__device__ constexpr unsigned GetRegisterClipboardSize() const
__device__ constexpr index_t GetRegisterClipboardSize() const
{
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
}
......@@ -497,13 +495,13 @@ struct Blockwise2dTensorCopy3
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
......@@ -513,18 +511,18 @@ struct Blockwise2dTensorCopy3
}
}
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
auto f_copy = [&](unsigned iloop) {
auto f_copy = [&](index_t iloop) {
*(reinterpret_cast<vector_t*>(p_clipboard + iloop * 4)) =
*(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
iloop * src_loop_stride));
};
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
{
f_copy(iloop);
}
......@@ -533,7 +531,7 @@ struct Blockwise2dTensorCopy3
if(has_tail_d0)
{
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
......@@ -548,13 +546,13 @@ struct Blockwise2dTensorCopy3
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
......@@ -564,17 +562,17 @@ struct Blockwise2dTensorCopy3
}
}
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
auto f_copy = [&](unsigned iloop) {
auto f_copy = [&](index_t iloop) {
*(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
*(reinterpret_cast<const vector_t*>(p_clipboard + iloop * 4));
};
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
{
f_copy(iloop);
}
......@@ -583,7 +581,7 @@ struct Blockwise2dTensorCopy3
if(has_tail_d0)
{
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
......
#pragma once
#include "ConstantTensorDescriptor.hip.hpp"
template <unsigned BlockSize, class Float, class DstDesc, class F>
template <index_t BlockSize, class Float, class DstDesc, class F>
__device__ void
blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
{
......@@ -22,27 +22,27 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
}
#endif
constexpr unsigned NLoop = desc.GetElementSize() / BlockSize;
constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
unsigned is = threadIdx.x + iloop * BlockSize;
index_t is = threadIdx.x + iloop * BlockSize;
const unsigned did0 = is / desc.GetStride(I0);
const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0);
const unsigned did1 = is / desc.GetStride(I1);
const index_t did1 = is / desc.GetStride(I1);
is -= did1 * desc.GetStride(I1);
const unsigned did2 = is / desc.GetStride(I2);
const index_t did2 = is / desc.GetStride(I2);
is -= did2 * desc.GetStride(I2);
const unsigned did3 = is / desc.GetStride(I3);
const index_t did3 = is / desc.GetStride(I3);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
const index_t dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
f(p_dst[dindex]);
}
......@@ -51,25 +51,25 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
if(has_tail)
{
unsigned is = threadIdx.x + NLoop * BlockSize;
index_t is = threadIdx.x + NLoop * BlockSize;
if(is < desc.GetElementSize())
{
const unsigned did0 = is / desc.GetStride(I0);
const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0);
const unsigned did1 = is / desc.GetStride(I1);
const index_t did1 = is / desc.GetStride(I1);
is -= did1 * desc.GetStride(I1);
const unsigned did2 = is / desc.GetStride(I2);
const index_t did2 = is / desc.GetStride(I2);
is -= did2 * desc.GetStride(I2);
const unsigned did3 = is / desc.GetStride(I3);
const index_t did3 = is / desc.GetStride(I3);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
const index_t dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
f(p_dst[dindex]);
}
......@@ -79,7 +79,7 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
......@@ -100,22 +100,22 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
constexpr unsigned IR2 = DstFromSrcReorder{}.Get(I2);
constexpr unsigned IR3 = DstFromSrcReorder{}.Get(I3);
constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
constexpr index_t IR2 = DstFromSrcReorder{}.Get(I2);
constexpr index_t IR3 = DstFromSrcReorder{}.Get(I3);
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
unsigned is = threadIdx.x + iloop * BlockSize;
index_t is = threadIdx.x + iloop * BlockSize;
unsigned did[4];
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
......@@ -131,9 +131,9 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
did[3] = is / ref_desc.GetStride(I3);
const unsigned src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
const index_t src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
const unsigned dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
const index_t dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
f(p_src[src_index], p_dst[dst_index]);
}
......@@ -142,11 +142,11 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
if(has_tail)
{
unsigned is = threadIdx.x + NLoop * BlockSize;
index_t is = threadIdx.x + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
unsigned did[4];
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
......@@ -162,16 +162,16 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
did[3] = is / ref_desc.GetStride(I3);
const unsigned src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
const index_t src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
const unsigned dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
const index_t dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
f(p_src[src_index], p_dst[dst_index]);
}
}
}
template <unsigned BlockSize, class Float, class DstDesc>
template <index_t BlockSize, class Float, class DstDesc>
__device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
{
auto f_set_zero = [](Float& v) { v = Float(0); };
......@@ -179,7 +179,7 @@ __device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
blockwise_4d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
}
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
......@@ -199,12 +199,12 @@ blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
}
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
unsigned DataPerRead>
index_t DataPerRead>
struct Blockwise4dTensorCopy1
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
......@@ -230,8 +230,8 @@ struct Blockwise4dTensorCopy1
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride2 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr unsigned L3 = CopyLengths{}.Get(I3);
constexpr unsigned read_per_d3 = integer_divide_ceil(L3, DataPerRead);
constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr index_t read_per_d3 = integer_divide_ceil(L3, DataPerRead);
static_assert(read_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
"wrong! out-of-bound write will contaminate next line!\n");
......@@ -247,20 +247,20 @@ struct Blockwise4dTensorCopy1
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr unsigned L2 = CopyLengths{}.Get(I2);
constexpr unsigned L3 = CopyLengths{}.Get(I3);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr unsigned read_per_d3 = integer_divide_ceil(L3, DataPerRead);
constexpr index_t read_per_d3 = integer_divide_ceil(L3, DataPerRead);
constexpr auto ref_desc =
make_ConstantTensorDescriptor(Sequence<L0, L1, L2, read_per_d3>{});
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
auto f_copy = [&](unsigned is) {
unsigned did[4];
auto f_copy = [&](index_t is) {
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
......@@ -276,18 +276,18 @@ struct Blockwise4dTensorCopy1
did[3] = is / ref_desc.GetStride(I3);
const unsigned src_index =
const index_t src_index =
src_desc.Get1dIndex(did[0], did[1], did[2], did[3] * DataPerRead);
const unsigned dst_index =
const index_t dst_index =
dst_desc.Get1dIndex(did[0], did[1], did[2], did[3] * DataPerRead);
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src + src_index));
};
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
unsigned is = threadIdx.x + iloop * BlockSize;
index_t is = threadIdx.x + iloop * BlockSize;
f_copy(is);
}
......@@ -296,7 +296,7 @@ struct Blockwise4dTensorCopy1
if(has_tail)
{
unsigned is = threadIdx.x + NLoop * BlockSize;
index_t is = threadIdx.x + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
......@@ -306,7 +306,7 @@ struct Blockwise4dTensorCopy1
}
};
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
......@@ -315,15 +315,15 @@ template <unsigned BlockSize,
struct BlockwiseChwnTensorCopyPadded
{
__device__ void Run(const Float* __restrict__ p_src,
unsigned c_block_data_begin,
unsigned ho_block_data_begin,
unsigned wo_block_data_begin,
unsigned n_block_data_begin,
index_t c_block_data_begin,
index_t ho_block_data_begin,
index_t wo_block_data_begin,
index_t n_block_data_begin,
Float* __restrict__ p_dst,
unsigned h_block_pad_low,
unsigned w_block_pad_low,
unsigned h_block_pad_up,
unsigned w_block_pad_up) const
index_t h_block_pad_low,
index_t w_block_pad_low,
index_t h_block_pad_up,
index_t w_block_pad_up) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -337,7 +337,7 @@ struct BlockwiseChwnTensorCopyPadded
constexpr auto h_global_pad_low = GlobalLowerPads{}.Get(I0);
constexpr auto w_global_pad_low = GlobalLowerPads{}.Get(I1);
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
const Float* p_src_tmp =
p_src +
......@@ -368,11 +368,11 @@ struct BlockwiseChwnTensorCopyPadded
}
#endif
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
unsigned is = threadIdx.x + iloop * BlockSize;
index_t is = threadIdx.x + iloop * BlockSize;
unsigned did[4];
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
......@@ -388,7 +388,7 @@ struct BlockwiseChwnTensorCopyPadded
did[3] = is / ref_desc.GetStride(I3);
const unsigned bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
const index_t bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
p_dst[bindex] =
(did[1] < h_block_pad_low || did[1] + h_block_pad_up >= ref_desc.GetLength(I1) ||
......@@ -401,11 +401,11 @@ struct BlockwiseChwnTensorCopyPadded
if(has_tail)
{
unsigned is = threadIdx.x + NLoop * BlockSize;
index_t is = threadIdx.x + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
unsigned did[4];
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
......@@ -421,7 +421,7 @@ struct BlockwiseChwnTensorCopyPadded
did[3] = is / ref_desc.GetStride(I3);
const unsigned bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
const index_t bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
p_dst[bindex] =
(did[1] < h_block_pad_low ||
......@@ -436,19 +436,19 @@ struct BlockwiseChwnTensorCopyPadded
// starting point need to be aligned to float4 or float2 or float
// stride3 need to be 1 for both source and destination
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
class ThreadPerDims,
unsigned DataPerRead>
index_t DataPerRead>
struct Blockwise4dTensorCopy3
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
unsigned mSrcMyThreadOffset;
unsigned mDstMyThreadOffset;
index_t mSrcMyThreadOffset;
index_t mDstMyThreadOffset;
__device__ Blockwise4dTensorCopy3()
{
......@@ -469,20 +469,20 @@ struct Blockwise4dTensorCopy3
DstDesc{}.GetStride(I2) % DataPerRead == 0,
"wrong! src and dst stride2 should be multiple of DataPerRead to keep alignment");
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr unsigned L2 = CopyLengths{}.Get(I2);
constexpr unsigned L3 = CopyLengths{}.Get(I3);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr unsigned thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr unsigned thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr unsigned thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr unsigned thread_per_d3 = ThreadPerDims{}.Get(I3);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr unsigned nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
static_assert(nloop_d3 * thread_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
"wrong! out-of-bound write will contaminate next line!\n");
......@@ -493,7 +493,7 @@ struct Blockwise4dTensorCopy3
static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3,
"wrrong! BlockSize is not big enough for ThreadPerDims!");
constexpr unsigned num_active_thread =
constexpr index_t num_active_thread =
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
if(BlockSize > num_active_thread)
......@@ -504,14 +504,14 @@ struct Blockwise4dTensorCopy3
}
}
const unsigned thread_id_d0 =
const index_t thread_id_d0 =
get_thread_local_1d_id() / (thread_per_d1 * thread_per_d2 * thread_per_d3);
unsigned itmp = get_thread_local_1d_id() -
thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3);
const unsigned thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3);
index_t itmp = get_thread_local_1d_id() -
thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3);
const index_t thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3);
itmp -= thread_id_d1 * (thread_per_d2 * thread_per_d3);
const unsigned thread_id_d2 = itmp / thread_per_d3;
const unsigned thread_id_d3 = itmp - thread_id_d2 * thread_per_d3;
const index_t thread_id_d2 = itmp / thread_per_d3;
const index_t thread_id_d3 = itmp - thread_id_d2 * thread_per_d3;
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(
thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead);
......@@ -526,17 +526,17 @@ struct Blockwise4dTensorCopy3
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr unsigned L2 = CopyLengths{}.Get(I2);
constexpr unsigned L3 = CopyLengths{}.Get(I3);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr unsigned thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr unsigned thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr unsigned thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr unsigned thread_per_d3 = ThreadPerDims{}.Get(I3);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
constexpr unsigned num_active_thread =
constexpr index_t num_active_thread =
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
if(BlockSize > num_active_thread)
......@@ -547,30 +547,30 @@ struct Blockwise4dTensorCopy3
}
}
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
constexpr unsigned nloop_d1 = L1 / thread_per_d1;
constexpr unsigned nloop_d2 = L2 / thread_per_d2;
constexpr unsigned nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = L2 / thread_per_d2;
constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
#pragma unroll
for(unsigned iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
{
#pragma unroll
for(unsigned iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
{
#pragma unroll
for(unsigned iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
{
#pragma unroll
for(unsigned iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
{
const unsigned src_offset =
const index_t src_offset =
SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2,
iloop_d3 * thread_per_d3 * DataPerRead);
const unsigned dst_offset =
const index_t dst_offset =
DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2,
......
#pragma once
#include "threadwise_gemm.hip.hpp"
template <unsigned BlockSize,
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
bool TransA,
bool TransB,
bool TransC,
unsigned BlockMatrixStrideA,
unsigned BlockMatrixStrideB,
unsigned ThreadMatrixStrideC,
unsigned BatchSize,
unsigned BatchPerThread,
unsigned KPerThreadLoop,
index_t BlockMatrixStrideA,
index_t BlockMatrixStrideB,
index_t ThreadMatrixStrideC,
index_t BatchSize,
index_t BatchPerThread,
index_t KPerThreadLoop,
bool DistributeThreadAlongColumnFirst>
struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
{
unsigned mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0;
index_t mMyThreadOffsetA = 0;
index_t mMyThreadOffsetB = 0;
struct MatrixIndex
{
unsigned batch;
unsigned row;
unsigned col;
index_t batch;
index_t row;
index_t col;
};
__device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC()
......@@ -61,7 +61,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
#endif
}
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
{
if(TransA && (!TransB) && (!TransC))
......@@ -72,22 +72,22 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! k dimension not consistent!");
constexpr unsigned MPerBlock = a_block_mtx.NCol();
constexpr unsigned NPerBlock = b_block_mtx.NCol();
constexpr index_t MPerBlock = a_block_mtx.NCol();
constexpr index_t NPerBlock = b_block_mtx.NCol();
constexpr auto c_thread_mtx = ThreadMatrixC{};
// divide thread work
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert(BatchSize % BatchPerThread == 0, "BatchSize % BatchPerThread != 0");
static_assert(MPerBlock % MPerThread == 0, "MPerBlock % MPerThread != 0");
static_assert(NPerBlock % NPerThread == 0, "NPerBlock % NPerThread != 0");
constexpr unsigned BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread;
constexpr unsigned MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread;
constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
constexpr index_t BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread;
constexpr index_t MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread;
constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
static_assert(BlockSize == BatchThreadWork * MThreadWork * NThreadWork,
"wrong! wrong BlockSize");
......@@ -95,10 +95,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
if(DistributeThreadAlongColumnFirst)
{
// num of operations can be reduced
const unsigned b_work_id = thread_id / (MThreadWork * NThreadWork);
unsigned itmp = thread_id - b_work_id * (MThreadWork * NThreadWork);
const unsigned m_work_id = itmp / NThreadWork;
const unsigned n_work_id = itmp - m_work_id * NThreadWork;
const index_t b_work_id = thread_id / (MThreadWork * NThreadWork);
index_t itmp = thread_id - b_work_id * (MThreadWork * NThreadWork);
const index_t m_work_id = itmp / NThreadWork;
const index_t n_work_id = itmp - m_work_id * NThreadWork;
return MatrixIndex{
b_work_id * BatchPerThread, m_work_id * MPerThread, n_work_id * NPerThread};
......@@ -118,7 +118,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
// this should be optimized away if input is known
__device__ static MatrixIndex
GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c)
GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
{
return MatrixIndex{batch_in_c, m_in_c, n_in_c};
}
......@@ -138,10 +138,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// a is transposed, b is not
constexpr auto a_thread_mtx =
......@@ -154,7 +154,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
// loop over k
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
// read first batch of a, b
threadwise_matrix_copy(a_block_mtx,
......@@ -172,7 +172,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
b_thread_mtx.GetLengths());
// loop over batch
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib)
for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
{
// do current batch of gemm
threadwise_gemm(a_thread_mtx,
......@@ -226,32 +226,32 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
}
};
template <unsigned BlockSize,
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
unsigned BlockMatrixStrideA,
unsigned BlockMatrixStrideB,
unsigned ThreadMatrixStrideC,
unsigned BatchSize,
unsigned MPerThreadSubC,
unsigned NPerThreadSubC,
unsigned MLevel0Cluster,
unsigned NLevel0Cluster,
unsigned MLevel1Cluster,
unsigned NLevel1Cluster,
unsigned KPerThreadLoop,
unsigned BatchPerThread>
index_t BlockMatrixStrideA,
index_t BlockMatrixStrideB,
index_t ThreadMatrixStrideC,
index_t BatchSize,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop,
index_t BatchPerThread>
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
unsigned mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0;
index_t mMyThreadOffsetA = 0;
index_t mMyThreadOffsetB = 0;
struct MatrixIndex
{
unsigned batch;
unsigned row;
unsigned col;
index_t batch;
index_t row;
index_t col;
};
__device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2()
......@@ -259,9 +259,9 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(BatchSize % BatchPerThread == 0,
"wrong! BatchSize is not dividable by BatchPerThread");
constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread;
constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
constexpr unsigned ThreadPerLevel1Cluster =
constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster,
......@@ -274,31 +274,31 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! K dimension not consistent\n");
constexpr unsigned M = a_block_mtx.NCol(); // A is transposed
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol(); // A is transposed
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
"wrong! Cannot evenly divide thread work among repeat \n");
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
"wrong! Cannot evenly divide work among repeat\n");
constexpr unsigned MPerLevel1Cluster = M / MRepeat;
constexpr unsigned NPerLevel1Cluster = N / NRepeat;
constexpr index_t MPerLevel1Cluster = M / MRepeat;
constexpr index_t NPerLevel1Cluster = N / NRepeat;
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
(NPerLevel1Cluster % NLevel1Cluster == 0),
"wrong! Cannot evenly divide work among Level1Cluster\n");
constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
(NPerLevel0Cluster % NLevel0Cluster == 0),
......@@ -335,28 +335,28 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
#endif
}
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
{
constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread;
constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
constexpr unsigned ThreadPerLevel1Cluster =
constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
unsigned batch_work_id = thread_id / ThreadPerLevel1Cluster;
unsigned cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster;
index_t batch_work_id = thread_id / ThreadPerLevel1Cluster;
index_t cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster;
unsigned level1_id = cluster_id / ThreadPerLevel0Cluster;
unsigned level1_m_id = level1_id / NLevel1Cluster;
unsigned level1_n_id = level1_id % NLevel1Cluster;
index_t level1_id = cluster_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1Cluster;
index_t level1_n_id = level1_id % NLevel1Cluster;
unsigned level0_id = cluster_id % ThreadPerLevel0Cluster;
unsigned level0_m_id = level0_id / NLevel0Cluster;
unsigned level0_n_id = level0_id % NLevel0Cluster;
index_t level0_id = cluster_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0Cluster;
index_t level0_n_id = level0_id % NLevel0Cluster;
constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
return MatrixIndex{batch_work_id * BatchPerThread,
level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
......@@ -365,24 +365,24 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// this should be optimized away if input is known
__device__ static MatrixIndex
GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c)
GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
{
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
unsigned m_repeat = m_in_c / MPerThreadSubC;
unsigned n_repeat = n_in_c / NPerThreadSubC;
index_t m_repeat = m_in_c / MPerThreadSubC;
index_t n_repeat = n_in_c / NPerThreadSubC;
unsigned m_in_sub_c = m_in_c % MPerThreadSubC;
unsigned n_in_sub_c = n_in_c % NPerThreadSubC;
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{batch_in_c,
m_repeat * MPerLevel1Cluster + m_in_sub_c,
......@@ -402,10 +402,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
......@@ -425,20 +425,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// loop over k
#pragma unroll
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
// read first batch of A, B
// copy A-sub to form A
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
......@@ -451,7 +451,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// copy B-sub to form B
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
......@@ -464,7 +464,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// loop over batch
#pragma unroll
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib)
for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
{
// do current batch of gemm
threadwise_gemm(a_thread_mtx,
......@@ -482,7 +482,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideA != 0)
{
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
......@@ -498,7 +498,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideB != 0)
{
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
......@@ -539,10 +539,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
......@@ -562,25 +562,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// loop over k
//#pragma unroll
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
// read first batch of A, B
// copy A-sub to form A
//#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
for(unsigned i = 0; i < a_thread_sub_mtx.NRow(); ++i)
for(index_t i = 0; i < a_thread_sub_mtx.NRow(); ++i)
{
#if 1
for(unsigned j = 0; j < a_thread_sub_mtx.NCol(); ++j)
for(index_t j = 0; j < a_thread_sub_mtx.NCol(); ++j)
{
p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] =
p_a_block[a_block_mtx.Get1dIndex(k_begin + i,
......@@ -596,11 +596,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// copy B-sub to form B
//#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
for(unsigned i = 0; i < b_thread_sub_mtx.NRow(); ++i)
for(index_t i = 0; i < b_thread_sub_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < b_thread_sub_mtx.NCol(); ++j)
for(index_t j = 0; j < b_thread_sub_mtx.NCol(); ++j)
{
p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] =
p_b_block[b_block_mtx.Get1dIndex(k_begin + i,
......@@ -612,20 +612,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// loop over batch
//#pragma unroll
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib)
for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
{
// do current batch of gemm
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k)
for(index_t k = 0; k < a_thread_mtx.NRow(); ++k)
{
#if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
for(index_t j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex =
const index_t aindex =
a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex =
const index_t bindex = b_thread_mtx.Get1dIndex(k, j);
const index_t cindex =
c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
......@@ -635,11 +635,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4,
"asm is only for 16x4");
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0);
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0);
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const index_t cindex = c_thread_mtx.Get1dIndex(i, 0);
asm volatile("\n \
v_mac_f32 %0, %4, %5 \n \
......@@ -668,11 +668,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideA != 0)
{
//#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
for(unsigned i = 0; i < a_thread_sub_mtx.NRow(); ++i)
for(index_t i = 0; i < a_thread_sub_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < a_thread_sub_mtx.NCol(); ++j)
for(index_t j = 0; j < a_thread_sub_mtx.NCol(); ++j)
{
p_a_thread[a_thread_mtx.Get1dIndex(i,
m_repeat * MPerThreadSubC + j)] =
......@@ -687,11 +687,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideB != 0)
{
//#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
for(unsigned i = 0; i < b_thread_sub_mtx.NRow(); ++i)
for(index_t i = 0; i < b_thread_sub_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < b_thread_sub_mtx.NCol(); ++j)
for(index_t j = 0; j < b_thread_sub_mtx.NCol(); ++j)
{
p_b_thread[b_thread_mtx.Get1dIndex(i,
n_repeat * NPerThreadSubC + j)] =
......@@ -705,16 +705,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
// do last batch of gemm
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k)
for(index_t k = 0; k < a_thread_mtx.NRow(); ++k)
{
#if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
for(index_t j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex = c_thread_mtx.Get1dIndex(i, j) +
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const index_t bindex = b_thread_mtx.Get1dIndex(k, j);
const index_t cindex = c_thread_mtx.Get1dIndex(i, j) +
(BatchPerThread - 1) * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
......@@ -724,11 +724,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4,
"asm is only for 16x4");
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0);
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned cindex =
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const index_t cindex =
c_thread_mtx.Get1dIndex(i, 0) + (BatchPerThread - 1) * ThreadMatrixStrideC;
asm volatile("\n \
......@@ -756,34 +756,34 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
}
template <class BlockMatrixC, unsigned BlockMatrixStrideC, class FloatC>
template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
__device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
FloatC* __restrict__ p_c_block) const
{
constexpr auto c_block_mtx = BlockMatrixC{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned c_thread_offset =
const index_t c_thread_offset =
c_thread_mtx_begin.batch * BlockMatrixStrideC +
c_block_mtx.Get1dIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col);
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
c_thread_sub_mtx,
......
......@@ -3,16 +3,16 @@
#include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_direct_convolution.hip.hpp"
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class InBlockDesc,
class WeiBlockDesc,
class OutBlockDesc,
unsigned NPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned HoPerThread,
unsigned WoPerThread>
index_t NPerThread,
index_t KPerThread,
index_t CPerThread,
index_t HoPerThread,
index_t WoPerThread>
__device__ void blockwise_direct_convolution(InBlockDesc,
Float* const __restrict__ p_in_block,
WeiBlockDesc,
......@@ -29,17 +29,17 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr auto wei_block_desc = WeiBlockDesc{};
constexpr auto out_block_desc = OutBlockDesc{};
constexpr unsigned Y = wei_block_desc.GetLength(I2);
constexpr unsigned X = wei_block_desc.GetLength(I3);
constexpr index_t Y = wei_block_desc.GetLength(I2);
constexpr index_t X = wei_block_desc.GetLength(I3);
constexpr unsigned InTileSizeH = HoPerThread + Y - 1;
constexpr unsigned InTileSizeW = WoPerThread + X - 1;
constexpr index_t InTileSizeH = HoPerThread + Y - 1;
constexpr index_t InTileSizeW = WoPerThread + X - 1;
// divide thread work
constexpr unsigned NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread;
constexpr unsigned KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread;
constexpr unsigned YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread;
constexpr unsigned XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread;
constexpr index_t NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread;
constexpr index_t KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread;
constexpr index_t YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread;
constexpr index_t XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread;
#if 0
if(threadIdx.x == 0)
......@@ -68,27 +68,27 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr auto out_thread_block_desc =
make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_block_desc.GetStrides());
const unsigned thread_id = threadIdx.x;
const index_t thread_id = threadIdx.x;
for(unsigned thread_work_id = thread_id;
for(index_t thread_work_id = thread_id;
thread_work_id < NThreadWork * KThreadWork * YThreadWork * XThreadWork;
thread_work_id += BlockSize)
{
unsigned itmp = thread_work_id;
unsigned n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork);
index_t itmp = thread_work_id;
index_t n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork);
itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork);
unsigned k_thread_work_id = itmp / (YThreadWork * XThreadWork);
index_t k_thread_work_id = itmp / (YThreadWork * XThreadWork);
itmp -= k_thread_work_id * (YThreadWork * XThreadWork);
unsigned y_thread_work_id = itmp / XThreadWork;
unsigned x_thread_work_id = itmp - y_thread_work_id * XThreadWork;
index_t y_thread_work_id = itmp / XThreadWork;
index_t x_thread_work_id = itmp - y_thread_work_id * XThreadWork;
unsigned n_thread_data_begin = n_thread_work_id * NPerThread;
unsigned k_thread_data_begin = k_thread_work_id * KPerThread;
unsigned ho_thread_data_begin = y_thread_work_id * HoPerThread;
unsigned wo_thread_data_begin = x_thread_work_id * WoPerThread;
index_t n_thread_data_begin = n_thread_work_id * NPerThread;
index_t k_thread_data_begin = k_thread_work_id * KPerThread;
index_t ho_thread_data_begin = y_thread_work_id * HoPerThread;
index_t wo_thread_data_begin = x_thread_work_id * WoPerThread;
unsigned hi_thread_data_begin = ho_thread_data_begin; // minus padding
unsigned wi_thread_data_begin = wo_thread_data_begin; // minus padding
index_t hi_thread_data_begin = ho_thread_data_begin; // minus padding
index_t wi_thread_data_begin = wo_thread_data_begin; // minus padding
Float p_out_thread[out_thread_desc.GetElementSpace()];
......@@ -102,7 +102,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
p_out_thread,
out_thread_desc.GetLengths());
for(unsigned c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1);
for(index_t c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1);
c_thread_data_begin += CPerThread)
{
// threadwise convolution
......
#pragma once
#include "threadwise_gemm.hip.hpp"
template <unsigned BlockSize,
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
bool TransA,
bool TransB,
bool TransC,
unsigned KPerThreadLoop,
unsigned MThreadPerCluster,
unsigned NThreadPerCluster,
index_t KPerThreadLoop,
index_t MThreadPerCluster,
index_t NThreadPerCluster,
bool DistributeThreadAlongColumnFirst>
struct BlockwiseGemmBlockABlockBThreadC
{
unsigned mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0;
index_t mMyThreadOffsetA = 0;
index_t mMyThreadOffsetB = 0;
struct MatrixIndex
{
unsigned row;
unsigned col;
index_t row;
index_t col;
};
__device__ BlockwiseGemmBlockABlockBThreadC()
......@@ -55,7 +55,7 @@ struct BlockwiseGemmBlockABlockBThreadC
#endif
}
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
{
if(TransA && (!TransB) && (!TransC))
......@@ -66,14 +66,14 @@ struct BlockwiseGemmBlockABlockBThreadC
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! k dimension not consistent!");
constexpr unsigned MPerBlock = a_block_mtx.NCol();
constexpr unsigned NPerBlock = b_block_mtx.NCol();
constexpr index_t MPerBlock = a_block_mtx.NCol();
constexpr index_t NPerBlock = b_block_mtx.NCol();
constexpr auto c_thread_mtx = ThreadMatrixC{};
// divide thread work
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert(MPerBlock % (MPerThread * MThreadPerCluster) == 0,
"MPerBlock % (MPerThread * MThreadPerCluster) != 0");
......@@ -81,10 +81,10 @@ struct BlockwiseGemmBlockABlockBThreadC
static_assert(NPerBlock % (NPerThread * NThreadPerCluster) == 0,
"NPerBlock % (NPerThread * NThreadPerCluster) != 0");
constexpr unsigned MClusterWork =
constexpr index_t MClusterWork =
(MPerBlock + MPerThread * MThreadPerCluster - 1) / (MPerThread * MThreadPerCluster);
constexpr unsigned NClusterWork =
constexpr index_t NClusterWork =
(NPerBlock + NPerThread * NThreadPerCluster - 1) / (NPerThread * NThreadPerCluster);
static_assert(BlockSize ==
......@@ -94,19 +94,18 @@ struct BlockwiseGemmBlockABlockBThreadC
if(DistributeThreadAlongColumnFirst)
{
const unsigned cluster_work_block_id =
const index_t cluster_work_block_id =
thread_id / (MThreadPerCluster * NThreadPerCluster);
const unsigned thread_work_cluster_id =
const index_t thread_work_cluster_id =
thread_id - cluster_work_block_id * (MThreadPerCluster * NThreadPerCluster);
const unsigned m_cluster_work_block_id = cluster_work_block_id / NClusterWork;
const unsigned n_cluster_work_block_id =
const index_t m_cluster_work_block_id = cluster_work_block_id / NClusterWork;
const index_t n_cluster_work_block_id =
cluster_work_block_id - m_cluster_work_block_id * NClusterWork;
const unsigned m_thread_work_cluster_id =
thread_work_cluster_id / NThreadPerCluster;
const unsigned n_thread_work_cluster_id =
const index_t m_thread_work_cluster_id = thread_work_cluster_id / NThreadPerCluster;
const index_t n_thread_work_cluster_id =
thread_work_cluster_id - m_thread_work_cluster_id * NThreadPerCluster;
#if 0
......@@ -143,8 +142,8 @@ struct BlockwiseGemmBlockABlockBThreadC
}
// this should be optimized away if input is known
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
unsigned n_in_c)
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
index_t n_in_c)
{
return MatrixIndex{m_in_c, n_in_c};
}
......@@ -164,10 +163,10 @@ struct BlockwiseGemmBlockABlockBThreadC
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// a is transposed, b is not
constexpr auto a_thread_mtx =
......@@ -180,7 +179,7 @@ struct BlockwiseGemmBlockABlockBThreadC
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
// loop over k
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA +
......@@ -213,31 +212,31 @@ struct BlockwiseGemmBlockABlockBThreadC
// if following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
template <unsigned BlockSize,
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
unsigned MPerThreadSubC,
unsigned NPerThreadSubC,
unsigned MLevel0Cluster,
unsigned NLevel0Cluster,
unsigned MLevel1Cluster,
unsigned NLevel1Cluster,
unsigned KPerThreadLoop>
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
struct MatrixIndex
{
unsigned row;
unsigned col;
index_t row;
index_t col;
};
unsigned mMyThreadOffsetA;
unsigned mMyThreadOffsetB;
index_t mMyThreadOffsetA;
index_t mMyThreadOffsetB;
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
{
constexpr unsigned ThreadPerLevel1Cluster =
constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
......@@ -249,31 +248,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! K dimension not consistent\n");
constexpr unsigned M = a_block_mtx.NCol(); // A is transposed
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol(); // A is transposed
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
"wrong! Cannot evenly divide thread work among repeat \n");
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
"wrong! Cannot evenly divide work among repeat\n");
constexpr unsigned MPerLevel1Cluster = M / MRepeat;
constexpr unsigned NPerLevel1Cluster = N / NRepeat;
constexpr index_t MPerLevel1Cluster = M / MRepeat;
constexpr index_t NPerLevel1Cluster = N / NRepeat;
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
(NPerLevel1Cluster % NLevel1Cluster == 0),
"wrong! Cannot evenly divide work among Level1Cluster\n");
constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
(NPerLevel0Cluster % NLevel0Cluster == 0),
......@@ -289,45 +288,45 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id)
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{
constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
unsigned level1_id = thread_id / ThreadPerLevel0Cluster;
unsigned level1_m_id = level1_id / NLevel1Cluster;
unsigned level1_n_id = level1_id % NLevel1Cluster;
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1Cluster;
index_t level1_n_id = level1_id % NLevel1Cluster;
unsigned level0_id = thread_id % ThreadPerLevel0Cluster;
unsigned level0_m_id = level0_id / NLevel0Cluster;
unsigned level0_n_id = level0_id % NLevel0Cluster;
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0Cluster;
index_t level0_n_id = level0_id % NLevel0Cluster;
constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
// this should be optimized away if input is known
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
unsigned n_in_c)
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
index_t n_in_c)
{
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
unsigned m_repeat = m_in_c / MPerThreadSubC;
unsigned n_repeat = n_in_c / NPerThreadSubC;
index_t m_repeat = m_in_c / MPerThreadSubC;
index_t n_repeat = n_in_c / NPerThreadSubC;
unsigned m_in_sub_c = m_in_c % MPerThreadSubC;
unsigned n_in_sub_c = n_in_c % NPerThreadSubC;
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
n_repeat * NPerLevel1Cluster + n_in_sub_c};
......@@ -346,12 +345,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
constexpr auto a_thread_mtx =
......@@ -370,19 +369,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
#pragma unroll
// loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
#pragma unroll
// copy A-sub to form A
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
......@@ -395,7 +394,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll
// copy B-sub to form B
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
......@@ -433,12 +432,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
constexpr auto a_thread_mtx =
......@@ -457,19 +456,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
#pragma unroll
// loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
#pragma unroll
//#pragma unroll
// copy A-sub to form A
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
......@@ -480,9 +479,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
a_thread_sub_mtx.GetLengths());
}
#pragma unroll
//#pragma unroll
// copy B-sub to form B
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
......@@ -505,19 +504,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
False,
p_c_thread,
f_accum);
#else
#elif 0
// inline asm
static_assert(c_thread_mtx.NRow() == 8 && c_thread_mtx.NCol() == 8,
"asm is only for 8x8");
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed
for(index_t k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed
{
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0);
const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0);
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const index_t cindex = c_thread_mtx.Get1dIndex(i, 0);
asm volatile("\n \
v_mac_f32 %0, %8, %9 \n \
......@@ -573,12 +572,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
constexpr auto a_thread_mtx =
......@@ -601,15 +600,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// preload A, B
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
......@@ -619,7 +618,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
......@@ -631,7 +630,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
bool even_loop = true;
#pragma unroll
for(unsigned k_begin = 0; k_begin + KPerThreadLoop < K;
for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
k_begin += KPerThreadLoop, even_loop = !even_loop)
{ // loop over k
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
......@@ -642,7 +641,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// preload next A, B
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA +
......@@ -654,7 +653,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB +
......@@ -710,12 +709,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A-sub, B-sub, C-sub
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
......@@ -737,15 +736,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
#pragma unroll
// loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
// C-sub(s) in first row-wise subblock of C
{
......@@ -779,7 +778,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll
// copy next B-sub, and do GEMM
for(unsigned n_repeat = 1; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 1; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
......@@ -805,7 +804,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll
// loop over rest of row-wise subblock
// all B-sub(s) has been copied, so only A-sub(s) need to be copied
for(unsigned m_repeat = 1; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 1; m_repeat < MRepeat; ++m_repeat)
{
// copy a A-sub
threadwise_matrix_copy(
......@@ -817,7 +816,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
a_thread_sub_mtx.GetLengths());
// do some GEMMs
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_gemm(
a_thread_sub_mtx,
......
......@@ -5,9 +5,9 @@
#include "Array.hip.hpp"
#include "functional.hip.hpp"
__device__ unsigned get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ unsigned get_block_1d_id() { return blockIdx.x; }
__device__ index_t get_block_1d_id() { return blockIdx.x; }
template <class T1, class T2>
struct is_same
......@@ -35,7 +35,7 @@ __host__ __device__ constexpr T min(T a, T b)
}
#endif
__host__ __device__ constexpr unsigned integer_divide_ceil(unsigned a, unsigned b)
__host__ __device__ constexpr index_t integer_divide_ceil(index_t a, index_t b)
{
return (a + b - 1) / b;
}
......@@ -11,3 +11,5 @@
#include "nvToolsExt.h"
#include "helper_cuda.h"
#endif
using index_t = uint32_t;
......@@ -8,5 +8,5 @@ struct integral_constant
__host__ __device__ constexpr T Get() const { return value; }
};
template <unsigned N>
using Number = integral_constant<unsigned, N>;
template <index_t N>
using Number = integral_constant<index_t, N>;
#pragma once
#include "config.h"
template <class T, unsigned N>
template <class T, index_t N>
struct vector_type
{
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment