Commit 157491ab authored by Chao Liu's avatar Chao Liu
Browse files

added bwd data v2r1: no need for atomic

parent b7992190
......@@ -8,6 +8,9 @@
namespace ck {
// GemmK = K * Ydot * Xdot;
// GemmM = C * Ytilda * Xtilda;
// GemmN = N * Htilda * Wtilda;
template <index_t GridSize,
index_t BlockSize,
typename Float,
......@@ -73,33 +76,41 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
// simplicity
static_assert(ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
ConvDilationW == 1,
"wrong! not supported yet");
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
// TODO: these logic are only for stride = 1, dilation = 1
constexpr index_t Ydot = Y;
constexpr index_t Ytilda = 1;
constexpr index_t Htilda = Ho + Y - 1;
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Xdot = X;
constexpr index_t Xtilda = 1;
constexpr index_t Wtilda = Wo + X - 1;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t GemmK = K * Ydot * Xdot;
constexpr index_t GemmM = C * Ytilda * Xtilda;
constexpr index_t GemmN = N * Htilda * Wtilda;
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
constexpr index_t Htilda = Ho + right_pad_ho;
constexpr index_t Wtilda = Wo + right_pad_wo;
// weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
constexpr auto wei_k_c_yp_xp_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(
PassThrough<K>{},
PassThrough<C>{},
Embed<Sequence<Ydot, Ytilda>, Sequence<1, 1, 0>>{}, // coefficient may be wrong
Embed<Sequence<Xdot, Xtilda>, Sequence<1, 1, 0>>{}), // coefficient may be wrong
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Pad<Sequence<Y, X>,
Sequence<0, 0>,
Sequence<Ydot * Ytilda - Y, Xdot * Xtilda - X>,
true>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_yp_xp_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Embed<Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......@@ -110,23 +121,25 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
constexpr auto out_n_k_hop_wop_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<K>{},
Pad<Sequence<Ho, Wo>, Sequence<0, 0>, Sequence<Y - 1, X - 1>>{}), // coefficient may
// be wrong
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto out_n_k_hop_wop_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Pad<Sequence<Ho, Wo>,
Sequence<0, 0>,
Sequence<right_pad_ho, right_pad_wo>,
true>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_hop_wop_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<K>{},
Embed<Sequence<Ydot, Htilda>, Sequence<0, 1, 0>>{}, // coefficient may be wrong
Embed<Sequence<Xdot, Wtilda>, Sequence<0, 1, 0>>{}), // coefficient may be wrong
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......@@ -137,14 +150,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{}));
// input tensor
constexpr auto eff_left_pads = LeftPads{} + Sequence<Y - 1, X - 1>{};
constexpr auto eff_right_pads = RightPads{} + Sequence<Y - 1, X - 1>{};
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, decltype(eff_left_pads), decltype(eff_right_pads)>{}),
Pad<Sequence<Hi, Wi>, LeftPads, RightPads, true>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
......@@ -160,7 +170,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
......
......@@ -84,35 +84,14 @@ struct Pad
__host__ __device__ constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) const
{
#if 0
struct lambda_no_pad
{
__host__ __device__ constexpr bool operator()(index_t x) const { return x == 0; }
};
bool flag = true;
if(sequence_all_of(LeftPads{}, lambda_no_pad{}) &&
sequence_all_of(RightPads{}, lambda_no_pad{}))
{
return true;
}
else
#endif
{
bool flag = true;
static_for<0, nDim, 1>{}([&](auto idim) {
flag = flag && (idx_up[idim] >= LeftPads::At(idim)) &&
(idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim));
});
static_for<0, nDim, 1>{}([&](auto idim) {
// only check if there is left-padding
static_if<(LeftPads::At(idim) != 0)>{}(
[&](auto) { flag = flag && idx_up[idim] >= LeftPads::At(idim); });
// only check if there is right-padding
static_if<(RightPads::At(idim) != 0)>{}([&](auto) {
flag = flag && (idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim));
});
});
return flag;
}
return flag;
}
};
......
......@@ -46,18 +46,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto a_k_m_global_desc = AGlobalDesc{};
constexpr auto b_k_n_global_desc = BGlobalDesc{};
constexpr auto c_m_n_global_desc = CGlobalDesc{};
constexpr auto K = a_k_m_global_desc.GetLength(I0);
constexpr auto M = a_k_m_global_desc.GetLength(I1);
constexpr auto N = b_k_n_global_desc.GetLength(I1);
constexpr auto K = a_k_m_global_desc.GetLengths()[0];
constexpr auto M = a_k_m_global_desc.GetLengths()[1];
constexpr auto N = b_k_n_global_desc.GetLengths()[1];
// lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDataPerAccess_M,
......
......@@ -97,12 +97,57 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
return x < y ? x : y;
}
// this is WRONG
// TODO: implement least common multiple properly, instead of calling max()
template <class T, class... Ts>
__host__ __device__ constexpr T lcm(T x, Ts... xs)
// highest common factor
template <typename T>
__host__ __device__ constexpr T hcf(T x, T y)
{
if(x == 0)
{
return y;
}
if(y == 0)
{
return x;
}
if(x == y)
{
return x;
}
if(x > y)
{
return hcf(x - y, y);
}
return hcf(x, y - x);
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto hcf(Number<X>, Number<Y>)
{
constexpr auto result = hcf(X, Y);
return Number<result>{};
}
template <typename X, typename... Ys>
__host__ __device__ constexpr auto hcf(X x, Ys... ys)
{
return hcf(x, ys...);
}
// least common multiple
template <typename T>
__host__ __device__ constexpr T lcm(T x, T y)
{
return (x * y) / hcf(x, y);
}
template <typename X, typename Y, typename... Zs>
__host__ __device__ constexpr auto lcm(X x, Y y, Zs... zs)
{
return max(x, xs...);
return lcm(x, lcm(y, zs...));
}
template <class T>
......
......@@ -36,6 +36,12 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
......@@ -105,13 +111,20 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
// simplicity
constexpr index_t Ydot = 1;
constexpr index_t Ytilda = Y;
constexpr index_t Htilda = Ho + Y - 1;
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; // may be wrong
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; // may be wrong
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
constexpr index_t Xdot = 1;
constexpr index_t Xtilda = X;
constexpr index_t Wtilda = Wo + X - 1;
constexpr index_t Htilda = Ho + right_pad_ho;
constexpr index_t Wtilda = Wo + right_pad_wo;
constexpr index_t GemmK = K * Ydot * Xdot;
constexpr index_t GemmM = C * Ytilda * Xtilda;
......
......@@ -22,20 +22,20 @@ int main(int argc, char* argv[])
using namespace ck;
#if 0
constexpr index_t N = 4;
constexpr index_t C = 8;
constexpr index_t HI = 11;
constexpr index_t WI = 11;
constexpr index_t N = 8;
constexpr index_t C = 128;
constexpr index_t HI = 16;
constexpr index_t WI = 16;
constexpr index_t K = 8;
constexpr index_t Y = 4;
constexpr index_t X = 4;
constexpr index_t Y = 2;
constexpr index_t X = 2;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using ConvStrides = Sequence<4, 4>;
using ConvDilations = Sequence<2, 2>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
#elif 0
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
......@@ -52,7 +52,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
constexpr index_t N = 64;
constexpr index_t C = 1536;
constexpr index_t HI = 8;
......@@ -68,7 +67,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51%
constexpr index_t N = 128;
constexpr index_t C = 2048;
constexpr index_t HI = 8;
......@@ -84,7 +82,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 82%, ck@V100 76%, ck@P100 67%, ck@VII 64%
constexpr index_t N = 128;
constexpr index_t C = 832;
constexpr index_t HI = 7;
......@@ -100,7 +97,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 83%, ck@V100 75%, ck@P100 78%, ck@VII 65%
constexpr index_t N = 128;
constexpr index_t C = 1280;
constexpr index_t HI = 8;
......@@ -116,7 +112,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 14x14 image
// cudnn@V100 62%, ck@V100 68%, ck@P100 70%, ck@VII 50%
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 14;
......@@ -132,7 +127,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 74%, ck@V100 57%, ck@P100 78%, ck@VII 61%
constexpr index_t N = 64;
constexpr index_t C = 1536;
constexpr index_t HI = 8;
......@@ -148,7 +142,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 28x28 image
// cudnn@V100 86%, ck@V100 84%, ck@P100 80%, ck@VII 69%
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 28;
......@@ -164,7 +157,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 71%, ck@V100 55%, ck@P100 70%, ck@VII 62%
constexpr index_t N = 128;
constexpr index_t C = 832;
constexpr index_t HI = 7;
......@@ -180,7 +172,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 17x17 input
// cudnn@V100 81%, ck@V100 76%, ck@P100 70%, ck@VII 76%
constexpr index_t N = 128;
constexpr index_t C = 768;
constexpr index_t HI = 17;
......@@ -196,7 +187,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 14x14 image
// cudnn@V100 73%, ck@V100 71%, ck@P100 70%, ck@VII 64%
constexpr index_t N = 128;
constexpr index_t C = 528;
constexpr index_t HI = 14;
......@@ -212,7 +202,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 14x14 image
// cudnn@V100 73%, ck@V100 72%, ck@P100 79%, ck@VII 75%
constexpr index_t N = 128;
constexpr index_t C = 528;
constexpr index_t HI = 14;
......@@ -228,7 +217,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
constexpr index_t N = 128;
constexpr index_t C = 832;
constexpr index_t HI = 7;
......@@ -244,7 +232,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128;
constexpr index_t C = 288;
constexpr index_t HI = 35;
......@@ -340,9 +327,6 @@ int main(int argc, char* argv[])
#if 0
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
#elif 0
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
#else
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
......
......@@ -58,7 +58,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
constexpr index_t N = 64;
......@@ -250,7 +250,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
#elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128;
......@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 1
#elif 0
// 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment