Commit e090e72a authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add K0 into KernelArgument

parent 66a297cd
...@@ -77,6 +77,7 @@ struct KernelArgument ...@@ -77,6 +77,7 @@ struct KernelArgument
std::array<index_t, Dim> input_left_pads; std::array<index_t, Dim> input_left_pads;
std::array<index_t, Dim> input_right_pads; std::array<index_t, Dim> input_right_pads;
std::array<index_t, Dim> tildes; std::array<index_t, Dim> tildes;
index_t K0;
KernelArgument(index_t N_, KernelArgument(index_t N_,
index_t K_, index_t K_,
...@@ -88,8 +89,9 @@ struct KernelArgument ...@@ -88,8 +89,9 @@ struct KernelArgument
const std::vector<index_t>& conv_filter_dilations_, const std::vector<index_t>& conv_filter_dilations_,
const std::vector<index_t>& input_left_pads_, const std::vector<index_t>& input_left_pads_,
const std::vector<index_t>& input_right_pads_, const std::vector<index_t>& input_right_pads_,
const std::vector<index_t>& tildes_) const std::vector<index_t>& tildes_,
: N{N_}, K{K_}, C{C_} index_t K0_)
: N{N_}, K{K_}, C{C_}, K0{K0_}
{ {
#if defined(PP_COPY_MEMBER_VALUES) #if defined(PP_COPY_MEMBER_VALUES)
#error "PP_COPY_MEMBER_VALUES macro was already defined" #error "PP_COPY_MEMBER_VALUES macro was already defined"
...@@ -213,8 +215,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -213,8 +215,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const index_t ConvStrideW = karg.conv_filter_strides[0]; const index_t ConvStrideW = karg.conv_filter_strides[0];
const index_t ConvDilationW = karg.conv_filter_dilations[0]; const index_t ConvDilationW = karg.conv_filter_dilations[0];
const auto K0 = karg.K / K1;
const auto in_n_wi_c_grid_desc = const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(karg.N, Wi, karg.C)); make_naive_tensor_descriptor_packed(make_tuple(karg.N, Wi, karg.C));
...@@ -225,14 +225,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -225,14 +225,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(karg.N * Wo, karg.K)), make_naive_tensor_descriptor_packed(make_tuple(karg.N * Wo, karg.K)),
make_tuple(make_pass_through_transform(karg.N * Wo), make_tuple(make_pass_through_transform(karg.N * Wo),
make_unmerge_transform(make_tuple(K0, K1))), make_unmerge_transform(make_tuple(karg.K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{})); make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor // B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(karg.K, karg.C)), make_naive_tensor_descriptor_packed(make_tuple(karg.K, karg.C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(karg.K0, K1)),
make_pass_through_transform(karg.C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -310,13 +310,13 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -310,13 +310,13 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_tuple(make_pass_through_transform(karg.N), make_tuple(make_pass_through_transform(karg.N),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))), make_unmerge_transform(make_tuple(karg.K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_xdotslice_wtildeslice_k0_k1_grid_desc, out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), make_tuple(make_merge_transform(make_tuple(XDotSlice, karg.K0)),
make_merge_transform(make_tuple(karg.N, WTildeSlice)), make_merge_transform(make_tuple(karg.N, WTildeSlice)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}), make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}),
...@@ -334,7 +334,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -334,7 +334,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor( const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor(
wei_k_xdot_xtilde_c_grid_desc, wei_k_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(karg.K0, K1)),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_xtilde), make_freeze_transform(i_xtilde),
make_pass_through_transform(karg.C)), make_pass_through_transform(karg.C)),
...@@ -343,7 +343,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -343,7 +343,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_xdotslice_c_grid_desc, wei_k0_k1_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), make_tuple(make_merge_transform(make_tuple(XDotSlice, karg.K0)),
make_pass_through_transform(karg.C), make_pass_through_transform(karg.C),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}), make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}),
...@@ -418,8 +418,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -418,8 +418,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const index_t ConvDilationH = karg.conv_filter_dilations[0]; const index_t ConvDilationH = karg.conv_filter_dilations[0];
const index_t ConvDilationW = karg.conv_filter_dilations[1]; const index_t ConvDilationW = karg.conv_filter_dilations[1];
const auto K0 = karg.K / K1;
const auto out_n_ho_wo_k_grid_desc = const auto out_n_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(karg.N, Ho, Wo, karg.K)); make_naive_tensor_descriptor_packed(make_tuple(karg.N, Ho, Wo, karg.K));
const auto wei_k_y_x_c_grid_desc = const auto wei_k_y_x_c_grid_desc =
...@@ -434,14 +432,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -434,14 +432,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(karg.N * Ho * Wo, karg.K)), make_naive_tensor_descriptor_packed(make_tuple(karg.N * Ho * Wo, karg.K)),
make_tuple(make_pass_through_transform(karg.N * Ho * Wo), make_tuple(make_pass_through_transform(karg.N * Ho * Wo),
make_unmerge_transform(make_tuple(K0, K1))), make_unmerge_transform(make_tuple(karg.K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{})); make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor // B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(karg.K, karg.C)), make_naive_tensor_descriptor_packed(make_tuple(karg.K, karg.C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(karg.K0, K1)),
make_pass_through_transform(karg.C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -533,7 +531,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -533,7 +531,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))), make_unmerge_transform(make_tuple(karg.K0, K1))),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -549,7 +547,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -549,7 +547,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, karg.K0)),
make_merge_transform(make_tuple(karg.N, HTildeSlice, WTildeSlice)), make_merge_transform(make_tuple(karg.N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
...@@ -567,30 +565,30 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -567,30 +565,30 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor(
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(karg.K0, K1)),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde), make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde), make_freeze_transform(i_xtilde),
make_pass_through_transform(karg.C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
Sequence<2>{}, Sequence<2>{},
Sequence<4>{}, Sequence<4>{},
Sequence<5>{}), Sequence<5>{}),
make_tuple(Sequence<0, 1>{}, make_tuple(Sequence<0, 1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<3>{}, Sequence<3>{},
Sequence<>{}, Sequence<>{},
Sequence<>{}, Sequence<>{},
Sequence<4>{})); Sequence<4>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc, wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, karg.K0)),
make_pass_through_transform(karg.C), make_pass_through_transform(karg.C),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
...@@ -688,8 +686,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -688,8 +686,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const index_t ConvDilationH = karg.conv_filter_dilations[1]; const index_t ConvDilationH = karg.conv_filter_dilations[1];
const index_t ConvDilationW = karg.conv_filter_dilations[2]; const index_t ConvDilationW = karg.conv_filter_dilations[2];
const auto K0 = karg.K / K1;
const auto out_n_do_ho_wo_k_grid_desc = const auto out_n_do_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(karg.N, Do, Ho, Wo, karg.K)); make_naive_tensor_descriptor_packed(make_tuple(karg.N, Do, Ho, Wo, karg.K));
const auto wei_k_z_y_x_c_grid_desc = const auto wei_k_z_y_x_c_grid_desc =
...@@ -704,14 +700,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -704,14 +700,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(karg.N * Do * Ho * Wo, karg.K)), make_naive_tensor_descriptor_packed(make_tuple(karg.N * Do * Ho * Wo, karg.K)),
make_tuple(make_pass_through_transform(karg.N * Do * Ho * Wo), make_tuple(make_pass_through_transform(karg.N * Do * Ho * Wo),
make_unmerge_transform(make_tuple(K0, K1))), make_unmerge_transform(make_tuple(karg.K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{})); make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor // B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(karg.K, karg.C)), make_naive_tensor_descriptor_packed(make_tuple(karg.K, karg.C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(karg.K0, K1)),
make_pass_through_transform(karg.C)), make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -839,7 +835,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -839,7 +835,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))), make_unmerge_transform(make_tuple(karg.K0, K1))),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -860,7 +856,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -860,7 +856,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple( make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, karg.K0)),
make_merge_transform(make_tuple(karg.N, DTildeSlice, HTildeSlice, WTildeSlice)), make_merge_transform(make_tuple(karg.N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}), make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}),
...@@ -888,37 +884,39 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -888,37 +884,39 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
Sequence<7>{})); Sequence<7>{}));
const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc = const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, transform_tensor_descriptor(
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_slice_transform(ZDot, I0, ZDotSlice), make_tuple(make_unmerge_transform(make_tuple(karg.K0, K1)),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(YDot, I0, YDotSlice),
make_freeze_transform(i_ztilde), make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde), make_freeze_transform(i_ztilde),
make_freeze_transform(i_xtilde), make_freeze_transform(i_ytilde),
make_pass_through_transform(karg.C)), make_freeze_transform(i_xtilde),
make_tuple(Sequence<0>{}, make_pass_through_transform(karg.C)),
Sequence<1>{}, make_tuple(Sequence<0>{},
Sequence<3>{}, Sequence<1>{},
Sequence<5>{}, Sequence<3>{},
Sequence<2>{}, Sequence<5>{},
Sequence<4>{}, Sequence<2>{},
Sequence<6>{}, Sequence<4>{},
Sequence<7>{}), Sequence<6>{},
make_tuple(Sequence<0, 1>{}, Sequence<7>{}),
Sequence<2>{}, make_tuple(Sequence<0, 1>{},
Sequence<3>{}, Sequence<2>{},
Sequence<4>{}, Sequence<3>{},
Sequence<>{}, Sequence<4>{},
Sequence<>{}, Sequence<>{},
Sequence<>{}, Sequence<>{},
Sequence<5>{})); Sequence<>{},
Sequence<5>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc, wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), make_tuple(
make_pass_through_transform(karg.C), make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, karg.K0)),
make_pass_through_transform(K1)), make_pass_through_transform(karg.C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}), make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
...@@ -1000,14 +998,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1000,14 +998,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
static auto GetDummyABCGridDesc() static auto GetDummyABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
detail::KernelArgument<1>(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0})); detail::KernelArgument<1>(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0}, 1));
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetDummyABCGridDesc() static auto GetDummyABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(detail::KernelArgument<2>( return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(detail::KernelArgument<2>(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0})); 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, 1));
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
...@@ -1024,7 +1022,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1024,7 +1022,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
{1, 1, 1}, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
{0, 0, 0})); {0, 0, 0},
1));
} }
// GridwiseGemm // GridwiseGemm
...@@ -1140,7 +1139,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1140,7 +1139,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_dilations_, conv_filter_dilations_,
input_left_pads_, input_left_pads_,
input_right_pads_, input_right_pads_,
{i_xtilde})); {i_xtilde},
GridwiseGemm::CalculateK0(Conv_K_)));
grid_desc_container_.push_back(descs); grid_desc_container_.push_back(descs);
} }
} }
...@@ -1176,17 +1176,19 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1176,17 +1176,19 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
detail::KernelArgument<NDimSpatial>(Conv_N_, detail::KernelArgument<NDimSpatial>(
Conv_K_, Conv_N_,
Conv_C_, Conv_K_,
input_spatial_lengths_, Conv_C_,
filter_spatial_lengths_, input_spatial_lengths_,
output_spatial_lengths_, filter_spatial_lengths_,
conv_filter_strides_, output_spatial_lengths_,
conv_filter_dilations_, conv_filter_strides_,
input_left_pads_, conv_filter_dilations_,
input_right_pads_, input_left_pads_,
{i_ytilde, i_xtilde})); input_right_pads_,
{i_ytilde, i_xtilde},
GridwiseGemm::CalculateK0(Conv_K_)));
grid_desc_container_.push_back(descs); grid_desc_container_.push_back(descs);
} }
} }
...@@ -1242,7 +1244,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1242,7 +1244,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_dilations_, conv_filter_dilations_,
input_left_pads_, input_left_pads_,
input_right_pads_, input_right_pads_,
{i_ztilde, i_ytilde, i_xtilde})); {i_ztilde, i_ytilde, i_xtilde},
GridwiseGemm::CalculateK0(Conv_K_)));
grid_desc_container_.push_back(descs); grid_desc_container_.push_back(descs);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment