Commit 07f16673 authored by Chao Liu's avatar Chao Liu
Browse files

add lds double buffer for cnhw implicit gemm

parent c8667736
......@@ -14,7 +14,6 @@
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh"
//#include "device_winograd_convolution.cuh"
struct GeneratorTensor_1
......@@ -392,7 +391,7 @@ int main()
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
#elif 0
#elif 1
// 3x3, 34x34
constexpr unsigned N = 64;
constexpr unsigned C = 256;
......@@ -592,10 +591,8 @@ int main()
device_implicit_gemm_convolution_1_chwn_csrk_khwn
#elif 0
device_implicit_gemm_convolution_2_cnhw_srck_knhw
#elif 0
device_implicit_gemm_convolution_2_cnhw_csrk_knhw
#elif 1
device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2
device_implicit_gemm_convolution_2_cnhw_csrk_knhw
#endif
(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat);
......
#pragma once
#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh"
#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh"
#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.cuh"
#include <unistd.h>
template <class T, class InDesc, class WeiDesc, class OutDesc>
......@@ -67,17 +67,24 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
Tensor<T> out_knhw(make_TensorDescriptor(out_knhw_desc));
#if 0
#if 1
// 3x3, 34x34
constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned BPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 8;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 1;
constexpr unsigned GemmNLevel1Cluster = 8;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
......@@ -90,17 +97,24 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
#elif 1
#elif 0
// 1x1, 28x28
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 8;
constexpr unsigned BPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 8;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 1;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
......@@ -113,6 +127,36 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 64;
#elif 1
// 1x1, 28x28 try
constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 8;
constexpr unsigned GemmNLevel0Cluster = 4;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 256;
#endif
constexpr unsigned GridSize =
......@@ -143,8 +187,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
#if 1
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw
#elif 1
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
#else
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer
#endif
<GridSize,
BlockSize,
......@@ -157,9 +201,15 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
CPerBlock,
BPerThread,
KPerThread,
CPerThread,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1,
WeiBlockCopyThreadPerDim0,
......
#pragma once
#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh"
#include <unistd.h>
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcsr,
OutDesc,
Tensor<T>& out_nkhw,
unsigned nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcsr_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned N = in_nchw_desc.GetLength(I0);
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcsr_desc.GetLength(I0);
constexpr unsigned C = wei_kcsr_desc.GetLength(I1);
constexpr unsigned S = wei_kcsr_desc.GetLength(I2);
constexpr unsigned R = wei_kcsr_desc.GetLength(I3);
constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1);
// convert in_nchw to in_cnhw
auto in_cnhw_desc = make_ConstantTensorDescriptor(Sequence<C, N, Hi, Wi>{});
ostream_ConstantTensorDescriptor(in_cnhw_desc, std::cout << "in_cnhw_desc: ");
Tensor<T> in_cnhw(make_TensorDescriptor(in_cnhw_desc));
auto f_reorder_nchw2cnhw = [&](auto n, auto c, auto hi, auto wi) {
in_cnhw(c, n, hi, wi) = in_nchw(n, c, hi, wi);
};
make_ParallelTensorFunctor(f_reorder_nchw2cnhw, N, C, Hi, Wi)(
std::thread::hardware_concurrency());
// convert wei_kcsr to wei_csrk
auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence<C, S, R, K>{});
ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: ");
Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc));
auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto s, auto r) {
wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r);
};
make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, S, R)(
std::thread::hardware_concurrency());
// conver out_nkhw to out_knhw
auto out_knhw_desc = make_ConstantTensorDescriptor(Sequence<K, N, Ho, Wo>{});
ostream_ConstantTensorDescriptor(out_knhw_desc, std::cout << "out_knhw_desc: ");
Tensor<T> out_knhw(make_TensorDescriptor(out_knhw_desc));
#if 1
// 1x1, 28x28
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 8;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 1;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 64;
#endif
constexpr unsigned GridSize =
((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
dim3 block_dim(BlockSize);
dim3 grid_dim(GridSize);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
// mem
std::size_t data_sz = sizeof(T);
DeviceMem in_cnhw_device_buf(data_sz * (in_cnhw.mDesc.GetElementSpace() + BGhostRead +
BPerBlock)); // reserve extra space for BGhostRead
DeviceMem wei_csrk_device_buf(data_sz * wei_csrk.mDesc.GetElementSpace());
DeviceMem out_knhw_device_buf(data_sz * out_knhw.mDesc.GetElementSpace());
in_cnhw_device_buf.ToDevice(in_cnhw.mData.data());
wei_csrk_device_buf.ToDevice(wei_csrk.mData.data());
out_knhw_device_buf.ToDevice(out_knhw.mData.data());
for(unsigned i = 0; i < nrepeat; ++i)
{
cudaEvent_t start, stop;
float elapsedTime;
cudaEventCreate(&start);
cudaEventRecord(start, 0);
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2<GridSize,
BlockSize,
T,
decltype(in_cnhw_desc),
decltype(wei_csrk_desc),
decltype(out_knhw_desc),
BPerBlock,
KPerBlock,
CPerBlock,
BPerThread,
KPerThread,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1,
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1,
InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead>
<<<grid_dim, block_dim>>>(in_cnhw_desc,
static_cast<T*>(in_cnhw_device_buf.GetDeviceBuffer()),
wei_csrk_desc,
static_cast<T*>(wei_csrk_device_buf.GetDeviceBuffer()),
out_knhw_desc,
static_cast<T*>(out_knhw_device_buf.GetDeviceBuffer()));
cudaEventCreate(&stop);
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Elapsed time : %f ms\n", elapsedTime);
usleep(std::min(elapsedTime * 1000, float(10000)));
}
checkCudaErrors(cudaGetLastError());
out_knhw_device_buf.FromDevice(out_knhw.mData.data());
// convert out_knhw to out_nkhw
auto f_reorder_knhw2nkhw = [&](auto n, auto k, auto ho, auto wo) {
out_nkhw(n, k, ho, wo) = out_knhw(k, n, ho, wo);
};
make_ParallelTensorFunctor(f_reorder_knhw2nkhw, N, K, Ho, Wo)(
std::thread::hardware_concurrency());
}
......@@ -680,6 +680,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
Number<NPerThreadSubC>{},
Number<NPerThread>{}); // constexpr doesn't compile
// register
FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()];
......@@ -687,7 +688,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
......@@ -717,7 +717,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
bool even_loop = true;
#pragma unroll
for(unsigned k_begin = 0; k_begin + 1 < K;
for(unsigned k_begin = 0; k_begin + KPerThreadLoop < K;
k_begin += KPerThreadLoop, even_loop = !even_loop)
{ // loop over k
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
......@@ -727,7 +727,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0;
// preload next A, B
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A
......@@ -767,8 +766,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// last loop
{
even_loop = !even_loop;
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;
......
......@@ -19,9 +19,15 @@ template <unsigned GridSize,
unsigned CPerBlock,
unsigned BPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerRowPerCluster,
unsigned GemmMPerThreadSubC,
unsigned GemmNPerThreadSubC,
unsigned GemmMLevel0Cluster,
unsigned GemmNLevel0Cluster,
unsigned GemmMLevel1Cluster,
unsigned GemmNLevel1Cluster,
unsigned GemmKPerThreadLoop,
unsigned InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1,
unsigned WeiBlockCopyThreadPerDim0,
......@@ -179,6 +185,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
#if 0
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
......@@ -186,10 +193,24 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
true,
false,
false,
CPerThread,
GemmKPerThreadLoop,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
true>{};
#else
const auto blockwise_gemm =
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
decltype(c_kxb_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop>{};
#endif
// LDS: be careful of alignment
constexpr unsigned in_block_size =
......@@ -237,7 +258,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
{
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 1
#if 0
blockwise_gemm.Run
#else
blockwise_gemm.Run_RegisterDoubleBuffer
......@@ -251,13 +272,11 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
}
// output: register to global mem,
const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const auto c_thread_mtx_begin =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned b_thread_data_begin = matrix_c_index.col;
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
#if 0
if(get_block_1d_id() == 0)
......@@ -277,8 +296,11 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
{
for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
{
unsigned k_data = k_data_begin + k;
unsigned b_data = b_data_begin + b;
const auto c_thread_mtx_distance =
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row;
unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col;
unsigned n_data = b_data / (Hi * Wi);
unsigned itmp = b_data - n_data * (Hi * Wi);
......
......@@ -34,8 +34,8 @@ template <unsigned GridSize,
unsigned WeiBlockCopyThreadPerDim1,
unsigned InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead>
__global__ void
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InGlobalDesc,
__global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer(
InGlobalDesc,
Float* const __restrict__ p_in_global,
WeiGlobalDesc,
Float* const __restrict__ p_wei_global,
......@@ -123,7 +123,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InGlobalDesc,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
#elif 0
#elif 1
const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(in_cb_global_desc),
......@@ -149,7 +149,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InGlobalDesc,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif 0
#elif 1
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(wei_ek_global_desc),
......@@ -223,14 +223,12 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InGlobalDesc,
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
// register
Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
// LDS double buffer
__shared__ Float p_in_block_0[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block_0[max_align * ((wei_block_size + max_align - 1) / max_align)];
// set threadwise output tensor to 0
threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread);
__shared__ Float p_in_block_1[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block_1[max_align * ((wei_block_size + max_align - 1) / max_align)];
Float* p_in_global_block_offset =
p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin);
......@@ -238,28 +236,78 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InGlobalDesc,
Float* p_wei_global_block_offset =
p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
// preload data into LDS
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_0);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_0);
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0);
// register
Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
// set threadwise output tensor to 0
threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread);
bool even_loop = true;
for(unsigned c_block_data_begin = 0; c_block_data_begin + CPerBlock < C;
c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0),
__syncthreads())
even_loop = !even_loop)
{
// input: global mem to LDS,
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1;
Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1;
// weight: global mem to LDS,
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
Float* p_in_block_next = even_loop ? p_in_block_1 : p_in_block_0;
Float* p_wei_block_next = even_loop ? p_wei_block_1 : p_wei_block_0;
__syncthreads();
// load next data
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next);
// compute on current data
// a series of GEMM
for(unsigned s = 0; s < S; ++s)
{
for(unsigned r = 0; r < R; ++r)
{
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 0
blockwise_gemm.Run
#else
blockwise_gemm.Run_RegisterDoubleBuffer
#endif
(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
p_in_block_now + s * Wi + r,
p_out_thread,
f_accum);
}
}
}
// last computation
{
Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1;
Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1;
__syncthreads();
blockwise_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
p_in_block + s * Wi + r,
for(unsigned s = 0; s < S; ++s)
{
for(unsigned r = 0; r < R; ++r)
{
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 0
blockwise_gemm.Run
#else
blockwise_gemm.Run_RegisterDoubleBuffer
#endif
(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
p_in_block_now + s * Wi + r,
p_out_thread,
f_accum);
}
......
#pragma once
#include "common.cuh"
#include "ConstantTensorDescriptor.cuh"
#include "ConstantMatrixDescriptor.cuh"
#include "blockwise_4d_tensor_op.cuh"
#include "blockwise_2d_tensor_op.cuh"
#include "threadwise_2d_tensor_op.cuh"
#include "blockwise_gemm.cuh"
// define B = flatten(N, Hi, Wi)
template <unsigned GridSize,
unsigned BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
unsigned BPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned BPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerRowPerCluster,
unsigned InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1,
unsigned WeiBlockCopyThreadPerDim0,
unsigned WeiBlockCopyThreadPerDim1,
unsigned InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead>
__global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline(
InGlobalDesc,
Float* const __restrict__ p_in_global,
WeiGlobalDesc,
Float* const __restrict__ p_wei_global,
OutGlobalDesc,
Float* __restrict__ p_out_global)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_cnhw_global_desc = InGlobalDesc{};
constexpr auto wei_csrk_global_desc = WeiGlobalDesc{};
constexpr auto out_knhw_global_desc = OutGlobalDesc{};
constexpr unsigned C = in_cnhw_global_desc.GetLength(I0);
constexpr unsigned N = in_cnhw_global_desc.GetLength(I1);
constexpr unsigned Hi = in_cnhw_global_desc.GetLength(I2);
constexpr unsigned Wi = in_cnhw_global_desc.GetLength(I3);
constexpr unsigned K = out_knhw_global_desc.GetLength(I0);
constexpr unsigned Ho = out_knhw_global_desc.GetLength(I2);
constexpr unsigned Wo = out_knhw_global_desc.GetLength(I3);
constexpr unsigned S = wei_csrk_global_desc.GetLength(I1);
constexpr unsigned R = wei_csrk_global_desc.GetLength(I2);
constexpr unsigned B = N * Hi * Wi;
constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1);
// divide block work by 2d: [K, B]
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock;
const unsigned k_block_work_id = get_block_1d_id() / BBlockWork;
const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned b_block_data_begin = b_block_work_id * BPerBlock;
#if 0
if(get_thread_local_1d_id() == 0)
{
printf("K %u B %u, BGhostRead %u\n", K, B, BGhostRead);
printf("%u %u, KBlockWork %u BBlockWork %u, k_block_data_begin %u b_block_data_begin %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
KBlockWork,
BBlockWork,
k_block_data_begin,
b_block_data_begin);
}
#endif
// flattend (2d) tensor view of gridwise input
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * S * R, K>{});
// tensor view of blockwise input and weight
constexpr auto in_cb_block_desc =
make_ConstantTensorDescriptor(Sequence<CPerBlock, BPerBlock + BGhostRead>{});
constexpr auto wei_ek_block_desc =
make_ConstantTensorDescriptor(Sequence<CPerBlock * S * R, KPerBlock>{});
constexpr auto wei_csrk_block_desc =
make_ConstantTensorDescriptor(Sequence<CPerBlock, S, R, KPerBlock>{});
// tensor view of threadwise output in register
constexpr auto out_kb_thread_desc =
make_ConstantTensorDescriptor(Sequence<KPerThread, BPerThread>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc");
print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc");
print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc");
printf("KPerBlock %u\n", KPerBlock);
}
#endif
// blockwise in copy
// formmat is [CPerBlock,BPerBlock + BGhostRead]
#if 0
const auto blockwise_in_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths()),
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_in_copy = Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths()),
InBlockCopyDataPerRead>{};
#endif
// blockwise wei copy
// format is [CPerBlock*S*R,KPerBlock]
#if 0
const auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{};
#endif
// a series of blockwise GEMM
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
// a_mtx[C,K] is a sub-matrix of wei_block[S,R,C,K]
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
// c_mtx[K,B] is out_block[K,B]
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}); // constexpr doesn't compile
const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{},
Number<BPerBlock>{},
Number<in_cb_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
decltype(c_kxb_thread_mtx_desc),
true,
false,
false,
CPerThread,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
true>{};
// LDS
constexpr unsigned in_block_size = in_cb_block_desc.GetElementSpace();
constexpr unsigned wei_block_size = wei_csrk_block_desc.GetElementSpace();
// LDS double buffer
__shared__ Float p_in_block_0[in_block_size];
__shared__ Float p_wei_block_0[wei_block_size];
__shared__ Float p_in_block_1[in_block_size];
__shared__ Float p_wei_block_1[wei_block_size];
// register
Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
Float* p_in_global_block_offset =
p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin);
Float* p_wei_global_block_offset =
p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
// prelog : preload data
// input: global mem to LDS,
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_0);
// weight: global mem to LDS,
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_0);
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I2);
// set threadwise output tensor to 0
threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread);
bool even_loop = true;
for(unsigned c_block_data_begin = CPerBlock; c_block_data_begin < C;
c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I2),
even_loop = !even_loop)
{
__syncthreads();
Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1;
Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1;
Float* p_in_block_next = even_loop ? p_in_block_1 : p_in_block_0;
Float* p_wei_block_next = even_loop ? p_wei_block_1 : p_wei_block_0;
// input: global mem to LDS,
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
// weight: global mem to LDS,
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next);
// a series of GEMM
for(unsigned s = 0; s < S; ++s)
{
for(unsigned r = 0; r < R; ++r)
{
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
blockwise_gemm.Run(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
p_in_block_now + s * Wi + r,
p_out_thread,
f_accum);
}
}
}
// last computation
{
__syncthreads();
Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1;
Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1;
// a series of GEMM
for(unsigned s = 0; s < S; ++s)
{
for(unsigned r = 0; r < R; ++r)
{
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
blockwise_gemm.Run(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(s, r, 0, 0),
p_in_block_now + s * Wi + r,
p_out_thread,
f_accum);
}
}
}
// output: register to global mem,
const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned b_thread_data_begin = matrix_c_index.col;
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
matrix_c_index.row,
matrix_c_index.col,
k_data_begin,
b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
}
#endif
for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
{
for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
{
unsigned k_data = k_data_begin + k;
unsigned b_data = b_data_begin + b;
unsigned n_data = b_data / (Hi * Wi);
unsigned itmp = b_data - n_data * (Hi * Wi);
unsigned h_data = itmp / Wi;
unsigned w_data = itmp - h_data * Wi;
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
k,
b,
k_data,
n_data,
h_data,
w_data,
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]);
}
#endif
if(n_data < N && h_data < Ho && w_data < Wo)
{
#if 1
p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] =
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)];
#endif
}
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment