Commit 039b5e7e authored by Chao Liu's avatar Chao Liu
Browse files

tweaking

parent e402e30b
......@@ -71,6 +71,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
#if 0 // debug
// sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data
static_assert(
......@@ -78,6 +79,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
......@@ -88,30 +90,19 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
constexpr index_t Htilda = Ho + right_pad_ho;
constexpr index_t Wtilda = Wo + right_pad_wo;
constexpr index_t Htilda = Ho + (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
constexpr index_t Wtilda = Wo + (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
// weight tensor
constexpr auto wei_k_c_yp_xp_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Pad<Sequence<Y, X>,
Sequence<0, 0>,
Sequence<Ydot * Ytilda - Y, Xdot * Xtilda - X>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_yp_xp_global_desc,
wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Embed<Sequence<Ydot, Ytilda>,
Embed<Y,
Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Sequence<Xdot, Xtilda>,
Embed<X,
Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......@@ -122,42 +113,19 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0 // debug
// output tensor
constexpr auto out_n_k_hop_wop_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<K>{},
Pad<Sequence<Ho, Wo>, Sequence<0, 0>, Sequence<right_pad_ho, right_pad_wo>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_hop_wop_global_desc,
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Sequence<Ydot, Htilda>,
Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Sequence<Xdot, Wtilda>,
Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#else
// output tensor
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
false>{},
Embed<Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
false>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#endif
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
......@@ -178,8 +146,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Ytilda, Htilda>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<Xtilda, Wtilda>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
Embed<Hi + InputLeftPads::At(0) + InputRightPads::At(0),
Sequence<Ytilda, Htilda>,
Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wi + InputLeftPads::At(1) + InputRightPads::At(1),
Sequence<Xtilda, Wtilda>,
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......
......@@ -320,7 +320,7 @@ struct UnMerge
// UpperLengths: Sequence<...>
// Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template <typename UpperLengths, typename Coefficients, bool IsAlwaysValidMapping = true>
template <index_t LowerLength, typename UpperLengths, typename Coefficients>
struct Embed
{
static constexpr index_t nDimLow = 1;
......@@ -345,8 +345,10 @@ struct Embed
{
LowerIndex idx_low(Coefficients{}[nDimUp]);
static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low(0) += idx_up[idim] * Coefficients{}[idim]; });
for(index_t i = 0; i < nDimUp; ++i)
{
idx_low(0) += idx_up[i] * Coefficients{}[i];
}
return idx_low;
}
......@@ -358,8 +360,10 @@ struct Embed
{
LowerIndex idx_low_diff{0};
static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low_diff(0) += idx_up_diff[idim] * Coefficients{}[idim]; });
for(index_t i = 0; i < nDimUp; ++i)
{
idx_low_diff(0) += idx_up_diff[i] * Coefficients{}[i];
}
return idx_low_diff;
}
......@@ -368,7 +372,37 @@ struct Embed
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return IsAlwaysValidMapping;
bool flag = true;
index_t ncorner = 1;
for(index_t idim = 0; idim < nDimUp; ++idim)
{
ncorner *= 2;
}
// loop over each corner of the upper tensor
for(index_t icorner = 0; icorner < ncorner; ++icorner)
{
// generate upper index for each corner
auto idx_up = make_zero_array<index_t, nDimUp>();
index_t itmp = icorner;
for(index_t idim = nDimUp - 1; idim >= 0; --idim)
{
idx_up(idim) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim) - 1;
itmp /= 2;
}
// calculate lower index
auto idx_low = CalculateLowerIndex(idx_up);
// judge if lower index is valid
flag = flag && idx_low[0] >= 0 && idx_low[0] < LowerLength;
}
return flag;
}
};
......
......@@ -498,7 +498,10 @@ struct TransformedTensorDescriptor
constexpr auto tran = Transforms{}.At(itran);
// check a indtransformation if it does not always has a valid mapping
if(!tran.IsValidUpperIndexAlwaysMappedToValidLowerIndex())
constexpr bool is_valid_up_always_mapped_to_valid_low =
decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex();
if(!is_valid_up_always_mapped_to_valid_low)
{
constexpr auto low_dims_part = LowDimensionIds{}.At(itran);
constexpr auto low_lengths_part =
......
......@@ -81,6 +81,37 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// BlockSize = 256, each thread hold 64 data
// for 1x1 weight, 8x8 input
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
......
......@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
#if 0
// BlockSize = 256, GemmKPerBlock = 8
constexpr index_t BlockSize = 256;
......@@ -84,7 +84,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
#elif 1
// BlockSize = 256, GemmKPerBlock = 8
// 1x1 filter, 8x8 image
constexpr index_t BlockSize = 256;
......
......@@ -13,20 +13,20 @@
#include "device_tensor.hpp"
#include "conv_common.hpp"
#include "host_conv_bwd_data.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[])
{
using namespace ck;
#if 1
#if 0
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 384;
constexpr index_t C = 128;
constexpr index_t HI = 5;
constexpr index_t WI = 5;
constexpr index_t K = 8;
constexpr index_t Y = 3;
constexpr index_t X = 3;
......@@ -50,7 +50,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
#elif 1
// 1x7
constexpr index_t N = 128;
constexpr index_t C = 1024;
......@@ -260,7 +260,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128;
constexpr index_t C = 288;
......
......@@ -44,7 +44,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
#elif 1
// 1x7
constexpr index_t N = 128;
constexpr index_t C = 1024;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment