Commit c6e3d607 authored by Chao Liu's avatar Chao Liu
Browse files

tweaking bwd data

parent bfba60cf
...@@ -142,18 +142,35 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -142,18 +142,35 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor( constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<C>{},
Embed<Hip, Sequence<Ytilda, Htilda>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<Xtilda, Wtilda>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#if 0
constexpr index_t HtildaLeft = LeftPads{}[0] / ConvStrides{}[0];
constexpr idext_t HtildaRight = math::integer_divide_ceil
constexpr index_t WtidaTrimLeft = LeftPads{}[0] / ConvStrides{}[0];
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<C>{}, PassThrough<C>{},
Embed<Hi + InputLeftPads::At(0) + InputRightPads::At(0), Trim<Sequence<Htilda, Wtilda>,
Sequence<Ytilda, Htilda>, Sequence<Ytilda, Htilda>,
Sequence<ConvDilationH, ConvStrideH, 0>>{}, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wi + InputLeftPads::At(1) + InputRightPads::At(1),
Sequence<Xtilda, Wtilda>,
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#endif
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
......
...@@ -100,6 +100,52 @@ struct Pad ...@@ -100,6 +100,52 @@ struct Pad
} }
}; };
// LowerLengths: Sequence<...>
template <typename LowerLengths, typename LeftTrims, typename RightTrims>
struct Trim
{
static constexpr index_t nDim = LowerLengths::Size();
using LowerIndex = MultiIndex<nDim>;
using UpperIndex = MultiIndex<nDim>;
__host__ __device__ explicit constexpr Trim()
{
static_assert(LowerLengths::GetSize() == nDim && LeftTrims::GetSize() == nDim &&
RightTrims::GetSize() == nDim,
"wrong! # of dimensions not consistent");
}
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return LowerLengths{} - LeftTrims{} + RightTrims{};
}
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
return idx_up + LeftTrims{};
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return idx_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
};
// LowerLengths: Sequence<...> // LowerLengths: Sequence<...>
template <typename LowerLengths> template <typename LowerLengths>
struct Merge struct Merge
......
...@@ -36,7 +36,7 @@ int main(int argc, char* argv[]) ...@@ -36,7 +36,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 3x3, 28x28 // 3x3, 28x28
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -141,6 +141,21 @@ int main(int argc, char* argv[]) ...@@ -141,6 +141,21 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>; using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>;
#elif 1
// 1x7 filter, 23x23 input
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 23;
constexpr index_t WI = 23;
constexpr index_t K = 1024;
constexpr index_t Y = 1;
constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 7x1 filter, 3x0 pad, 17x17 input // 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
...@@ -156,7 +171,7 @@ int main(int argc, char* argv[]) ...@@ -156,7 +171,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>; using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>;
#elif 0 #elif 1
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 1024;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment