Commit b8385cca authored by Chao Liu's avatar Chao Liu
Browse files

change Trim to Slice

parent 0368045e
......@@ -9,7 +9,7 @@
namespace ck {
// GemmM = C * Ytilda * Xtilda;
// GemmN = N * Htilda * Wtilda;
// GemmN = N * HtildaNonZero * WtildaNonZero;
// GemmK = K * Ydot * Xdot;
template <index_t GridSize,
index_t BlockSize,
......@@ -149,9 +149,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
PassThrough<K>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>,
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}),
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
......@@ -205,9 +205,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
PassThrough<C>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>,
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}),
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
......
......@@ -9,9 +9,9 @@
namespace ck {
// Ytilda*Xtilda number of GEMMs
// GemmM = C
// GemmN = N * Htilda * Wtilda;
// GemmK = K * slice(Ydot) * slice(Xdot);
// GemmM = C;
// GemmN = N * HtildaNonZero * WtildaNonZero;
// GemmK = K * YdotNonZero * XdotNonZero;
template <index_t GridSize,
index_t BlockSize,
typename Float,
......@@ -184,9 +184,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough<K>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>,
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}),
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
......@@ -233,9 +233,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough<C>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>,
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}),
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
......@@ -265,12 +265,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Trim<Sequence<Ydot, Xdot>,
Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<Ydot - YdotNonZero, Xdot - XdotNonZero>>{},
Trim<Sequence<Ytilda, Xtilda>,
Sequence<YdotNonZero, XdotNonZero>>{},
Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}),
Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(
......@@ -291,9 +291,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough<K>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Trim<Sequence<Ydot, Xdot>,
Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<Ydot - YdotNonZero, Xdot - XdotNonZero>>{}),
Sequence<YdotNonZero, XdotNonZero>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
......@@ -320,9 +320,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough<C>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Trim<Sequence<Ytilda, Xtilda>,
Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}),
Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
......
......@@ -157,9 +157,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough<K>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>,
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}),
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
......@@ -206,9 +206,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough<C>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>,
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}),
Sequence<HtildaRight, WtildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
......@@ -227,12 +227,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Trim<Sequence<Ydot, Xdot>,
Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<Ydot - YdotNonZero, Xdot - XdotNonZero>>{},
Trim<Sequence<Ytilda, Xtilda>,
Sequence<YdotNonZero, XdotNonZero>>{},
Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}),
Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
......@@ -250,9 +250,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough<K>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Trim<Sequence<Ydot, Xdot>,
Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>,
Sequence<Ydot - YdotNonZero, Xdot - XdotNonZero>>{}),
Sequence<YdotNonZero, XdotNonZero>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
make_tuple(
......@@ -272,9 +272,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough<C>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Trim<Sequence<Ytilda, Xtilda>,
Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}),
Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
make_tuple(
......
......@@ -110,19 +110,31 @@ struct Pad
};
// LowerLengths: Sequence<...>
template <typename LowerLengths, typename LeftTrims, typename RightTrims>
struct Trim
// SliceBegins: Sequence<...>
// SliceEnds: Sequence<...>
template <typename LowerLengths, typename SliceBegins, typename SliceEnds>
struct Slice
{
static constexpr index_t nDim = LowerLengths::Size();
using LowerIndex = MultiIndex<nDim>;
using UpperIndex = MultiIndex<nDim>;
__host__ __device__ explicit constexpr Trim()
__host__ __device__ explicit constexpr Slice()
{
static_assert(LowerLengths::GetSize() == nDim && LeftTrims::GetSize() == nDim &&
RightTrims::GetSize() == nDim,
static_assert(LowerLengths::GetSize() == nDim && SliceBegins::GetSize() == nDim &&
SliceEnds::GetSize() == nDim,
"wrong! # of dimensions not consistent");
#if 0
// TODO: would not compile, error on constexpr
static_for<0, nDim, 1>{}([&](auto idim) {
static_assert(SliceBegins::At(idim) <= SliceEnds::At(idim) &&
SliceBegins::At(idim) >= 0 &&
SliceEnds::At(idim) <= LowerLengths::At(idim),
"wrong! Slice config is wrong");
});
#endif
}
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
......@@ -131,12 +143,12 @@ struct Trim
__host__ __device__ static constexpr auto GetUpperLengths()
{
return LowerLengths{} - LeftTrims{} - RightTrims{};
return SliceEnds{} - SliceBegins{};
}
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
return idx_up + LeftTrims{};
return idx_up + SliceBegins{};
}
__host__ __device__ static constexpr auto
......
......@@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0
#if 1
// BlockSize = 256, EperBlock = 8, each thread hold 64 data
constexpr index_t BlockSize = 256;
......@@ -128,7 +128,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
#elif 1
// BlockSize = 64, each thread hold 64 data
constexpr index_t BlockSize = 64;
......
......@@ -46,7 +46,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
#if 0
// BlockSize = 256, blockwise-GEMM 128x128, each thread hold 64 data
constexpr index_t BlockSize = 256;
......@@ -83,7 +83,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1
#elif 0
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
constexpr index_t BlockSize = 256;
......
......@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
#if 0
// BlockSize = 256, GemmKPerBlock = 8
constexpr index_t BlockSize = 256;
......
......@@ -158,10 +158,10 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>;
#elif 0
#elif 1
// 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t C = 128;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 128;
......@@ -188,7 +188,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 1
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128;
constexpr index_t C = 1024;
......
......@@ -87,21 +87,6 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 34;
constexpr index_t WI = 34;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
......@@ -296,7 +281,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128;
......@@ -327,7 +312,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>;
#elif 1
#elif 0
// 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
......@@ -439,7 +424,7 @@ int main(int argc, char* argv[])
#elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif 1
#elif 0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
......@@ -449,7 +434,7 @@ int main(int argc, char* argv[])
ConvStrides{},
ConvDilations{},
nrepeat);
#elif 1
#elif 0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment