Commit 039b5e7e authored by Chao Liu's avatar Chao Liu
Browse files

tweaking

parent e402e30b
...@@ -71,6 +71,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -71,6 +71,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr index_t ConvDilationH = ConvDilations{}[0]; constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1]; constexpr index_t ConvDilationW = ConvDilations{}[1];
#if 0 // debug
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data // TODO: this logic may not be correct for bwd-data
static_assert( static_assert(
...@@ -78,6 +79,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -78,6 +79,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0), (X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
...@@ -88,30 +90,19 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -88,30 +90,19 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda); constexpr index_t Htilda = Ho + (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda); constexpr index_t Wtilda = Wo + (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
constexpr index_t Htilda = Ho + right_pad_ho;
constexpr index_t Wtilda = Wo + right_pad_wo;
// weight tensor // weight tensor
constexpr auto wei_k_c_yp_xp_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Pad<Sequence<Y, X>,
Sequence<0, 0>,
Sequence<Ydot * Ytilda - Y, Xdot * Xtilda - X>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_yp_xp_global_desc, wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{}, make_tuple(PassThrough<K>{},
PassThrough<C>{}, PassThrough<C>{},
Embed<Sequence<Ydot, Ytilda>, Embed<Y,
Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{}, Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Sequence<Xdot, Xtilda>, Embed<X,
Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}), Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
...@@ -122,42 +113,19 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -122,42 +113,19 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0 // debug
// output tensor // output tensor
constexpr auto out_n_k_hop_wop_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<K>{},
Pad<Sequence<Ho, Wo>, Sequence<0, 0>, Sequence<right_pad_ho, right_pad_wo>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor( constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_hop_wop_global_desc, out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<K>{}, PassThrough<K>{},
Embed<Sequence<Ydot, Htilda>, Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{}, Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Sequence<Xdot, Wtilda>, Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}), Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#else
// output tensor
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
false>{},
Embed<Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
false>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#endif
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor( constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc, out_n_k_ydot_htilda_xdot_wtilda_global_desc,
...@@ -178,8 +146,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -178,8 +146,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<C>{}, PassThrough<C>{},
Embed<Sequence<Ytilda, Htilda>, Sequence<ConvDilationH, ConvStrideH, 0>>{}, Embed<Hi + InputLeftPads::At(0) + InputRightPads::At(0),
Embed<Sequence<Xtilda, Wtilda>, Sequence<ConvDilationW, ConvStrideW, 0>>{}), Sequence<Ytilda, Htilda>,
Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wi + InputLeftPads::At(1) + InputRightPads::At(1),
Sequence<Xtilda, Wtilda>,
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......
...@@ -320,7 +320,7 @@ struct UnMerge ...@@ -320,7 +320,7 @@ struct UnMerge
// UpperLengths: Sequence<...> // UpperLengths: Sequence<...>
// Coefficients: Sequence<...> // Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp] // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template <typename UpperLengths, typename Coefficients, bool IsAlwaysValidMapping = true> template <index_t LowerLength, typename UpperLengths, typename Coefficients>
struct Embed struct Embed
{ {
static constexpr index_t nDimLow = 1; static constexpr index_t nDimLow = 1;
...@@ -345,8 +345,10 @@ struct Embed ...@@ -345,8 +345,10 @@ struct Embed
{ {
LowerIndex idx_low(Coefficients{}[nDimUp]); LowerIndex idx_low(Coefficients{}[nDimUp]);
static_for<0, nDimUp, 1>{}( for(index_t i = 0; i < nDimUp; ++i)
[&](auto idim) { idx_low(0) += idx_up[idim] * Coefficients{}[idim]; }); {
idx_low(0) += idx_up[i] * Coefficients{}[i];
}
return idx_low; return idx_low;
} }
...@@ -358,8 +360,10 @@ struct Embed ...@@ -358,8 +360,10 @@ struct Embed
{ {
LowerIndex idx_low_diff{0}; LowerIndex idx_low_diff{0};
static_for<0, nDimUp, 1>{}( for(index_t i = 0; i < nDimUp; ++i)
[&](auto idim) { idx_low_diff(0) += idx_up_diff[idim] * Coefficients{}[idim]; }); {
idx_low_diff(0) += idx_up_diff[i] * Coefficients{}[i];
}
return idx_low_diff; return idx_low_diff;
} }
...@@ -368,7 +372,37 @@ struct Embed ...@@ -368,7 +372,37 @@ struct Embed
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
return IsAlwaysValidMapping; bool flag = true;
index_t ncorner = 1;
for(index_t idim = 0; idim < nDimUp; ++idim)
{
ncorner *= 2;
}
// loop over each corner of the upper tensor
for(index_t icorner = 0; icorner < ncorner; ++icorner)
{
// generate upper index for each corner
auto idx_up = make_zero_array<index_t, nDimUp>();
index_t itmp = icorner;
for(index_t idim = nDimUp - 1; idim >= 0; --idim)
{
idx_up(idim) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim) - 1;
itmp /= 2;
}
// calculate lower index
auto idx_low = CalculateLowerIndex(idx_up);
// judge if lower index is valid
flag = flag && idx_low[0] >= 0 && idx_low[0] < LowerLength;
}
return flag;
} }
}; };
......
...@@ -498,7 +498,10 @@ struct TransformedTensorDescriptor ...@@ -498,7 +498,10 @@ struct TransformedTensorDescriptor
constexpr auto tran = Transforms{}.At(itran); constexpr auto tran = Transforms{}.At(itran);
// check a indtransformation if it does not always has a valid mapping // check a indtransformation if it does not always has a valid mapping
if(!tran.IsValidUpperIndexAlwaysMappedToValidLowerIndex()) constexpr bool is_valid_up_always_mapped_to_valid_low =
decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex();
if(!is_valid_up_always_mapped_to_valid_low)
{ {
constexpr auto low_dims_part = LowDimensionIds{}.At(itran); constexpr auto low_dims_part = LowDimensionIds{}.At(itran);
constexpr auto low_lengths_part = constexpr auto low_lengths_part =
......
...@@ -81,6 +81,37 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -81,6 +81,37 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// BlockSize = 256, each thread hold 64 data
// for 1x1 weight, 8x8 input
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#endif #endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
......
...@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1 #if 0
// BlockSize = 256, GemmKPerBlock = 8 // BlockSize = 256, GemmKPerBlock = 8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -84,7 +84,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -84,7 +84,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0 #elif 1
// BlockSize = 256, GemmKPerBlock = 8 // BlockSize = 256, GemmKPerBlock = 8
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
......
...@@ -13,20 +13,20 @@ ...@@ -13,20 +13,20 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv_bwd_data.hpp" #include "host_conv_bwd_data.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp" //#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp" //#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 1 #if 0
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 128;
constexpr index_t HI = 35; constexpr index_t HI = 5;
constexpr index_t WI = 35; constexpr index_t WI = 5;
constexpr index_t K = 384; constexpr index_t K = 8;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -50,7 +50,7 @@ int main(int argc, char* argv[]) ...@@ -50,7 +50,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 1x7 // 1x7
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 1024;
...@@ -260,7 +260,7 @@ int main(int argc, char* argv[]) ...@@ -260,7 +260,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 288; constexpr index_t C = 288;
......
...@@ -44,7 +44,7 @@ int main(int argc, char* argv[]) ...@@ -44,7 +44,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 1x7 // 1x7
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 1024;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment