Commit 3439e4b5 authored by Chao Liu's avatar Chao Liu
Browse files

padding works (sort of), but code looks ugly. Tuned some resnet configs

parent 8bd6ea1a
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh" #include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh"
#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh" #include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh" #include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh" #include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
//#include "device_winograd_convolution.cuh" //#include "device_winograd_convolution.cuh"
...@@ -107,20 +108,31 @@ auto make_TensorDescriptor(TConstTensorDesc) ...@@ -107,20 +108,31 @@ auto make_TensorDescriptor(TConstTensorDesc)
return TensorDescriptor(lengths, strides); return TensorDescriptor(lengths, strides);
} }
template <class T> template <class T, class LowerPads, class UpperPads>
void host_direct_convolution(const Tensor<T>& in_nchw, const Tensor<T>& wei_kcsr, Tensor<T>& out) void host_direct_convolution(
const Tensor<T>& in_nchw, const Tensor<T>& wei_kcsr, Tensor<T>& out, LowerPads, UpperPads)
{ {
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
auto f = [&](auto n, auto k, auto ho, auto wo) { auto f = [&](auto n, auto k, auto ho, auto wo) {
double v = 0; double v = 0;
for(int c = 0; c < wei_kcsr.mDesc.GetLengths()[1]; ++c) for(int c = 0; c < wei_kcsr.mDesc.GetLengths()[1]; ++c)
{ {
for(int y = 0; y < wei_kcsr.mDesc.GetLengths()[2]; ++y) for(int y = 0; y < wei_kcsr.mDesc.GetLengths()[2]; ++y)
{ {
int hi = ho + y; int hi = ho + y - h_pad_low;
for(int x = 0; x < wei_kcsr.mDesc.GetLengths()[3]; ++x) for(int x = 0; x < wei_kcsr.mDesc.GetLengths()[3]; ++x)
{ {
int wi = wo + x; int wi = wo + x - w_pad_low;
v += in_nchw(n, c, hi, wi) * wei_kcsr(k, c, y, x); if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in_nchw.mDesc.GetLengths()[3])
{
v += in_nchw(n, c, hi, wi) * wei_kcsr(k, c, y, x);
}
} }
} }
} }
...@@ -136,10 +148,9 @@ void host_direct_convolution(const Tensor<T>& in_nchw, const Tensor<T>& wei_kcsr ...@@ -136,10 +148,9 @@ void host_direct_convolution(const Tensor<T>& in_nchw, const Tensor<T>& wei_kcsr
f_par(std::thread::hardware_concurrency()); f_par(std::thread::hardware_concurrency());
} }
template <class T> template <class T, class LowerPads, class UpperPads>
void host_winograd_3x3_convolution(const Tensor<T>& in_nchw, void host_winograd_3x3_convolution(
const Tensor<T>& wei_kcsr, const Tensor<T>& in_nchw, const Tensor<T>& wei_kcsr, Tensor<T>& out, LowerPads, UpperPads)
Tensor<T>& out)
{ {
constexpr std::size_t OutTileSizeH = 2; constexpr std::size_t OutTileSizeH = 2;
constexpr std::size_t OutTileSizeW = 2; constexpr std::size_t OutTileSizeW = 2;
...@@ -156,6 +167,12 @@ void host_winograd_3x3_convolution(const Tensor<T>& in_nchw, ...@@ -156,6 +167,12 @@ void host_winograd_3x3_convolution(const Tensor<T>& in_nchw,
std::size_t HO = out.mDesc.GetLengths()[2]; std::size_t HO = out.mDesc.GetLengths()[2];
std::size_t WO = out.mDesc.GetLengths()[3]; std::size_t WO = out.mDesc.GetLengths()[3];
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
std::size_t InTileSizeH = OutTileSizeH + S - 1; std::size_t InTileSizeH = OutTileSizeH + S - 1;
std::size_t InTileSizeW = OutTileSizeW + R - 1; std::size_t InTileSizeW = OutTileSizeW + R - 1;
...@@ -171,11 +188,20 @@ void host_winograd_3x3_convolution(const Tensor<T>& in_nchw, ...@@ -171,11 +188,20 @@ void host_winograd_3x3_convolution(const Tensor<T>& in_nchw,
auto f_in_hold = [&](auto n, auto c, auto y, auto x) { auto f_in_hold = [&](auto n, auto c, auto y, auto x) {
for(int j = 0; j < InTileSizeH; ++j) for(int j = 0; j < InTileSizeH; ++j)
{ {
std::size_t hi = OutTileSizeH * y + j; int hi = OutTileSizeH * y + j - h_pad_low;
for(int i = 0; i < InTileSizeW; ++i) for(int i = 0; i < InTileSizeW; ++i)
{ {
std::size_t wi = OutTileSizeW * x + i; int wi = OutTileSizeW * x + i - w_pad_low;
in_hold(n, c, y, x, j, i) = in_nchw(n, c, hi, wi);
if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in_nchw.mDesc.GetLengths()[3])
{
in_hold(n, c, y, x, j, i) = in_nchw(n, c, hi, wi);
}
else
{
in_hold(n, c, y, x, j, i) = T(0);
}
} }
} }
}; };
...@@ -406,7 +432,7 @@ int main() ...@@ -406,7 +432,7 @@ int main()
constexpr unsigned K = 64; constexpr unsigned K = 64;
constexpr unsigned S = 7; constexpr unsigned S = 7;
constexpr unsigned R = 7; constexpr unsigned R = 7;
#elif 1 #elif 0
// 3x3, 58x58 // 3x3, 58x58
constexpr unsigned N = 16; constexpr unsigned N = 16;
constexpr unsigned C = 128; constexpr unsigned C = 128;
...@@ -415,12 +441,63 @@ int main() ...@@ -415,12 +441,63 @@ int main()
constexpr unsigned K = 256; constexpr unsigned K = 256;
constexpr unsigned S = 3; constexpr unsigned S = 3;
constexpr unsigned R = 3; constexpr unsigned R = 3;
#elif 0
// 3x3 filter, 58x58 image, 0x0 padding
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 256;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
#elif 1
// 3x3 filter, 56x56 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 56;
constexpr unsigned WI = 56;
constexpr unsigned K = 256;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
#elif 0
// 3x3 filter, 28x28 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 512;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
#elif 0
// 3x3 filter, 20x84 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 20;
constexpr unsigned WI = 84;
constexpr unsigned K = 256;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
#endif #endif
auto lower_pads = Sequence<HPad, WPad>{};
auto upper_pads = Sequence<HPad, WPad>{};
auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{}); auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{});
auto wei_kcsr_desc = make_ConstantTensorDescriptor(Sequence<K, C, S, R>{}); auto wei_kcsr_desc = make_ConstantTensorDescriptor(Sequence<K, C, S, R>{});
auto out_nkhw_desc = auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor(
get_convolution_output_default_4d_tensor_descriptor(in_nchw_desc, wei_kcsr_desc); in_nchw_desc, wei_kcsr_desc, lower_pads, upper_pads);
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_ConstantTensorDescriptor(wei_kcsr_desc, std::cout << "wei_kcsr_desc: "); ostream_ConstantTensorDescriptor(wei_kcsr_desc, std::cout << "wei_kcsr_desc: ");
...@@ -443,6 +520,7 @@ int main() ...@@ -443,6 +520,7 @@ int main()
unsigned nrepeat = 50; unsigned nrepeat = 50;
#if 0
#if 0 #if 0
device_direct_convolution_1 device_direct_convolution_1
#elif 0 #elif 0
...@@ -451,7 +529,7 @@ int main() ...@@ -451,7 +529,7 @@ int main()
device_implicit_gemm_convolution_1_nchw_kcsr device_implicit_gemm_convolution_1_nchw_kcsr
#elif 0 #elif 0
device_implicit_gemm_convolution_1_nchw_srck_nkhw device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 1 #elif 0
device_implicit_gemm_convolution_1_chwn_csrk_khwn device_implicit_gemm_convolution_1_chwn_csrk_khwn
#elif 0 #elif 0
device_implicit_gemm_convolution_2_cnhw_srck_knhw device_implicit_gemm_convolution_2_cnhw_srck_knhw
...@@ -459,15 +537,28 @@ int main() ...@@ -459,15 +537,28 @@ int main()
device_winograd_convolution device_winograd_convolution
#endif #endif
(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat);
#endif
#if 1
device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding(in_nchw_desc,
in_nchw,
wei_kcsr_desc,
wei_kcsr,
out_nkhw_desc,
out_nkhw_device,
lower_pads,
upper_pads,
nrepeat);
#endif
#if 1 #if 1
if(S == 3 && R == 3) if(S == 3 && R == 3)
{ {
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host); host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
} }
else else
{ {
host_direct_convolution(in_nchw, wei_kcsr, out_nkhw_host); host_direct_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
} }
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
#endif #endif
......
#pragma once
#include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding.cuh"
#include <unistd.h>
template <class T, class InDesc, class WeiDesc, class OutDesc, class LowerPads, class UpperPads>
void device_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcsr,
OutDesc,
Tensor<T>& out_nkhw,
LowerPads,
UpperPads,
unsigned nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcsr_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcsr_desc.GetLength(I0);
constexpr unsigned C = wei_kcsr_desc.GetLength(I1);
constexpr unsigned S = wei_kcsr_desc.GetLength(I2);
constexpr unsigned R = wei_kcsr_desc.GetLength(I3);
// reorder weight
auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence<C, S, R, K>{});
ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: ");
Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc));
auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto s, auto r) {
wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r);
};
make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, S, R)(
std::thread::hardware_concurrency());
// reorder input
auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{});
ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: ");
Tensor<T> in_chwn(make_TensorDescriptor(in_chwn_desc));
auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) {
in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi);
};
make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)(
std::thread::hardware_concurrency());
// output
auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence<K, Ho, Wo, N>{});
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
std::size_t data_sz = sizeof(T);
DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace());
DeviceMem wei_csrk_device_buf(data_sz * wei_csrk.mDesc.GetElementSpace());
DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace());
in_chwn_device_buf.ToDevice(in_chwn.mData.data());
wei_csrk_device_buf.ToDevice(wei_csrk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 0
constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 1;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 1;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 8;
#elif 0
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 0
// for 5x5, 36x36
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 0
// for 7x7, 38x38
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 0
// for 3x3, 56x56
constexpr unsigned NPerBlock = 32;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 1
// 3x3 56x56, NKC = 16,256,128, with padding
// 3x3 28x28, NKC = 16,512,256, with padding
// 3x3 20x84, NKC = 16,256,256, with padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#endif
constexpr unsigned GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
dim3 block_dim(BlockSize);
dim3 grid_dim(GridSize);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
{
cudaEvent_t start, stop;
float elapsedTime;
cudaEventCreate(&start);
cudaEventRecord(start, 0);
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding<GridSize,
BlockSize,
T,
decltype(in_chwn_desc),
decltype(wei_csrk_desc),
decltype(out_khwn_desc),
LowerPads,
UpperPads,
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
CPerThread,
HoPerThread,
WoPerThread>
<<<grid_dim, block_dim>>>(static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_csrk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
cudaEventCreate(&stop);
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Elapsed time : %f ms\n", elapsedTime);
usleep(10000);
}
checkCudaErrors(cudaGetLastError());
out_khwn_device_buf.FromDevice(out_khwn.mData.data());
// reorder output
auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) {
out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n);
};
make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)(
std::thread::hardware_concurrency());
}
...@@ -211,6 +211,133 @@ struct blockwise_4d_tensor_copy_1 ...@@ -211,6 +211,133 @@ struct blockwise_4d_tensor_copy_1
} }
}; };
template <unsigned BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class DstOpLengths,
class GlobalLowerPads>
struct blockwise_chwn_tensor_copy_with_padding
{
__device__ void run(Float* const __restrict__ p_src,
unsigned c_block_data_begin,
unsigned ho_block_data_begin,
unsigned wo_block_data_begin,
unsigned n_block_data_begin,
Float* __restrict__ p_dst,
unsigned h_block_pad_low,
unsigned w_block_pad_low,
unsigned h_block_pad_up,
unsigned w_block_pad_up) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor(DstOpLengths{});
constexpr auto h_global_pad_low = GlobalLowerPads{}.Get(I0);
constexpr auto w_global_pad_low = GlobalLowerPads{}.Get(I1);
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
Float* const p_src_tmp =
p_src + src_desc.Get1dIndex(c_block_data_begin,
(ho_block_data_begin + h_block_pad_low) - h_global_pad_low,
(wo_block_data_begin + w_block_pad_low) - w_global_pad_low,
n_block_data_begin);
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(src_desc, "src_desc: ");
print_ConstantTensorDescriptor(dst_desc, "dst_desc: ");
print_ConstantTensorDescriptor(ref_desc, "ref_desc: ");
printf("%u %u, \t"
"h_global_pad_low %u w_global_pad_low %u \t"
"h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u \t"
"\n",
get_block_1d_id(),
get_thread_local_1d_id(),
h_global_pad_low,
w_global_pad_low,
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
}
#endif
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
{
unsigned is = threadIdx.x + iloop * BlockSize;
unsigned did[4];
did[0] = is / ref_desc.GetStride(I0);
is -= did[0] * ref_desc.GetStride(I0);
did[1] = is / ref_desc.GetStride(I1);
is -= did[1] * ref_desc.GetStride(I1);
did[2] = is / ref_desc.GetStride(I2);
is -= did[2] * ref_desc.GetStride(I2);
did[3] = is / ref_desc.GetStride(I3);
const unsigned bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
p_dst[bindex] =
(did[1] < h_block_pad_low || did[1] + h_block_pad_up >= ref_desc.GetLength(I1) ||
did[2] < w_block_pad_low || did[2] + w_block_pad_up >= ref_desc.GetLength(I2))
? Float(0)
: p_src_tmp[src_desc.Get1dIndex(did[0], did[1], did[2], did[3])];
}
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
if(has_tail)
{
unsigned is = threadIdx.x + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
unsigned did[4];
did[0] = is / ref_desc.GetStride(I0);
is -= did[0] * ref_desc.GetStride(I0);
did[1] = is / ref_desc.GetStride(I1);
is -= did[1] * ref_desc.GetStride(I1);
did[2] = is / ref_desc.GetStride(I2);
is -= did[2] * ref_desc.GetStride(I2);
did[3] = is / ref_desc.GetStride(I3);
const unsigned bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
p_dst[bindex] =
(did[1] < h_block_pad_low ||
did[1] + h_block_pad_up >= ref_desc.GetLength(I1) ||
did[2] < w_block_pad_low || did[2] + w_block_pad_up >= ref_desc.GetLength(I2))
? Float(0)
: p_src_tmp[src_desc.Get1dIndex(did[0], did[1], did[2], did[3])];
}
}
}
};
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths> template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
struct blockwise_4d_tensor_copy_dummy struct blockwise_4d_tensor_copy_dummy
{ {
......
...@@ -27,8 +27,45 @@ __host__ __device__ constexpr auto get_convolution_output_default_4d_tensor_desc ...@@ -27,8 +27,45 @@ __host__ __device__ constexpr auto get_convolution_output_default_4d_tensor_desc
constexpr auto S = wei_desc.GetLength(I2); constexpr auto S = wei_desc.GetLength(I2);
constexpr auto R = wei_desc.GetLength(I3); constexpr auto R = wei_desc.GetLength(I3);
constexpr auto HO = HI - S + 1; constexpr auto HO = HI + 1 - S;
constexpr auto WO = WI - R + 1; constexpr auto WO = WI + 1 - R;
return make_ConstantTensorDescriptor(Sequence<N, K, HO, WO>{});
}
template <class InDesc, class WeiDesc, class LowerPads, class UpperPads>
__host__ __device__ constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(
InDesc, WeiDesc, LowerPads, UpperPads)
{
constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
static_assert(in_desc.GetDimension() == 4, "input nDim is not 4");
static_assert(wei_desc.GetDimension() == 4, "weight nDim is not 4");
static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
"input & weight dimension not consistent");
constexpr auto N = in_desc.GetLength(I0);
constexpr auto HI = in_desc.GetLength(I2);
constexpr auto WI = in_desc.GetLength(I3);
constexpr auto K = wei_desc.GetLength(I0);
constexpr auto S = wei_desc.GetLength(I2);
constexpr auto R = wei_desc.GetLength(I3);
constexpr auto HPadLow = LowerPads{}.Get(I0);
constexpr auto WPadLow = LowerPads{}.Get(I1);
constexpr auto HPadUp = UpperPads{}.Get(I0);
constexpr auto WPadUp = UpperPads{}.Get(I1);
constexpr auto HO = HI + HPadLow + HPadUp + 1 - S;
constexpr auto WO = WI + WPadLow + WPadUp + 1 - R;
return make_ConstantTensorDescriptor(Sequence<N, K, HO, WO>{}); return make_ConstantTensorDescriptor(Sequence<N, K, HO, WO>{});
} }
#pragma once
#include "common.cuh"
#include "ConstantTensorDescriptor.cuh"
#include "ConstantMatrixDescriptor.cuh"
#include "blockwise_4d_tensor_op.cuh"
#include "threadwise_4d_tensor_op.cuh"
#include "gemm.cuh"
template <unsigned GridSize,
unsigned BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class LowerPads,
class UpperPads,
unsigned NPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned HoPerBlock,
unsigned WoPerBlock,
unsigned NPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned HoPerThread,
unsigned WoPerThread>
__global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_with_padding(
Float* const __restrict__ p_in_global,
Float* const __restrict__ p_wei_global,
Float* __restrict__ p_out_global)
{
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0");
static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock,
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_chwn_global_desc = InGlobalDesc{};
constexpr auto wei_csrk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr unsigned C = in_chwn_global_desc.GetLength(I0);
constexpr unsigned K = out_khwn_global_desc.GetLength(I0);
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1);
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2);
constexpr unsigned N = out_khwn_global_desc.GetLength(I3);
constexpr unsigned S = wei_csrk_global_desc.GetLength(I1);
constexpr unsigned R = wei_csrk_global_desc.GetLength(I2);
constexpr unsigned HPadLow = LowerPads{}.Get(I0);
constexpr unsigned WPadLow = LowerPads{}.Get(I1);
constexpr unsigned HPadUp = UpperPads{}.Get(I0);
constexpr unsigned WPadUp = UpperPads{}.Get(I1);
constexpr unsigned HiPerBlock = HoPerBlock + S - 1;
constexpr unsigned WiPerBlock = WoPerBlock + R - 1;
// divide block work: [K, Ho, Wo, N]
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const unsigned w_block_work_id = itmp / NBlockWork;
const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock;
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
// tensor view of blockwise input and weight in LDS
constexpr auto in_chwn_block_desc =
make_ConstantTensorDescriptor(Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{});
constexpr auto wei_csrk_block_desc =
make_ConstantTensorDescriptor(Sequence<CPerBlock, S, R, KPerBlock>{});
// tensor view of threadwise output in register
constexpr auto out_hkwn_thread_desc =
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc");
print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc");
}
#endif
// blockwise copy
// input: format is [C, Hi, Wi, N]
const unsigned h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0;
const unsigned w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0;
const unsigned h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0;
const unsigned w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0;
#if 0
if(get_thread_local_1d_id() == 0)
;
{
printf(
"%u %u, h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
}
#endif
constexpr auto blockwise_in_copy =
blockwise_chwn_tensor_copy_with_padding<BlockSize,
Float,
decltype(in_chwn_global_desc),
decltype(in_chwn_block_desc),
decltype(in_chwn_block_desc.GetLengths()),
LowerPads>{};
// weight: format is [S,R,C,K]
constexpr auto blockwise_wei_copy =
blockwise_4d_tensor_copy_1<BlockSize,
Float,
decltype(wei_csrk_global_desc),
decltype(wei_csrk_block_desc),
decltype(wei_csrk_block_desc.GetLengths())>{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N]
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{},
Number<KPerBlock>{},
Number<wei_csrk_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
const auto b_cxwn_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{},
Number<WoPerBlock * NPerBlock>{},
Number<in_chwn_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerThread>{}, Number<WoPerThread * NPerThread>{}); // constexpr doesn't compile
const auto blockwise_batch_gemm =
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxwn_block_mtx_desc),
decltype(c_kxwn_thread_mtx_desc),
true,
false,
false,
0,
in_chwn_block_desc.GetStride(I1),
out_hkwn_thread_desc.GetStride(
I0),
HoPerBlock,
HoPerThread,
CPerThread,
true>{};
// LDS
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
constexpr unsigned wei_block_size = wei_csrk_block_desc.GetElementSpace();
__shared__ Float p_in_block[in_block_size];
__shared__ Float p_wei_block[wei_block_size];
// register
Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock, __syncthreads())
{
#if 1
// input: global mem to LDS,
blockwise_in_copy.run(p_in_global,
c_block_data_begin,
ho_block_data_begin,
wo_block_data_begin,
n_block_data_begin,
p_in_block,
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
#endif
#if 1
// weight: global mem to LDS,
blockwise_wei_copy.run(p_wei_global + wei_csrk_global_desc.Get1dIndex(
c_block_data_begin, 0, 0, k_block_data_begin),
p_wei_block);
#endif
__syncthreads();
// a series of batched GEMM
for(unsigned s = 0; s < S; ++s)
{
for(unsigned r = 0; r < R; ++r)
{
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
blockwise_batch_gemm.run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0),
p_out_thread,
f_accum);
}
}
}
const auto matrix_c_index =
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock;
const unsigned n_thread_data_begin =
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
get_block_1d_id(), get_thread_local_1d_id(),
ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin,
ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin,
p_out_thread[0]);
#endif
// output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{};
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
out_hkwn_thread_desc,
p_out_thread,
out_khwn_global_desc,
p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_hkwn_thread_desc.GetLengths(),
reorder_khwn_from_hkwn);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment