"...resnet50_tensorflow.git" did not exist on "d7333c866d70d4ecb72194f573e17b96d5a7c7d1"
Commit 157491ab authored by Chao Liu's avatar Chao Liu
Browse files

added bwd data v2r1: no need for atomic

parent b7992190
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
namespace ck { namespace ck {
// GemmK = K * Ydot * Xdot;
// GemmM = C * Ytilda * Xtilda;
// GemmN = N * Htilda * Wtilda;
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
typename Float, typename Float,
...@@ -73,33 +76,41 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -73,33 +76,41 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
// simplicity constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
static_assert(ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
ConvDilationW == 1,
"wrong! not supported yet");
// TODO: these logic are only for stride = 1, dilation = 1 constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Ydot = Y; constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ytilda = 1;
constexpr index_t Htilda = Ho + Y - 1;
constexpr index_t Xdot = X; constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xtilda = 1; constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t Wtilda = Wo + X - 1;
constexpr index_t GemmK = K * Ydot * Xdot; constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
constexpr index_t GemmM = C * Ytilda * Xtilda; constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
constexpr index_t GemmN = N * Htilda * Wtilda;
constexpr index_t Htilda = Ho + right_pad_ho;
constexpr index_t Wtilda = Wo + right_pad_wo;
// weight tensor // weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( constexpr auto wei_k_c_yp_xp_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc, wei_k_c_y_x_global_desc,
make_tuple( make_tuple(PassThrough<K>{},
PassThrough<K>{}, PassThrough<C>{},
Pad<Sequence<Y, X>,
Sequence<0, 0>,
Sequence<Ydot * Ytilda - Y, Xdot * Xtilda - X>,
true>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_yp_xp_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{}, PassThrough<C>{},
Embed<Sequence<Ydot, Ytilda>, Sequence<1, 1, 0>>{}, // coefficient may be wrong Embed<Sequence<Ydot, Ytilda>,
Embed<Sequence<Xdot, Xtilda>, Sequence<1, 1, 0>>{}), // coefficient may be wrong Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
...@@ -110,23 +121,25 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -110,23 +121,25 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor // output tensor
constexpr auto out_n_k_hop_wop_global_desc = transform_tensor_descriptor( constexpr auto out_n_k_hop_wop_global_desc =
out_n_k_ho_wo_global_desc, transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple( make_tuple(PassThrough<N>{},
PassThrough<N>{},
PassThrough<K>{}, PassThrough<K>{},
Pad<Sequence<Ho, Wo>, Sequence<0, 0>, Sequence<Y - 1, X - 1>>{}), // coefficient may Pad<Sequence<Ho, Wo>,
// be wrong Sequence<0, 0>,
Sequence<right_pad_ho, right_pad_wo>,
true>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor( constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_hop_wop_global_desc, out_n_k_hop_wop_global_desc,
make_tuple( make_tuple(PassThrough<N>{},
PassThrough<N>{},
PassThrough<K>{}, PassThrough<K>{},
Embed<Sequence<Ydot, Htilda>, Sequence<0, 1, 0>>{}, // coefficient may be wrong Embed<Sequence<Ydot, Htilda>,
Embed<Sequence<Xdot, Wtilda>, Sequence<0, 1, 0>>{}), // coefficient may be wrong Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
...@@ -137,14 +150,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -137,14 +150,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// input tensor // input tensor
constexpr auto eff_left_pads = LeftPads{} + Sequence<Y - 1, X - 1>{};
constexpr auto eff_right_pads = RightPads{} + Sequence<Y - 1, X - 1>{};
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<C>{}, PassThrough<C>{},
Pad<Sequence<Hi, Wi>, decltype(eff_left_pads), decltype(eff_right_pads)>{}), Pad<Sequence<Hi, Wi>, LeftPads, RightPads, true>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
...@@ -160,7 +170,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -160,7 +170,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}), make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM // GEMM
......
...@@ -83,37 +83,16 @@ struct Pad ...@@ -83,37 +83,16 @@ struct Pad
__host__ __device__ constexpr bool __host__ __device__ constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) const IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) const
{
#if 0
struct lambda_no_pad
{
__host__ __device__ constexpr bool operator()(index_t x) const { return x == 0; }
};
if(sequence_all_of(LeftPads{}, lambda_no_pad{}) &&
sequence_all_of(RightPads{}, lambda_no_pad{}))
{
return true;
}
else
#endif
{ {
bool flag = true; bool flag = true;
static_for<0, nDim, 1>{}([&](auto idim) { static_for<0, nDim, 1>{}([&](auto idim) {
// only check if there is left-padding flag = flag && (idx_up[idim] >= LeftPads::At(idim)) &&
static_if<(LeftPads::At(idim) != 0)>{}( (idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim));
[&](auto) { flag = flag && idx_up[idim] >= LeftPads::At(idim); });
// only check if there is right-padding
static_if<(RightPads::At(idim) != 0)>{}([&](auto) {
flag = flag && (idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim));
});
}); });
return flag; return flag;
} }
}
}; };
// LowerLengths: Sequence<...> // LowerLengths: Sequence<...>
......
...@@ -46,18 +46,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1 ...@@ -46,18 +46,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
{ {
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto a_k_m_global_desc = AGlobalDesc{}; constexpr auto a_k_m_global_desc = AGlobalDesc{};
constexpr auto b_k_n_global_desc = BGlobalDesc{}; constexpr auto b_k_n_global_desc = BGlobalDesc{};
constexpr auto c_m_n_global_desc = CGlobalDesc{}; constexpr auto c_m_n_global_desc = CGlobalDesc{};
constexpr auto K = a_k_m_global_desc.GetLength(I0); constexpr auto K = a_k_m_global_desc.GetLengths()[0];
constexpr auto M = a_k_m_global_desc.GetLength(I1); constexpr auto M = a_k_m_global_desc.GetLengths()[1];
constexpr auto N = b_k_n_global_desc.GetLength(I1); constexpr auto N = b_k_n_global_desc.GetLengths()[1];
// lds max alignment // lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDataPerAccess_M, constexpr index_t max_lds_align = math::lcm(ABlockCopyDataPerAccess_M,
......
...@@ -97,12 +97,57 @@ __host__ __device__ constexpr T min(T x, Ts... xs) ...@@ -97,12 +97,57 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
return x < y ? x : y; return x < y ? x : y;
} }
// this is WRONG // highest common factor
// TODO: implement least common multiple properly, instead of calling max() template <typename T>
template <class T, class... Ts> __host__ __device__ constexpr T hcf(T x, T y)
__host__ __device__ constexpr T lcm(T x, Ts... xs) {
if(x == 0)
{
return y;
}
if(y == 0)
{
return x;
}
if(x == y)
{
return x;
}
if(x > y)
{
return hcf(x - y, y);
}
return hcf(x, y - x);
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto hcf(Number<X>, Number<Y>)
{
constexpr auto result = hcf(X, Y);
return Number<result>{};
}
template <typename X, typename... Ys>
__host__ __device__ constexpr auto hcf(X x, Ys... ys)
{
return hcf(x, ys...);
}
// least common multiple
template <typename T>
__host__ __device__ constexpr T lcm(T x, T y)
{
return (x * y) / hcf(x, y);
}
template <typename X, typename Y, typename... Zs>
__host__ __device__ constexpr auto lcm(X x, Y y, Zs... zs)
{ {
return max(x, xs...); return lcm(x, lcm(y, zs...));
} }
template <class T> template <class T>
......
...@@ -36,6 +36,12 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -36,6 +36,12 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2]; constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
constexpr index_t X = wei_kcyx_desc.GetLengths()[3]; constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
std::size_t data_sz = sizeof(T); std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
...@@ -105,13 +111,20 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -105,13 +111,20 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for // TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
// simplicity // simplicity
constexpr index_t Ydot = 1; constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t Ytilda = Y; constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t Htilda = Ho + Y - 1;
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; // may be wrong
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; // may be wrong
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
constexpr index_t Xdot = 1; constexpr index_t Htilda = Ho + right_pad_ho;
constexpr index_t Xtilda = X; constexpr index_t Wtilda = Wo + right_pad_wo;
constexpr index_t Wtilda = Wo + X - 1;
constexpr index_t GemmK = K * Ydot * Xdot; constexpr index_t GemmK = K * Ydot * Xdot;
constexpr index_t GemmM = C * Ytilda * Xtilda; constexpr index_t GemmM = C * Ytilda * Xtilda;
......
...@@ -22,20 +22,20 @@ int main(int argc, char* argv[]) ...@@ -22,20 +22,20 @@ int main(int argc, char* argv[])
using namespace ck; using namespace ck;
#if 0 #if 0
constexpr index_t N = 4; constexpr index_t N = 8;
constexpr index_t C = 8; constexpr index_t C = 128;
constexpr index_t HI = 11; constexpr index_t HI = 16;
constexpr index_t WI = 11; constexpr index_t WI = 16;
constexpr index_t K = 8; constexpr index_t K = 8;
constexpr index_t Y = 4; constexpr index_t Y = 2;
constexpr index_t X = 4; constexpr index_t X = 2;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<4, 4>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<2, 2>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -52,7 +52,6 @@ int main(int argc, char* argv[]) ...@@ -52,7 +52,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 1536; constexpr index_t C = 1536;
constexpr index_t HI = 8; constexpr index_t HI = 8;
...@@ -68,7 +67,6 @@ int main(int argc, char* argv[]) ...@@ -68,7 +67,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
// cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 2048; constexpr index_t C = 2048;
constexpr index_t HI = 8; constexpr index_t HI = 8;
...@@ -84,7 +82,6 @@ int main(int argc, char* argv[]) ...@@ -84,7 +82,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 7x7 image // 1x1 filter, 7x7 image
// cudnn@V100 82%, ck@V100 76%, ck@P100 67%, ck@VII 64%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 832; constexpr index_t C = 832;
constexpr index_t HI = 7; constexpr index_t HI = 7;
...@@ -100,7 +97,6 @@ int main(int argc, char* argv[]) ...@@ -100,7 +97,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
// cudnn@V100 83%, ck@V100 75%, ck@P100 78%, ck@VII 65%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1280; constexpr index_t C = 1280;
constexpr index_t HI = 8; constexpr index_t HI = 8;
...@@ -116,7 +112,6 @@ int main(int argc, char* argv[]) ...@@ -116,7 +112,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 14x14 image // 1x1 filter, 14x14 image
// cudnn@V100 62%, ck@V100 68%, ck@P100 70%, ck@VII 50%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 512; constexpr index_t C = 512;
constexpr index_t HI = 14; constexpr index_t HI = 14;
...@@ -132,7 +127,6 @@ int main(int argc, char* argv[]) ...@@ -132,7 +127,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
// cudnn@V100 74%, ck@V100 57%, ck@P100 78%, ck@VII 61%
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 1536; constexpr index_t C = 1536;
constexpr index_t HI = 8; constexpr index_t HI = 8;
...@@ -148,7 +142,6 @@ int main(int argc, char* argv[]) ...@@ -148,7 +142,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
// cudnn@V100 86%, ck@V100 84%, ck@P100 80%, ck@VII 69%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 256;
constexpr index_t HI = 28; constexpr index_t HI = 28;
...@@ -164,7 +157,6 @@ int main(int argc, char* argv[]) ...@@ -164,7 +157,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 7x7 image // 1x1 filter, 7x7 image
// cudnn@V100 71%, ck@V100 55%, ck@P100 70%, ck@VII 62%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 832; constexpr index_t C = 832;
constexpr index_t HI = 7; constexpr index_t HI = 7;
...@@ -180,7 +172,6 @@ int main(int argc, char* argv[]) ...@@ -180,7 +172,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 17x17 input // 1x1 filter, 17x17 input
// cudnn@V100 81%, ck@V100 76%, ck@P100 70%, ck@VII 76%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 768; constexpr index_t C = 768;
constexpr index_t HI = 17; constexpr index_t HI = 17;
...@@ -196,7 +187,6 @@ int main(int argc, char* argv[]) ...@@ -196,7 +187,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 14x14 image // 1x1 filter, 14x14 image
// cudnn@V100 73%, ck@V100 71%, ck@P100 70%, ck@VII 64%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 528; constexpr index_t C = 528;
constexpr index_t HI = 14; constexpr index_t HI = 14;
...@@ -212,7 +202,6 @@ int main(int argc, char* argv[]) ...@@ -212,7 +202,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 14x14 image // 1x1 filter, 14x14 image
// cudnn@V100 73%, ck@V100 72%, ck@P100 79%, ck@VII 75%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 528; constexpr index_t C = 528;
constexpr index_t HI = 14; constexpr index_t HI = 14;
...@@ -228,7 +217,6 @@ int main(int argc, char* argv[]) ...@@ -228,7 +217,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 7x7 image // 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 832; constexpr index_t C = 832;
constexpr index_t HI = 7; constexpr index_t HI = 7;
...@@ -244,7 +232,6 @@ int main(int argc, char* argv[]) ...@@ -244,7 +232,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 288; constexpr index_t C = 288;
constexpr index_t HI = 35; constexpr index_t HI = 35;
...@@ -340,9 +327,6 @@ int main(int argc, char* argv[]) ...@@ -340,9 +327,6 @@ int main(int argc, char* argv[])
#if 0 #if 0
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
#elif 0
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
#else #else
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
......
...@@ -58,7 +58,7 @@ int main(int argc, char* argv[]) ...@@ -58,7 +58,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42% // cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
constexpr index_t N = 64; constexpr index_t N = 64;
...@@ -250,7 +250,7 @@ int main(int argc, char* argv[]) ...@@ -250,7 +250,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81% // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128; constexpr index_t N = 128;
...@@ -296,7 +296,7 @@ int main(int argc, char* argv[]) ...@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>; using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>;
#elif 1 #elif 0
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment