Commit 1b323316 authored by Chao Liu's avatar Chao Liu
Browse files

add another blockwise gemm

parent 5e776504
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh" #include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh" #include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh" #include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh"
//#include "device_winograd_convolution.cuh" //#include "device_winograd_convolution.cuh"
struct GeneratorTensor_1 struct GeneratorTensor_1
...@@ -391,7 +392,7 @@ int main() ...@@ -391,7 +392,7 @@ int main()
constexpr unsigned HPad = 0; constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0; constexpr unsigned WPad = 0;
#elif 1 #elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr unsigned N = 64; constexpr unsigned N = 64;
constexpr unsigned C = 256; constexpr unsigned C = 256;
...@@ -484,7 +485,7 @@ int main() ...@@ -484,7 +485,7 @@ int main()
constexpr unsigned HPad = 1; constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1; constexpr unsigned WPad = 1;
#elif 0 #elif 1
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
constexpr unsigned N = 16; constexpr unsigned N = 16;
constexpr unsigned C = 256; constexpr unsigned C = 256;
...@@ -591,8 +592,10 @@ int main() ...@@ -591,8 +592,10 @@ int main()
device_implicit_gemm_convolution_1_chwn_csrk_khwn device_implicit_gemm_convolution_1_chwn_csrk_khwn
#elif 0 #elif 0
device_implicit_gemm_convolution_2_cnhw_srck_knhw device_implicit_gemm_convolution_2_cnhw_srck_knhw
#elif 1 #elif 0
device_implicit_gemm_convolution_2_cnhw_csrk_knhw device_implicit_gemm_convolution_2_cnhw_csrk_knhw
#elif 1
device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2
#endif #endif
(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat);
...@@ -608,7 +611,7 @@ int main() ...@@ -608,7 +611,7 @@ int main()
nrepeat); nrepeat);
#endif #endif
#if 1 #if 0
if(S == 3 && R == 3) if(S == 3 && R == 3)
{ {
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
......
...@@ -67,7 +67,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, ...@@ -67,7 +67,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
Tensor<T> out_knhw(make_TensorDescriptor(out_knhw_desc)); Tensor<T> out_knhw(make_TensorDescriptor(out_knhw_desc));
#if 1 #if 0
// 3x3, 34x34 // 3x3, 34x34
constexpr unsigned BPerBlock = 128; constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 64; constexpr unsigned KPerBlock = 64;
...@@ -90,31 +90,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, ...@@ -90,31 +90,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#elif 0
// 1x1, 28x28
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 8;
constexpr unsigned BPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 64;
#elif 1 #elif 1
// 1x1, 28x28 try // 1x1, 28x28
constexpr unsigned BPerBlock = 64; constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 64; constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 8; constexpr unsigned CPerBlock = 8;
......
#pragma once
#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh"
#include <unistd.h>
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcsr,
OutDesc,
Tensor<T>& out_nkhw,
unsigned nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcsr_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned N = in_nchw_desc.GetLength(I0);
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcsr_desc.GetLength(I0);
constexpr unsigned C = wei_kcsr_desc.GetLength(I1);
constexpr unsigned S = wei_kcsr_desc.GetLength(I2);
constexpr unsigned R = wei_kcsr_desc.GetLength(I3);
constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1);
// convert in_nchw to in_cnhw
auto in_cnhw_desc = make_ConstantTensorDescriptor(Sequence<C, N, Hi, Wi>{});
ostream_ConstantTensorDescriptor(in_cnhw_desc, std::cout << "in_cnhw_desc: ");
Tensor<T> in_cnhw(make_TensorDescriptor(in_cnhw_desc));
auto f_reorder_nchw2cnhw = [&](auto n, auto c, auto hi, auto wi) {
in_cnhw(c, n, hi, wi) = in_nchw(n, c, hi, wi);
};
make_ParallelTensorFunctor(f_reorder_nchw2cnhw, N, C, Hi, Wi)(
std::thread::hardware_concurrency());
// convert wei_kcsr to wei_csrk
auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence<C, S, R, K>{});
ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: ");
Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc));
auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto s, auto r) {
wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r);
};
make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, S, R)(
std::thread::hardware_concurrency());
// conver out_nkhw to out_knhw
auto out_knhw_desc = make_ConstantTensorDescriptor(Sequence<K, N, Ho, Wo>{});
ostream_ConstantTensorDescriptor(out_knhw_desc, std::cout << "out_knhw_desc: ");
Tensor<T> out_knhw(make_TensorDescriptor(out_knhw_desc));
#if 0
// 1x1, 28x28
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 8;
constexpr unsigned BPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned GemmMPerThreadSubC = 16;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 8;
constexpr unsigned GemmMLevel1Cluster = 1;
constexpr unsigned GemmNLevel1Cluster = 2;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 64;
#elif 1
// 1x1, 28x28 try
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 8;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 1;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 64;
#endif
constexpr unsigned GridSize =
((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
dim3 block_dim(BlockSize);
dim3 grid_dim(GridSize);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
// mem
std::size_t data_sz = sizeof(T);
DeviceMem in_cnhw_device_buf(data_sz * (in_cnhw.mDesc.GetElementSpace() + BGhostRead +
BPerBlock)); // reserve extra space for BGhostRead
DeviceMem wei_csrk_device_buf(data_sz * wei_csrk.mDesc.GetElementSpace());
DeviceMem out_knhw_device_buf(data_sz * out_knhw.mDesc.GetElementSpace());
in_cnhw_device_buf.ToDevice(in_cnhw.mData.data());
wei_csrk_device_buf.ToDevice(wei_csrk.mData.data());
out_knhw_device_buf.ToDevice(out_knhw.mData.data());
for(unsigned i = 0; i < nrepeat; ++i)
{
cudaEvent_t start, stop;
float elapsedTime;
cudaEventCreate(&start);
cudaEventRecord(start, 0);
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2<GridSize,
BlockSize,
T,
decltype(in_cnhw_desc),
decltype(wei_csrk_desc),
decltype(out_knhw_desc),
BPerBlock,
KPerBlock,
CPerBlock,
BPerThread,
KPerThread,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1,
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1,
InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead>
<<<grid_dim, block_dim>>>(in_cnhw_desc,
static_cast<T*>(in_cnhw_device_buf.GetDeviceBuffer()),
wei_csrk_desc,
static_cast<T*>(wei_csrk_device_buf.GetDeviceBuffer()),
out_knhw_desc,
static_cast<T*>(out_knhw_device_buf.GetDeviceBuffer()));
cudaEventCreate(&stop);
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Elapsed time : %f ms\n", elapsedTime);
usleep(std::min(elapsedTime * 1000, float(10000)));
}
checkCudaErrors(cudaGetLastError());
out_knhw_device_buf.FromDevice(out_knhw.mData.data());
// convert out_knhw to out_nkhw
auto f_reorder_knhw2nkhw = [&](auto n, auto k, auto ho, auto wo) {
out_nkhw(n, k, ho, wo) = out_knhw(k, n, ho, wo);
};
make_ParallelTensorFunctor(f_reorder_knhw2nkhw, N, K, Ho, Wo)(
std::thread::hardware_concurrency());
}
...@@ -22,9 +22,9 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC ...@@ -22,9 +22,9 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
struct MatrixIndex struct MatrixIndex
{ {
unsigned batch_begin; unsigned batch;
unsigned row_begin; unsigned row;
unsigned col_begin; unsigned col;
}; };
__device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC() __device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC()
...@@ -32,15 +32,15 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC ...@@ -32,15 +32,15 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
const auto c_thread_mtx_index = CalculateThreadMatrixCIndex(get_thread_local_1d_id()); const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = c_thread_mtx_index.batch_begin * BlockMatrixStrideA + mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA +
((!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row_begin, 0) ((!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row, 0)
: a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row_begin)); : a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row));
mMyThreadOffsetB = c_thread_mtx_index.batch_begin * BlockMatrixStrideB + mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB +
((!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col_begin) ((!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col)
: b_block_mtx.Get1dIndex(c_thread_mtx_index.col_begin, 0)); : b_block_mtx.Get1dIndex(c_thread_mtx_index.col, 0));
#if 0 #if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...@@ -52,16 +52,16 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC ...@@ -52,16 +52,16 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
printf("%u %u, %u %u %u, %u %u\n", printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(), get_block_1d_id(),
get_thread_local_1d_id(), get_thread_local_1d_id(),
c_thread_mtx_index.batch_begin, c_thread_mtx_index.batch,
c_thread_mtx_index.row_begin, c_thread_mtx_index.row,
c_thread_mtx_index.col_begin, c_thread_mtx_index.col,
mMyThreadOffsetA, mMyThreadOffsetA,
mMyThreadOffsetB); mMyThreadOffsetB);
} }
#endif #endif
} }
__device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const __device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
{ {
if(TransA && (!TransB) && (!TransC)) if(TransA && (!TransB) && (!TransC))
...@@ -237,8 +237,8 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -237,8 +237,8 @@ struct BlockwiseGemmBlockABlockBThreadC
struct MatrixIndex struct MatrixIndex
{ {
unsigned row_begin; unsigned row;
unsigned col_begin; unsigned col;
}; };
__device__ BlockwiseGemmBlockABlockBThreadC() __device__ BlockwiseGemmBlockABlockBThreadC()
...@@ -246,13 +246,13 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -246,13 +246,13 @@ struct BlockwiseGemmBlockABlockBThreadC
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
const auto c_thread_mtx_index = CalculateThreadMatrixCIndex(get_thread_local_1d_id()); const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = (!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row_begin, 0) mMyThreadOffsetA = (!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row, 0)
: a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row_begin); : a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row);
mMyThreadOffsetB = (!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col_begin) mMyThreadOffsetB = (!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col)
: b_block_mtx.Get1dIndex(c_thread_mtx_index.col_begin, 0); : b_block_mtx.Get1dIndex(c_thread_mtx_index.col, 0);
#if 0 #if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...@@ -264,16 +264,16 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -264,16 +264,16 @@ struct BlockwiseGemmBlockABlockBThreadC
printf("%u %u, %u %u %u, %u %u\n", printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(), get_block_1d_id(),
get_thread_local_1d_id(), get_thread_local_1d_id(),
c_thread_mtx_index.batch_begin, c_thread_mtx_index.batch,
c_thread_mtx_index.row_begin, c_thread_mtx_index.row,
c_thread_mtx_index.col_begin, c_thread_mtx_index.col,
mMyThreadOffsetA, mMyThreadOffsetA,
mMyThreadOffsetB); mMyThreadOffsetB);
} }
#endif #endif
} }
__device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const __device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
{ {
if(TransA && (!TransB) && (!TransC)) if(TransA && (!TransB) && (!TransC))
...@@ -359,6 +359,13 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -359,6 +359,13 @@ struct BlockwiseGemmBlockABlockBThreadC
} }
} }
// this should be optimized away if input is known
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
unsigned n_in_c)
{
return MatrixIndex{m_in_c, n_in_c};
}
template <class FloatA, class FloatB, class FloatC, class Accumulator> template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void Run(FloatA* const p_a_block, __device__ void Run(FloatA* const p_a_block,
FloatB* const p_b_block, FloatB* const p_b_block,
...@@ -420,3 +427,215 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -420,3 +427,215 @@ struct BlockwiseGemmBlockABlockBThreadC
} }
} }
}; };
// if following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
template <unsigned BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
unsigned MPerThreadSubC,
unsigned NPerThreadSubC,
unsigned MLevel0Cluster,
unsigned NLevel0Cluster,
unsigned MLevel1Cluster,
unsigned NLevel1Cluster,
unsigned KPerThreadLoop>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
struct MatrixIndex
{
unsigned row;
unsigned col;
};
unsigned mMyThreadOffsetA;
unsigned mMyThreadOffsetB;
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
{
constexpr unsigned ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! K dimension not consistent\n");
constexpr unsigned M = a_block_mtx.NCol(); // A is transposed
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
"wrong! Cannot evenly divide thread work among repeat \n");
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
"wrong! Cannot evenly divide work among repeat\n");
constexpr unsigned MPerLevel1Cluster = M / MRepeat;
constexpr unsigned NPerLevel1Cluster = N / NRepeat;
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
(NPerLevel1Cluster % NLevel1Cluster == 0),
"wrong! Cannot evenly divide work among Level1Cluster\n");
constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
(NPerLevel0Cluster % NLevel0Cluster == 0),
"wrong! Cannot evenly divide work among Level0Cluster\n");
static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
(NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
"wrong! thread work size is wrong\n");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row);
mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id)
{
constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
unsigned level1_id = thread_id / ThreadPerLevel0Cluster;
unsigned level1_m_id = level1_id / NLevel1Cluster;
unsigned level1_n_id = level1_id % NLevel1Cluster;
unsigned level0_id = thread_id % ThreadPerLevel0Cluster;
unsigned level0_m_id = level0_id / NLevel0Cluster;
unsigned level0_n_id = level0_id % NLevel0Cluster;
constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
// this should be optimized away if input is known
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
unsigned n_in_c)
{
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
unsigned m_repeat = m_in_c / MPerThreadSubC;
unsigned n_repeat = n_in_c / NPerThreadSubC;
unsigned m_in_sub_c = m_in_c % MPerThreadSubC;
unsigned n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
n_repeat * NPerLevel1Cluster + n_in_sub_c};
}
template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void Run(FloatA* const p_a_block,
FloatB* const p_b_block,
FloatC* p_c_thread,
Accumulator f_accum) const
{
constexpr auto True = Constant<bool, true>{};
constexpr auto False = Constant<bool, false>{};
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
const auto a_thread_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThread>{}); // constexpr doesn't compile
const auto b_thread_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThread>{}); // constexpr doesn't compile
// thread A-sub, B-sub for copy
const auto a_thread_sub_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{},
Number<MPerThreadSubC>{},
Number<MPerThread>{}); // constexpr doesn't compile
const auto b_thread_sub_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{},
Number<NPerThreadSubC>{},
Number<NPerThread>{}); // constexpr doesn't compile
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
// loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
// copy A-sub to form A
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA +
k_begin * a_block_mtx.RowStride() +
m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx,
p_a_thread + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths());
}
// copy B-sub to form B
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB +
k_begin * b_block_mtx.RowStride() +
n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx,
p_b_thread + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths());
}
// C = A * B
threadwise_gemm(a_thread_mtx,
True,
p_a_thread,
b_thread_mtx,
False,
p_b_thread,
c_thread_mtx,
False,
p_c_thread,
f_accum);
}
}
};
...@@ -208,13 +208,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc, ...@@ -208,13 +208,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
} }
const auto matrix_c_index = const auto matrix_c_index =
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; const unsigned ho_thread_data_begin = matrix_c_index.batch;
const unsigned k_thread_data_begin = matrix_c_index.row_begin; const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock; const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
const unsigned n_thread_data_begin = const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
#if 0 #if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n", printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
......
...@@ -262,13 +262,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri ...@@ -262,13 +262,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
} }
const auto matrix_c_index = const auto matrix_c_index =
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; const unsigned ho_thread_data_begin = matrix_c_index.batch;
const unsigned k_thread_data_begin = matrix_c_index.row_begin; const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock; const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
const unsigned n_thread_data_begin = const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
#if 0 #if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n", printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
......
...@@ -318,13 +318,12 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p ...@@ -318,13 +318,12 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
} }
const auto matrix_c_index = const auto matrix_c_index =
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; const unsigned ho_thread_data_begin = matrix_c_index.batch;
const unsigned k_thread_data_begin = matrix_c_index.row_begin; const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock; const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
const unsigned n_thread_data_begin = const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
#if 0 #if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n", printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
......
...@@ -228,15 +228,15 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc, ...@@ -228,15 +228,15 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
} }
const auto matrix_c_index = const auto matrix_c_index =
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
#if 0 #if 0
printf("%u %u, %u %u %u\n",get_block_1d_id(), get_thread_local_1d_id(), matrix_c_index.batch_begin, matrix_c_index.row_begin, matrix_c_index.col_begin); printf("%u %u, %u %u %u\n",get_block_1d_id(), get_thread_local_1d_id(), matrix_c_index.batch, matrix_c_index.row, matrix_c_index.col);
#endif #endif
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; const unsigned ho_thread_data_begin = matrix_c_index.batch;
const unsigned k_thread_data_begin = matrix_c_index.row_begin; const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread; const unsigned wo_thread_data_begin = matrix_c_index.col / NPerThread;
#if 1 #if 1
// output: register to global mem, // output: register to global mem,
......
...@@ -205,13 +205,12 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, ...@@ -205,13 +205,12 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
} }
const auto matrix_c_index = const auto matrix_c_index =
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; const unsigned ho_thread_data_begin = matrix_c_index.batch;
const unsigned k_thread_data_begin = matrix_c_index.row_begin; const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock; const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
const unsigned n_thread_data_begin = const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
// output: register to global mem, // output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo] // convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
......
...@@ -75,6 +75,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc, ...@@ -75,6 +75,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * S * R, K>{}); constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * S * R, K>{});
// tensor view of blockwise input and weight // tensor view of blockwise input and weight
// be careful of alignment
constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, BPerBlock + BGhostRead>{}, Number<InBlockCopyDataPerRead>{}); Sequence<CPerBlock, BPerBlock + BGhostRead>{}, Number<InBlockCopyDataPerRead>{});
...@@ -245,11 +246,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc, ...@@ -245,11 +246,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
} }
// output: register to global mem, // output: register to global mem,
const auto matrix_c_index = const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const unsigned k_thread_data_begin = matrix_c_index.row_begin; const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned b_thread_data_begin = matrix_c_index.col_begin; const unsigned b_thread_data_begin = matrix_c_index.col;
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
...@@ -257,11 +257,11 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc, ...@@ -257,11 +257,11 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
#if 0 #if 0
if(get_block_1d_id() == 0) if(get_block_1d_id() == 0)
{ {
printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(), get_block_1d_id(),
get_thread_local_1d_id(), get_thread_local_1d_id(),
matrix_c_index.row_begin, matrix_c_index.row,
matrix_c_index.col_begin, matrix_c_index.col,
k_data_begin, k_data_begin,
b_data_begin, b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
......
#pragma once
#include "common.cuh"
#include "ConstantTensorDescriptor.cuh"
#include "ConstantMatrixDescriptor.cuh"
#include "blockwise_4d_tensor_op.cuh"
#include "blockwise_2d_tensor_op.cuh"
#include "threadwise_2d_tensor_op.cuh"
#include "blockwise_gemm.cuh"
// define B = flatten(N, Hi, Wi)
template <unsigned GridSize,
unsigned BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
unsigned BPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned BPerThread,
unsigned KPerThread,
unsigned GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerRowPerCluster,
unsigned GemmMPerThreadSubC,
unsigned GemmNPerThreadSubC,
unsigned GemmMLevel0Cluster,
unsigned GemmNLevel0Cluster,
unsigned GemmMLevel1Cluster,
unsigned GemmNLevel1Cluster,
unsigned GemmKPerThreadLoop,
unsigned InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1,
unsigned WeiBlockCopyThreadPerDim0,
unsigned WeiBlockCopyThreadPerDim1,
unsigned InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead>
__global__ void
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InGlobalDesc,
Float* const __restrict__ p_in_global,
WeiGlobalDesc,
Float* const __restrict__ p_wei_global,
OutGlobalDesc,
Float* __restrict__ p_out_global)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_cnhw_global_desc = InGlobalDesc{};
constexpr auto wei_csrk_global_desc = WeiGlobalDesc{};
constexpr auto out_knhw_global_desc = OutGlobalDesc{};
constexpr unsigned C = in_cnhw_global_desc.GetLength(I0);
constexpr unsigned N = in_cnhw_global_desc.GetLength(I1);
constexpr unsigned Hi = in_cnhw_global_desc.GetLength(I2);
constexpr unsigned Wi = in_cnhw_global_desc.GetLength(I3);
constexpr unsigned K = out_knhw_global_desc.GetLength(I0);
constexpr unsigned Ho = out_knhw_global_desc.GetLength(I2);
constexpr unsigned Wo = out_knhw_global_desc.GetLength(I3);
constexpr unsigned S = wei_csrk_global_desc.GetLength(I1);
constexpr unsigned R = wei_csrk_global_desc.GetLength(I2);
constexpr unsigned B = N * Hi * Wi;
constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1);
// divide block work by 2d: [K, B]
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock;
const unsigned k_block_work_id = get_block_1d_id() / BBlockWork;
const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned b_block_data_begin = b_block_work_id * BPerBlock;
// flattend (2d) tensor view of gridwise input
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * S * R, K>{});
// tensor view of blockwise input and weight
// be careful of alignment
constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, BPerBlock + BGhostRead>{}, Number<InBlockCopyDataPerRead>{});
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock * S * R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, S, R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
// tensor view of threadwise output in register
constexpr auto out_kb_thread_desc =
make_ConstantTensorDescriptor(Sequence<KPerThread, BPerThread>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_cnhw_global_desc, "in_cnhw_global_desc");
print_ConstantTensorDescriptor(wei_csrk_global_desc, "wei_csrk_global_desc");
print_ConstantTensorDescriptor(out_knhw_global_desc, "out_knhw_global_desc");
print_ConstantTensorDescriptor(in_cb_global_desc, "in_cb_global_desc");
print_ConstantTensorDescriptor(wei_ek_global_desc, "wei_ek_global_desc");
print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc");
print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc");
print_ConstantTensorDescriptor(wei_ek_block_desc, "wei_ek_block_desc");
print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc");
printf("KPerBlock %u\n", KPerBlock);
}
#endif
// blockwise in copy
// formmat is [CPerBlock,BPerBlock + BGhostRead]
#if 0
const auto blockwise_in_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths()),
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_in_copy = Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths()),
InBlockCopyDataPerRead>{};
#endif
// blockwise wei copy
// format is [CPerBlock*S*R,KPerBlock]
#if 0
const auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{};
#endif
// a series of blockwise GEMM
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
// a_mtx[C,K] is a sub-matrix of wei_block[C,S,R,K]
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
// c_mtx[K,B] is out_block[K,B]
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{},
Number<KPerBlock>{},
Number<wei_csrk_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{},
Number<BPerBlock>{},
Number<in_cb_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
#if 0
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
decltype(c_kxb_thread_mtx_desc),
true,
false,
false,
GemmKPerThreadLoop,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
true>{};
#else
const auto blockwise_gemm =
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
decltype(c_kxb_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop>{};
#endif
// LDS: be careful of alignment
constexpr unsigned in_block_size =
in_cb_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr unsigned wei_block_size =
wei_csrk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
// register
Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
// set threadwise output tensor to 0
threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread);
Float* p_in_global_block_offset =
p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin);
Float* p_wei_global_block_offset =
p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0),
__syncthreads())
{
// input: global mem to LDS,
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
// weight: global mem to LDS,
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
__syncthreads();
// a series of GEMM
for(unsigned s = 0; s < S; ++s)
{
for(unsigned r = 0; r < R; ++r)
{
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
blockwise_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
p_in_block + s * Wi + r,
p_out_thread,
f_accum);
}
}
}
// output: register to global mem,
const auto c_thread_mtx_begin =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
matrix_c_index.row,
matrix_c_index.col,
k_data_begin,
b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
}
#endif
for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
{
for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
{
const auto c_thread_mtx_distance =
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row;
unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col;
unsigned n_data = b_data / (Hi * Wi);
unsigned itmp = b_data - n_data * (Hi * Wi);
unsigned h_data = itmp / Wi;
unsigned w_data = itmp - h_data * Wi;
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
k,
b,
k_data,
n_data,
h_data,
w_data,
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]);
}
#endif
if(n_data < N && h_data < Ho && w_data < Wo)
{
p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] =
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)];
}
}
}
}
...@@ -290,11 +290,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline ...@@ -290,11 +290,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
} }
// output: register to global mem, // output: register to global mem,
const auto matrix_c_index = const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const unsigned k_thread_data_begin = matrix_c_index.row_begin; const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned b_thread_data_begin = matrix_c_index.col_begin; const unsigned b_thread_data_begin = matrix_c_index.col;
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
...@@ -302,11 +301,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline ...@@ -302,11 +301,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
#if 0 #if 0
if(get_block_1d_id() == 0) if(get_block_1d_id() == 0)
{ {
printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(), get_block_1d_id(),
get_thread_local_1d_id(), get_thread_local_1d_id(),
matrix_c_index.row_begin, matrix_c_index.row,
matrix_c_index.col_begin, matrix_c_index.col,
k_data_begin, k_data_begin,
b_data_begin, b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
......
...@@ -217,11 +217,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, ...@@ -217,11 +217,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
} }
// output: register to global mem, // output: register to global mem,
const auto matrix_c_index = const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const unsigned k_thread_data_begin = matrix_c_index.row_begin; const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned b_thread_data_begin = matrix_c_index.col_begin; const unsigned b_thread_data_begin = matrix_c_index.col;
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
...@@ -229,11 +228,11 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, ...@@ -229,11 +228,11 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
#if 0 #if 0
if(get_block_1d_id() == 0) if(get_block_1d_id() == 0)
{ {
printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(), get_block_1d_id(),
get_thread_local_1d_id(), get_thread_local_1d_id(),
matrix_c_index.row_begin, matrix_c_index.row,
matrix_c_index.col_begin, matrix_c_index.col,
k_data_begin, k_data_begin,
b_data_begin, b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
......
...@@ -276,11 +276,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline ...@@ -276,11 +276,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
} }
// output: register to global mem, // output: register to global mem,
const auto matrix_c_index = const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const unsigned k_thread_data_begin = matrix_c_index.row_begin; const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned b_thread_data_begin = matrix_c_index.col_begin; const unsigned b_thread_data_begin = matrix_c_index.col;
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
...@@ -288,11 +287,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline ...@@ -288,11 +287,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
#if 0 #if 0
if(get_block_1d_id() == 0) if(get_block_1d_id() == 0)
{ {
printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(), get_block_1d_id(),
get_thread_local_1d_id(), get_thread_local_1d_id(),
matrix_c_index.row_begin, matrix_c_index.row,
matrix_c_index.col_begin, matrix_c_index.col,
k_data_begin, k_data_begin,
b_data_begin, b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment