Commit 6b165b9b authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 7f9e59a1
...@@ -167,9 +167,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -167,9 +167,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr index_t ConvDilationH = ConvDilations{}[0]; constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1]; constexpr index_t ConvDilationW = ConvDilations{}[1];
//\todo static_assert for global vector load/store
// statc_assert();
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
...@@ -179,6 +176,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -179,6 +176,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
constexpr index_t HTilda = constexpr index_t HTilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t WTilda = constexpr index_t WTilda =
...@@ -198,10 +198,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -198,10 +198,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft; constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft; constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
// A matrix: weight
// weight out-of-bound check can be skipped // weight out-of-bound check can be skipped
constexpr bool wei_skip_out_of_bound_check = true; constexpr bool wei_skip_out_of_bound_check = true;
// weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc, wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{}, make_tuple(PassThrough<K>{},
...@@ -217,15 +217,34 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -217,15 +217,34 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc =
transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(
PassThrough<K>{},
PassThrough<C>{},
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
Slice<Sequence<YTilda, XTilda>,
Sequence<iYTilda, iXTilda>,
Sequence<iYTilda + 1, iXTilda + 1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc,
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{}, Merge<Sequence<C, 1, 1>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B matrix: output tensor
// TODO sometimes output tensor out-of-bound check can be skipped, find out all such
// situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK #if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr bool out_skip_out_of_bound_check = false; constexpr bool out_skip_out_of_bound_check = false;
#else #else
//\todo sometimes output tensor out-of-bound check can be skipped, find out all such
// situations
constexpr bool out_skip_out_of_bound_check = true; constexpr bool out_skip_out_of_bound_check = true;
#endif #endif
// output tensor
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor( constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc, out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
...@@ -256,14 +275,35 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -256,14 +275,35 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<K>{},
PassThrough<HTildaSlice>{},
PassThrough<WTildaSlice>{},
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc,
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C matrix: input tensor
// TODO sometimes input out-of-bound check can be skipped, find out all such situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK #if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr bool in_skip_out_of_bound_check = false; constexpr bool in_skip_out_of_bound_check = false;
#else #else
//\todo sometimes input out-of-bound check can be skipped, find out all such situations constexpr bool in_skip_out_of_bound_check = true;
constexpr bool in_skip_out_of_bound_check = true;
#endif #endif
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
make_tuple( make_tuple(
...@@ -306,53 +346,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -306,53 +346,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
// GEMM
constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
// A matrix
constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc =
transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(
PassThrough<K>{},
PassThrough<C>{},
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
Slice<Sequence<YTilda, XTilda>,
Sequence<iYTilda, iXTilda>,
Sequence<iYTilda + 1, iXTilda + 1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc,
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{}, Merge<Sequence<C, 1, 1>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B matrix
constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<K>{},
PassThrough<HTildaSlice>{},
PassThrough<WTildaSlice>{},
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc,
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C matrix
constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc = constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc, in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc,
......
...@@ -133,7 +133,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -133,7 +133,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1 #elif 0
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -172,7 +172,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -172,7 +172,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0 #elif 1
// cdata = 64, BlockSize = 256, 128x128x16 // cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
......
...@@ -190,7 +190,7 @@ int main(int argc, char* argv[]) ...@@ -190,7 +190,7 @@ int main(int argc, char* argv[])
#elif 1 #elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 1024;
constexpr index_t HI = 35; constexpr index_t HI = 35;
constexpr index_t WI = 35; constexpr index_t WI = 35;
constexpr index_t K = 1024; constexpr index_t K = 1024;
...@@ -247,7 +247,7 @@ int main(int argc, char* argv[]) ...@@ -247,7 +247,7 @@ int main(int argc, char* argv[])
#if 0 #if 0
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 1 #elif 0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 0 #elif 0
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment