Commit 66a297cd authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Merge kernel arguments into one object

parent cd01db8b
...@@ -62,6 +62,55 @@ __global__ void ...@@ -62,6 +62,55 @@ __global__ void
ignore = NumKBlockLoop; ignore = NumKBlockLoop;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <size_t Dim>
struct KernelArgument
{
index_t N;
index_t K;
index_t C;
std::array<index_t, Dim> input_spatial_lengths;
std::array<index_t, Dim> filter_spatial_lengths;
std::array<index_t, Dim> output_spatial_lengths;
std::array<index_t, Dim> conv_filter_strides;
std::array<index_t, Dim> conv_filter_dilations;
std::array<index_t, Dim> input_left_pads;
std::array<index_t, Dim> input_right_pads;
std::array<index_t, Dim> tildes;
KernelArgument(index_t N_,
index_t K_,
index_t C_,
const std::vector<index_t>& input_spatial_lengths_,
const std::vector<index_t>& filter_spatial_lengths_,
const std::vector<index_t>& output_spatial_lengths_,
const std::vector<index_t>& conv_filter_strides_,
const std::vector<index_t>& conv_filter_dilations_,
const std::vector<index_t>& input_left_pads_,
const std::vector<index_t>& input_right_pads_,
const std::vector<index_t>& tildes_)
: N{N_}, K{K_}, C{C_}
{
#if defined(PP_COPY_MEMBER_VALUES)
#error "PP_COPY_MEMBER_VALUES macro was already defined"
#else
#define PP_COPY_MEMBER_VALUES(member) \
std::copy_n(std::begin(member##_), \
std::min(std::size(member##_), std::size(member)), \
std::begin(member))
PP_COPY_MEMBER_VALUES(input_spatial_lengths);
PP_COPY_MEMBER_VALUES(filter_spatial_lengths);
PP_COPY_MEMBER_VALUES(output_spatial_lengths);
PP_COPY_MEMBER_VALUES(conv_filter_strides);
PP_COPY_MEMBER_VALUES(conv_filter_dilations);
PP_COPY_MEMBER_VALUES(input_left_pads);
PP_COPY_MEMBER_VALUES(input_right_pads);
PP_COPY_MEMBER_VALUES(tildes);
#undef PP_COPY_MEMBER_VALUES
#endif
}
};
} // namespace detail } // namespace detail
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
...@@ -150,68 +199,58 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -150,68 +199,58 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
static constexpr auto GemmK1Number = K1Number; static constexpr auto GemmK1Number = K1Number;
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(detail::KernelArgument<NDim> karg)
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(index_t N,
index_t K,
index_t C,
std::vector<index_t> input_spatial_lengths,
std::vector<index_t> filter_spatial_lengths,
std::vector<index_t> output_spatial_lengths,
std::vector<index_t> conv_filter_strides,
std::vector<index_t> conv_filter_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads,
std::vector<index_t> tildes)
{ {
using namespace ck; using namespace ck;
index_t i_xtilde = tildes[0]; index_t i_xtilde = karg.tildes[0];
const index_t Wi = input_spatial_lengths[0]; const index_t Wi = karg.input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0]; const index_t Wo = karg.output_spatial_lengths[0];
const index_t X = filter_spatial_lengths[0]; const index_t X = karg.filter_spatial_lengths[0];
const index_t InLeftPadW = input_left_pads[0]; const index_t InLeftPadW = karg.input_left_pads[0];
const index_t InRightPadW = input_right_pads[0]; const index_t InRightPadW = karg.input_right_pads[0];
const index_t ConvStrideW = conv_filter_strides[0]; const index_t ConvStrideW = karg.conv_filter_strides[0];
const index_t ConvDilationW = conv_filter_dilations[0]; const index_t ConvDilationW = karg.conv_filter_dilations[0];
const auto K0 = K / K1; const auto K0 = karg.K / K1;
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(karg.N, Wi, karg.C));
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)), make_naive_tensor_descriptor_packed(make_tuple(karg.N * Wo, karg.K)),
make_tuple(make_pass_through_transform(N * Wo), make_tuple(make_pass_through_transform(karg.N * Wo),
make_unmerge_transform(make_tuple(K0, K1))), make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{})); make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor // B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), make_naive_tensor_descriptor_packed(make_tuple(karg.K, karg.C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: input tensor // C: input tensor
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc, in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_x_wo_c_grid_desc, in_n_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0), make_tuple(make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Wo)), make_merge_transform(make_tuple(karg.N, Wo)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
...@@ -222,9 +261,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -222,9 +261,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
else else
{ {
const auto out_n_wo_k_grid_desc = const auto out_n_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wo, K)); make_naive_tensor_descriptor_packed(make_tuple(karg.N, Wo, karg.K));
const auto wei_k_x_c_grid_desc = const auto wei_k_x_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X, C)); make_naive_tensor_descriptor_packed(make_tuple(karg.K, X, karg.C));
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
...@@ -250,25 +289,25 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -250,25 +289,25 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
// A: output tensor // A: output tensor
const auto out_n_wop_k_grid_desc = transform_tensor_descriptor( const auto out_n_wop_k_grid_desc = transform_tensor_descriptor(
out_n_wo_k_grid_desc, out_n_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_pad_transform(Wo, I0, I0), make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)), make_pass_through_transform(karg.K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_wop_k_grid_desc, out_n_wop_k_grid_desc,
make_tuple( make_tuple(
make_pass_through_transform(N), make_pass_through_transform(karg.N),
make_embed_transform(make_tuple(XDot, WTilde), make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)), make_pass_through_transform(karg.K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor( const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor(
out_n_xdot_wtilde_k_grid_desc, out_n_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))), make_unmerge_transform(make_tuple(K0, K1))),
...@@ -278,7 +317,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -278,7 +317,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_xdotslice_wtildeslice_k0_k1_grid_desc, out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
make_merge_transform(make_tuple(N, WTildeSlice)), make_merge_transform(make_tuple(karg.N, WTildeSlice)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}), make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
...@@ -286,10 +325,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -286,10 +325,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
// B weight tensor // B weight tensor
const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_x_c_grid_desc, wei_k_x_c_grid_desc,
make_tuple(make_pass_through_transform(K), make_tuple(make_pass_through_transform(karg.K),
make_embed_transform(make_tuple(XDot, XTilde), make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)), make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
...@@ -298,14 +337,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -298,14 +337,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_xtilde), make_freeze_transform(i_xtilde),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_xdotslice_c_grid_desc, wei_k0_k1_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
make_pass_through_transform(C), make_pass_through_transform(karg.C),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}), make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
...@@ -313,34 +352,34 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -313,34 +352,34 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
// C: input tensor // C: input tensor
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc, in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc, in_n_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_embed_transform(make_tuple(XTilde, WTilde), make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor( const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_xtilde_wtilde_c_grid_desc, in_n_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_freeze_transform(i_xtilde), make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_wtildeslice_c_grid_desc, in_n_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, WTildeSlice)), make_tuple(make_merge_transform(make_tuple(karg.N, WTildeSlice)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -351,80 +390,69 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -351,80 +390,69 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(detail::KernelArgument<NDim> karg)
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(index_t N,
index_t K,
index_t C,
std::vector<index_t> input_spatial_lengths,
std::vector<index_t> filter_spatial_lengths,
std::vector<index_t> output_spatial_lengths,
std::vector<index_t> conv_filter_strides,
std::vector<index_t> conv_filter_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads,
std::vector<index_t> tildes)
{ {
using namespace ck; using namespace ck;
index_t i_ytilde = tildes[0]; index_t i_ytilde = karg.tildes[0];
index_t i_xtilde = tildes[1]; index_t i_xtilde = karg.tildes[1];
const index_t Hi = input_spatial_lengths[0]; const index_t Hi = karg.input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1]; const index_t Wi = karg.input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0]; const index_t Ho = karg.output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1]; const index_t Wo = karg.output_spatial_lengths[1];
const index_t Y = filter_spatial_lengths[0]; const index_t Y = karg.filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1]; const index_t X = karg.filter_spatial_lengths[1];
const index_t InLeftPadH = input_left_pads[0]; const index_t InLeftPadH = karg.input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1]; const index_t InLeftPadW = karg.input_left_pads[1];
const index_t InRightPadH = input_right_pads[0]; const index_t InRightPadH = karg.input_right_pads[0];
const index_t InRightPadW = input_right_pads[1]; const index_t InRightPadW = karg.input_right_pads[1];
const index_t ConvStrideH = conv_filter_strides[0]; const index_t ConvStrideH = karg.conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1]; const index_t ConvStrideW = karg.conv_filter_strides[1];
const index_t ConvDilationH = conv_filter_dilations[0]; const index_t ConvDilationH = karg.conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1]; const index_t ConvDilationW = karg.conv_filter_dilations[1];
const auto K0 = K / K1; const auto K0 = karg.K / K1;
const auto out_n_ho_wo_k_grid_desc = const auto out_n_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); make_naive_tensor_descriptor_packed(make_tuple(karg.N, Ho, Wo, karg.K));
const auto wei_k_y_x_c_grid_desc = const auto wei_k_y_x_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); make_naive_tensor_descriptor_packed(make_tuple(karg.K, Y, X, karg.C));
const auto in_n_hi_wi_c_grid_desc = const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(karg.N, Hi, Wi, karg.C));
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), make_naive_tensor_descriptor_packed(make_tuple(karg.N * Ho * Wo, karg.K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_tuple(make_pass_through_transform(karg.N * Ho * Wo),
make_unmerge_transform(make_tuple(K0, K1))), make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{})); make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor // B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), make_naive_tensor_descriptor_packed(make_tuple(karg.K, karg.C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: input tensor // C: input tensor
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc, in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
...@@ -432,8 +460,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -432,8 +460,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
in_n_y_ho_x_wo_c_grid_desc, in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0), make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0), make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Ho, Wo)), make_merge_transform(make_tuple(karg.N, Ho, Wo)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
...@@ -478,29 +506,29 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -478,29 +506,29 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
// A: output tensor // A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc, out_n_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_pad_transform(Ho, I0, I0), make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0), make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)), make_pass_through_transform(karg.K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc, out_n_hop_wop_k_grid_desc,
make_tuple( make_tuple(
make_pass_through_transform(N), make_pass_through_transform(karg.N),
make_embed_transform(make_tuple(YDot, HTilde), make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde), make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)), make_pass_through_transform(karg.K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc, out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
...@@ -522,7 +550,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -522,7 +550,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_merge_transform(make_tuple(karg.N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
...@@ -530,12 +558,12 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -530,12 +558,12 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
// B weight tensor // B weight tensor
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc, wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K), make_tuple(make_pass_through_transform(karg.K),
make_embed_transform(make_tuple(YDot, YTilde), make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)), make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde), make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)), make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
...@@ -546,7 +574,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -546,7 +574,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde), make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde), make_freeze_transform(i_xtilde),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
...@@ -563,7 +591,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -563,7 +591,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc, wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C), make_pass_through_transform(karg.C),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
...@@ -571,32 +599,32 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -571,32 +599,32 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
// C: input tensor // C: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc, in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_embed_transform(make_tuple(YTilde, HTilde), make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde), make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_freeze_transform(i_ytilde), make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde), make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -612,8 +640,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -612,8 +640,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_htildeslice_wtildeslice_c_grid_desc, in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_tuple(make_merge_transform(make_tuple(karg.N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -624,89 +652,78 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -624,89 +652,78 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(detail::KernelArgument<NDim> karg)
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(index_t N,
index_t K,
index_t C,
std::vector<index_t> input_spatial_lengths,
std::vector<index_t> filter_spatial_lengths,
std::vector<index_t> output_spatial_lengths,
std::vector<index_t> conv_filter_strides,
std::vector<index_t> conv_filter_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads,
std::vector<index_t> tildes)
{ {
using namespace ck; using namespace ck;
const index_t i_ztilde = tildes[0]; const index_t i_ztilde = karg.tildes[0];
const index_t i_ytilde = tildes[1]; const index_t i_ytilde = karg.tildes[1];
const index_t i_xtilde = tildes[2]; const index_t i_xtilde = karg.tildes[2];
const index_t Di = input_spatial_lengths[0]; const index_t Di = karg.input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1]; const index_t Hi = karg.input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2]; const index_t Wi = karg.input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0]; const index_t Do = karg.output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1]; const index_t Ho = karg.output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2]; const index_t Wo = karg.output_spatial_lengths[2];
const index_t Z = filter_spatial_lengths[0]; const index_t Z = karg.filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1]; const index_t Y = karg.filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2]; const index_t X = karg.filter_spatial_lengths[2];
const index_t InLeftPadD = input_left_pads[0]; const index_t InLeftPadD = karg.input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1]; const index_t InLeftPadH = karg.input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2]; const index_t InLeftPadW = karg.input_left_pads[2];
const index_t InRightPadD = input_right_pads[0]; const index_t InRightPadD = karg.input_right_pads[0];
const index_t InRightPadH = input_right_pads[1]; const index_t InRightPadH = karg.input_right_pads[1];
const index_t InRightPadW = input_right_pads[2]; const index_t InRightPadW = karg.input_right_pads[2];
const index_t ConvStrideD = conv_filter_strides[0]; const index_t ConvStrideD = karg.conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1]; const index_t ConvStrideH = karg.conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2]; const index_t ConvStrideW = karg.conv_filter_strides[2];
const index_t ConvDilationD = conv_filter_dilations[0]; const index_t ConvDilationD = karg.conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1]; const index_t ConvDilationH = karg.conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2]; const index_t ConvDilationW = karg.conv_filter_dilations[2];
const auto K0 = K / K1; const auto K0 = karg.K / K1;
const auto out_n_do_ho_wo_k_grid_desc = const auto out_n_do_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); make_naive_tensor_descriptor_packed(make_tuple(karg.N, Do, Ho, Wo, karg.K));
const auto wei_k_z_y_x_c_grid_desc = const auto wei_k_z_y_x_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); make_naive_tensor_descriptor_packed(make_tuple(karg.K, Z, Y, X, karg.C));
const auto in_n_di_hi_wi_c_grid_desc = const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(karg.N, Di, Hi, Wi, karg.C));
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)), make_naive_tensor_descriptor_packed(make_tuple(karg.N * Do * Ho * Wo, karg.K)),
make_tuple(make_pass_through_transform(N * Do * Ho * Wo), make_tuple(make_pass_through_transform(karg.N * Do * Ho * Wo),
make_unmerge_transform(make_tuple(K0, K1))), make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{})); make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor // B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), make_naive_tensor_descriptor_packed(make_tuple(karg.K, karg.C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: input tensor // C: input tensor
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc, in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)), make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
...@@ -720,8 +737,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -720,8 +737,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_tuple(make_freeze_transform(I0), make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0), make_freeze_transform(I0),
make_freeze_transform(I0), make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Do, Ho, Wo)), make_merge_transform(make_tuple(karg.N, Do, Ho, Wo)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<1>{}, make_tuple(Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
Sequence<5>{}, Sequence<5>{},
...@@ -781,11 +798,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -781,11 +798,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
// A: output tensor // A: output tensor
const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor( const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_do_ho_wo_k_grid_desc, out_n_do_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_pad_transform(Do, I0, I0), make_pad_transform(Do, I0, I0),
make_pad_transform(Ho, I0, I0), make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0), make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)), make_pass_through_transform(karg.K)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple( make_tuple(
...@@ -795,14 +812,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -795,14 +812,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_dop_hop_wop_k_grid_desc, out_n_dop_hop_wop_k_grid_desc,
make_tuple( make_tuple(
make_pass_through_transform(N), make_pass_through_transform(karg.N),
make_embed_transform(make_tuple(ZDot, DTilde), make_embed_transform(make_tuple(ZDot, DTilde),
make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
make_embed_transform(make_tuple(YDot, HTilde), make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde), make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)), make_pass_through_transform(karg.K)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
...@@ -815,7 +832,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -815,7 +832,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_slice_transform(ZDot, I0, ZDotSlice), make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
...@@ -844,7 +861,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -844,7 +861,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple( make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), make_merge_transform(make_tuple(karg.N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}), make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
...@@ -854,14 +871,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -854,14 +871,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
transform_tensor_descriptor( transform_tensor_descriptor(
wei_k_z_y_x_c_grid_desc, wei_k_z_y_x_c_grid_desc,
make_tuple( make_tuple(
make_pass_through_transform(K), make_pass_through_transform(karg.K),
make_embed_transform(make_tuple(ZDot, ZTilde), make_embed_transform(make_tuple(ZDot, ZTilde),
make_tuple(ConvStrideD / GcdStrideDilationD, I1)), make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
make_embed_transform(make_tuple(YDot, YTilde), make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)), make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde), make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)), make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
...@@ -879,7 +896,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -879,7 +896,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_freeze_transform(i_ztilde), make_freeze_transform(i_ztilde),
make_freeze_transform(i_ytilde), make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde), make_freeze_transform(i_xtilde),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
...@@ -900,7 +917,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -900,7 +917,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc, wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C), make_pass_through_transform(karg.C),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}), make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
...@@ -908,11 +925,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -908,11 +925,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
// C: input tensor // C: input tensor
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc, in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_pad_transform(Di, InLeftPadD, InRightPadD), make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple( make_tuple(
...@@ -921,14 +938,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -921,14 +938,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc, in_n_dip_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_embed_transform(make_tuple(ZTilde, DTilde), make_embed_transform(make_tuple(ZTilde, DTilde),
make_tuple(ConvDilationD, ConvStrideD)), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(YTilde, HTilde), make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde), make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
...@@ -940,14 +957,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -940,14 +957,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(karg.N),
make_freeze_transform(i_ztilde), make_freeze_transform(i_ztilde),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_freeze_transform(i_ytilde), make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde), make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -968,8 +985,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -968,8 +985,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
make_tuple( make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), make_merge_transform(make_tuple(karg.N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -983,30 +1000,31 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -983,30 +1000,31 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
static auto GetDummyABCGridDesc() static auto GetDummyABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0}); detail::KernelArgument<1>(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0}));
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetDummyABCGridDesc() static auto GetDummyABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(detail::KernelArgument<2>(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}); 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}));
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetDummyABCGridDesc() static auto GetDummyABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(
1, detail::KernelArgument<3>(1,
1, 1,
{1, 1, 1}, 1,
{1, 1, 1}, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
{0, 0, 0}); {1, 1, 1},
{0, 0, 0}));
} }
// GridwiseGemm // GridwiseGemm
...@@ -1112,20 +1130,21 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1112,20 +1130,21 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_, detail::KernelArgument<NDimSpatial>(Conv_N_,
Conv_K_, Conv_K_,
Conv_C_, Conv_C_,
input_spatial_lengths_, input_spatial_lengths_,
filter_spatial_lengths_, filter_spatial_lengths_,
output_spatial_lengths_, output_spatial_lengths_,
conv_filter_strides_, conv_filter_strides_,
conv_filter_dilations_, conv_filter_dilations_,
input_left_pads_, input_left_pads_,
input_right_pads_, input_right_pads_,
{i_xtilde}); {i_xtilde}));
grid_desc_container_.push_back(descs); grid_desc_container_.push_back(descs);
} }
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
void CreateABCDesc() void CreateABCDesc()
{ {
...@@ -1157,21 +1176,22 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1157,21 +1176,22 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_, detail::KernelArgument<NDimSpatial>(Conv_N_,
Conv_K_, Conv_K_,
Conv_C_, Conv_C_,
input_spatial_lengths_, input_spatial_lengths_,
filter_spatial_lengths_, filter_spatial_lengths_,
output_spatial_lengths_, output_spatial_lengths_,
conv_filter_strides_, conv_filter_strides_,
conv_filter_dilations_, conv_filter_dilations_,
input_left_pads_, input_left_pads_,
input_right_pads_, input_right_pads_,
{i_ytilde, i_xtilde}); {i_ytilde, i_xtilde}));
grid_desc_container_.push_back(descs); grid_desc_container_.push_back(descs);
} }
} }
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
void CreateABCDesc() void CreateABCDesc()
{ {
...@@ -1211,17 +1231,18 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1211,17 +1231,18 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_, detail::KernelArgument<NDimSpatial>(
Conv_K_, Conv_N_,
Conv_C_, Conv_K_,
input_spatial_lengths_, Conv_C_,
filter_spatial_lengths_, input_spatial_lengths_,
output_spatial_lengths_, filter_spatial_lengths_,
conv_filter_strides_, output_spatial_lengths_,
conv_filter_dilations_, conv_filter_strides_,
input_left_pads_, conv_filter_dilations_,
input_right_pads_, input_left_pads_,
{i_ztilde, i_ytilde, i_xtilde}); input_right_pads_,
{i_ztilde, i_ytilde, i_xtilde}));
grid_desc_container_.push_back(descs); grid_desc_container_.push_back(descs);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment