Commit b8385cca authored by Chao Liu's avatar Chao Liu
Browse files

change Trim to Slice

parent 0368045e
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace ck { namespace ck {
// GemmM = C * Ytilda * Xtilda; // GemmM = C * Ytilda * Xtilda;
// GemmN = N * Htilda * Wtilda; // GemmN = N * HtildaNonZero * WtildaNonZero;
// GemmK = K * Ydot * Xdot; // GemmK = K * Ydot * Xdot;
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
...@@ -149,9 +149,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -149,9 +149,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
PassThrough<K>{}, PassThrough<K>{},
PassThrough<Ytilda>{}, PassThrough<Ytilda>{},
PassThrough<Xtilda>{}, PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>, Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>, Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}), Sequence<HtildaRight, WtildaRight>>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple( make_tuple(
...@@ -205,9 +205,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -205,9 +205,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
PassThrough<C>{}, PassThrough<C>{},
PassThrough<Ytilda>{}, PassThrough<Ytilda>{},
PassThrough<Xtilda>{}, PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>, Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>, Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}), Sequence<HtildaRight, WtildaRight>>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple( make_tuple(
......
...@@ -9,9 +9,9 @@ ...@@ -9,9 +9,9 @@
namespace ck { namespace ck {
// Ytilda*Xtilda number of GEMMs // Ytilda*Xtilda number of GEMMs
// GemmM = C // GemmM = C;
// GemmN = N * Htilda * Wtilda; // GemmN = N * HtildaNonZero * WtildaNonZero;
// GemmK = K * slice(Ydot) * slice(Xdot); // GemmK = K * YdotNonZero * XdotNonZero;
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
typename Float, typename Float,
...@@ -184,9 +184,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -184,9 +184,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough<K>{}, PassThrough<K>{},
PassThrough<Ytilda>{}, PassThrough<Ytilda>{},
PassThrough<Xtilda>{}, PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>, Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>, Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}), Sequence<HtildaRight, WtildaRight>>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple( make_tuple(
...@@ -233,9 +233,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -233,9 +233,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough<C>{}, PassThrough<C>{},
PassThrough<Ytilda>{}, PassThrough<Ytilda>{},
PassThrough<Xtilda>{}, PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>, Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>, Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}), Sequence<HtildaRight, WtildaRight>>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple( make_tuple(
...@@ -265,12 +265,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -265,12 +265,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(PassThrough<K>{}, make_tuple(PassThrough<K>{},
PassThrough<C>{}, PassThrough<C>{},
Trim<Sequence<Ydot, Xdot>, Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>, Sequence<0, 0>,
Sequence<Ydot - YdotNonZero, Xdot - XdotNonZero>>{}, Sequence<YdotNonZero, XdotNonZero>>{},
Trim<Sequence<Ytilda, Xtilda>, Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>, Sequence<ytilda, xtilda>,
Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}), Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple( make_tuple(
...@@ -291,9 +291,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -291,9 +291,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough<K>{}, PassThrough<K>{},
PassThrough<HtildaTrim>{}, PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{}, PassThrough<WtildaTrim>{},
Trim<Sequence<Ydot, Xdot>, Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>, Sequence<0, 0>,
Sequence<Ydot - YdotNonZero, Xdot - XdotNonZero>>{}), Sequence<YdotNonZero, XdotNonZero>>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
...@@ -320,9 +320,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -320,9 +320,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough<C>{}, PassThrough<C>{},
PassThrough<HtildaTrim>{}, PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{}, PassThrough<WtildaTrim>{},
Trim<Sequence<Ytilda, Xtilda>, Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>, Sequence<ytilda, xtilda>,
Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}), Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
......
...@@ -157,9 +157,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -157,9 +157,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough<K>{}, PassThrough<K>{},
PassThrough<Ytilda>{}, PassThrough<Ytilda>{},
PassThrough<Xtilda>{}, PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>, Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>, Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}), Sequence<HtildaRight, WtildaRight>>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple( make_tuple(
...@@ -206,9 +206,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -206,9 +206,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough<C>{}, PassThrough<C>{},
PassThrough<Ytilda>{}, PassThrough<Ytilda>{},
PassThrough<Xtilda>{}, PassThrough<Xtilda>{},
Trim<Sequence<Htilda, Wtilda>, Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>, Sequence<HtildaLeft, WtildaLeft>,
Sequence<Htilda - HtildaRight, Wtilda - WtildaRight>>{}), Sequence<HtildaRight, WtildaRight>>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple( make_tuple(
...@@ -227,12 +227,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -227,12 +227,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(PassThrough<K>{}, make_tuple(PassThrough<K>{},
PassThrough<C>{}, PassThrough<C>{},
Trim<Sequence<Ydot, Xdot>, Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>, Sequence<0, 0>,
Sequence<Ydot - YdotNonZero, Xdot - XdotNonZero>>{}, Sequence<YdotNonZero, XdotNonZero>>{},
Trim<Sequence<Ytilda, Xtilda>, Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>, Sequence<ytilda, xtilda>,
Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}), Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
...@@ -250,9 +250,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -250,9 +250,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough<K>{}, PassThrough<K>{},
PassThrough<HtildaTrim>{}, PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{}, PassThrough<WtildaTrim>{},
Trim<Sequence<Ydot, Xdot>, Slice<Sequence<Ydot, Xdot>,
Sequence<0, 0>, Sequence<0, 0>,
Sequence<Ydot - YdotNonZero, Xdot - XdotNonZero>>{}), Sequence<YdotNonZero, XdotNonZero>>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
make_tuple( make_tuple(
...@@ -272,9 +272,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -272,9 +272,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough<C>{}, PassThrough<C>{},
PassThrough<HtildaTrim>{}, PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{}, PassThrough<WtildaTrim>{},
Trim<Sequence<Ytilda, Xtilda>, Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>, Sequence<ytilda, xtilda>,
Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}), Sequence<ytilda + 1, xtilda + 1>>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
make_tuple( make_tuple(
......
...@@ -110,19 +110,31 @@ struct Pad ...@@ -110,19 +110,31 @@ struct Pad
}; };
// LowerLengths: Sequence<...> // LowerLengths: Sequence<...>
template <typename LowerLengths, typename LeftTrims, typename RightTrims> // SliceBegins: Sequence<...>
struct Trim // SliceEnds: Sequence<...>
template <typename LowerLengths, typename SliceBegins, typename SliceEnds>
struct Slice
{ {
static constexpr index_t nDim = LowerLengths::Size(); static constexpr index_t nDim = LowerLengths::Size();
using LowerIndex = MultiIndex<nDim>; using LowerIndex = MultiIndex<nDim>;
using UpperIndex = MultiIndex<nDim>; using UpperIndex = MultiIndex<nDim>;
__host__ __device__ explicit constexpr Trim() __host__ __device__ explicit constexpr Slice()
{ {
static_assert(LowerLengths::GetSize() == nDim && LeftTrims::GetSize() == nDim && static_assert(LowerLengths::GetSize() == nDim && SliceBegins::GetSize() == nDim &&
RightTrims::GetSize() == nDim, SliceEnds::GetSize() == nDim,
"wrong! # of dimensions not consistent"); "wrong! # of dimensions not consistent");
#if 0
// TODO: would not compile, error on constexpr
static_for<0, nDim, 1>{}([&](auto idim) {
static_assert(SliceBegins::At(idim) <= SliceEnds::At(idim) &&
SliceBegins::At(idim) >= 0 &&
SliceEnds::At(idim) <= LowerLengths::At(idim),
"wrong! Slice config is wrong");
});
#endif
} }
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; } __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
...@@ -131,12 +143,12 @@ struct Trim ...@@ -131,12 +143,12 @@ struct Trim
__host__ __device__ static constexpr auto GetUpperLengths() __host__ __device__ static constexpr auto GetUpperLengths()
{ {
return LowerLengths{} - LeftTrims{} - RightTrims{}; return SliceEnds{} - SliceBegins{};
} }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{ {
return idx_up + LeftTrims{}; return idx_up + SliceBegins{};
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
......
...@@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0 #if 1
// BlockSize = 256, EperBlock = 8, each thread hold 64 data // BlockSize = 256, EperBlock = 8, each thread hold 64 data
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -128,7 +128,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -128,7 +128,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0 #elif 1
// BlockSize = 64, each thread hold 64 data // BlockSize = 64, each thread hold 64 data
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
......
...@@ -46,7 +46,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc, ...@@ -46,7 +46,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1 #if 0
// BlockSize = 256, blockwise-GEMM 128x128, each thread hold 64 data // BlockSize = 256, blockwise-GEMM 128x128, each thread hold 64 data
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -83,7 +83,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc, ...@@ -83,7 +83,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1 #elif 0
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data // BlockSize = 256, EPerBlock = 16, each thread hold 64 data
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
......
...@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1 #if 0
// BlockSize = 256, GemmKPerBlock = 8 // BlockSize = 256, GemmKPerBlock = 8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
......
...@@ -158,10 +158,10 @@ int main(int argc, char* argv[]) ...@@ -158,10 +158,10 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>; using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>;
#elif 0 #elif 1
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 128;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 128; constexpr index_t K = 128;
...@@ -188,7 +188,7 @@ int main(int argc, char* argv[]) ...@@ -188,7 +188,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>; using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>;
#elif 1 #elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 1024;
......
...@@ -87,21 +87,6 @@ int main(int argc, char* argv[]) ...@@ -87,21 +87,6 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<2, 2>; using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 34;
constexpr index_t WI = 34;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
...@@ -296,7 +281,7 @@ int main(int argc, char* argv[]) ...@@ -296,7 +281,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81% // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128; constexpr index_t N = 128;
...@@ -327,7 +312,7 @@ int main(int argc, char* argv[]) ...@@ -327,7 +312,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>; using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>;
#elif 1 #elif 0
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -439,7 +424,7 @@ int main(int argc, char* argv[]) ...@@ -439,7 +424,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif 1 #elif 0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(in_nchw_desc, device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
...@@ -449,7 +434,7 @@ int main(int argc, char* argv[]) ...@@ -449,7 +434,7 @@ int main(int argc, char* argv[])
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
nrepeat); nrepeat);
#elif 1 #elif 0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment