Commit f67adee3 authored by Chao Liu's avatar Chao Liu
Browse files

tweaking bwd data

parent c6e3d607
......@@ -20,8 +20,8 @@ template <index_t GridSize,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InputLeftPads,
typename InputRightPads,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
......@@ -113,6 +113,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0 // debug
// output tensor
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
......@@ -138,7 +139,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InputLeftPads, InputRightPads>{}),
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
......@@ -155,28 +156,116 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#if 0
constexpr index_t HtildaLeft = LeftPads{}[0] / ConvStrides{}[0];
constexpr idext_t HtildaRight = math::integer_divide_ceil
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#else
#if 1
constexpr index_t HtildaLeft =
math::integer_divide_floor(InLeftPads{}[0], ConvStrides{}[0]);
constexpr index_t WtildaLeft =
math::integer_divide_floor(InLeftPads{}[1], ConvStrides{}[1]);
constexpr index_t WtidaTrimLeft = LeftPads{}[0] / ConvStrides{}[0];
constexpr index_t HtildaRight =
math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1 - ConvDilations{}[0] * (Ytilda - 1),
ConvStrides{}[0]) +
1;
constexpr index_t WtildaRight =
math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1 - ConvDilations{}[1] * (Xtilda - 1),
ConvStrides{}[1]) +
1;
#else
constexpr index_t HtildaLeft = 0;
constexpr index_t WtildaLeft = 0;
constexpr index_t HtildaRight = Htilda;
constexpr index_t WtildaRight = Wtilda;
#endif
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
// output tensor
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto out_gemmk_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Trim<Sequence<Htilda, Wtilda>,
Sequence<Ytilda, Htilda>,
Sequence<ConvDilationH, ConvStrideH, 0>>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<C>{},
Embed<Hip, Sequence<Ytilda, Htilda>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<Xtilda, Wtilda>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#endif
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto in_gemmm_gemmn_global_desc =
transform_tensor_descriptor(in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#endif
// GEMM
constexpr auto gridwise_gemm =
......
......@@ -122,7 +122,7 @@ struct Trim
__host__ __device__ static constexpr auto GetUpperLengths()
{
return LowerLengths{} - LeftTrims{} + RightTrims{};
return LowerLengths{} - LeftTrims{} - RightTrims{};
}
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
......
......@@ -49,6 +49,12 @@ struct integer_divide_ceiler
}
};
template <class X, class Y>
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
{
return x / y;
}
template <class X, class Y>
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
{
......
......@@ -27,12 +27,16 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
{
using namespace ck;
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
......@@ -51,7 +55,7 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
......@@ -153,14 +157,37 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
constexpr index_t Htilda = Ho + right_pad_ho;
constexpr index_t Wtilda = Wo + right_pad_wo;
constexpr index_t Htilda = Ho + (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
constexpr index_t Wtilda = Wo + (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
#if 0 // debug
constexpr index_t GemmM = C * Ytilda * Xtilda;
constexpr index_t GemmN = N * Htilda * Wtilda;
#else
#if 1
constexpr index_t HtildaLeft = math::integer_divide_floor(InLeftPads{}[0], ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(InLeftPads{}[1], ConvStrides{}[1]);
constexpr index_t HtildaRight =
math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1 - ConvDilations{}[0] * (Ytilda - 1),
ConvStrides{}[0]) +
1;
constexpr index_t WtildaRight =
math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1 - ConvDilations{}[1] * (Xtilda - 1),
ConvStrides{}[1]) +
1;
#else
constexpr index_t HtildaLeft = 0;
constexpr index_t WtildaLeft = 0;
constexpr index_t HtildaRight = Htilda;
constexpr index_t WtildaRight = Wtilda;
#endif
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
constexpr index_t GemmM = C * Ytilda * Xtilda;
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
#endif
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
......
......@@ -141,13 +141,13 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>;
#elif 1
#elif 0
// 1x7 filter, 23x23 input
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t C = 128;
constexpr index_t HI = 23;
constexpr index_t WI = 23;
constexpr index_t K = 1024;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 7;
......@@ -174,10 +174,10 @@ int main(int argc, char* argv[])
#elif 1
// 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t C = 128;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 1024;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 7;
......@@ -186,13 +186,13 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 1
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t C = 128;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 1024;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
......
......@@ -29,7 +29,7 @@ int main(int argc, char* argv[])
{
using namespace ck;
#if 1
#if 0
// 1x1
constexpr index_t N = 256;
constexpr index_t C = 1024;
......@@ -44,7 +44,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
#elif 0
// 1x7
constexpr index_t N = 128;
constexpr index_t C = 1024;
......@@ -59,7 +59,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 1
#elif 0
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
......@@ -72,6 +72,21 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
......@@ -281,7 +296,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128;
......
......@@ -14,13 +14,11 @@ cmake
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D DEVICE_BACKEND=NVIDIA \
-D CUDA_COMMON_INCLUDE_DIR="/package/install/cuda/10.1/NVIDIA_CUDA-10.1_Samples/common/inc" \
-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61" \
-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \
${MY_PROJECT_SOURCE}
#-D BOOST_ROOT="/package/install/boost_1.67.0" \
#-D CMAKE_CUDA_COMPILER="/package/install/cuda_10.0/bin/nvcc" \
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61" \
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -Xptxas -v -maxrregcount=128" \
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -gencode=arch=compute_70,code=sm_70" \
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -gencode=arch=compute_70,code=sm_70 -Xptxas -v -maxrregcount=128" \
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment