Commit 3f81301f authored by Chao Liu's avatar Chao Liu
Browse files

tweaking bwd data

parent f67adee3
......@@ -227,7 +227,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, true>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
......@@ -236,11 +236,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<C>{},
Embed<Hip, Sequence<Ytilda, Htilda>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<Xtilda, Wtilda>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip,
Sequence<Ytilda, Htilda>,
Sequence<ConvDilationH, ConvStrideH, 0>,
true>{},
Embed<Wip,
Sequence<Xtilda, Wtilda>,
Sequence<ConvDilationW, ConvStrideW, 0>,
true>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......
......@@ -48,7 +48,10 @@ struct PassThrough
};
// LowerLengths: Sequence<...>
template <typename LowerLengths, typename LeftPads, typename RightPads>
template <typename LowerLengths,
typename LeftPads,
typename RightPads,
bool SkipIsValidCheck = false>
struct Pad
{
static constexpr index_t nDim = LowerLengths::Size();
......@@ -89,6 +92,12 @@ struct Pad
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
#if 1 // debug
if(SkipIsValidCheck)
{
return true;
}
#endif
bool flag = true;
for(index_t i = 0; i < nDim; ++i)
......@@ -366,7 +375,10 @@ struct UnMerge
// UpperLengths: Sequence<...>
// Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template <index_t LowerLength, typename UpperLengths, typename Coefficients>
template <index_t LowerLength,
typename UpperLengths,
typename Coefficients,
bool SkipIsValidCheck = false>
struct Embed
{
static constexpr index_t nDimLow = 1;
......@@ -418,6 +430,12 @@ struct Embed
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
#if 1 // debug
if(SkipIsValidCheck)
{
return true;
}
#endif
bool flag = true;
index_t ncorner = 1;
......
......@@ -55,7 +55,7 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
#if 0
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
......@@ -115,7 +115,7 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
#elif 0
// BlockSize = 256, each thread hold 64 data
// for 1x1 weight, 8x8 input
constexpr index_t BlockSize = 256;
......
......@@ -49,8 +49,7 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
using LeftPads = Sequence<0, 0> using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 8x8 image
constexpr index_t N = 256;
......@@ -142,27 +141,27 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>;
#elif 0
// 1x7 filter, 23x23 input
// 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 23;
constexpr index_t WI = 23;
constexpr index_t K = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 1024;
constexpr index_t Y = 1;
constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 0
// 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 128;
constexpr index_t K = 1024;
constexpr index_t Y = 7;
constexpr index_t X = 1;
......@@ -172,27 +171,12 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 1
// 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 128;
constexpr index_t K = 1024;
constexpr index_t Y = 3;
constexpr index_t X = 3;
......
......@@ -74,7 +74,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128;
constexpr index_t C = 128;
......@@ -328,35 +328,35 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>;
#elif 0
// 7x1 filter, 3x0 pad, 17x17 input
// 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 128;
constexpr index_t Y = 7;
constexpr index_t X = 1;
constexpr index_t Y = 1;
constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 1
// 1x7 filter, 0x3 pad, 17x17 input
// 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 7;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#endif
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment