Commit e090e72a authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add K0 into KernelArgument

parent 66a297cd
......@@ -77,6 +77,7 @@ struct KernelArgument
std::array<index_t, Dim> input_left_pads;
std::array<index_t, Dim> input_right_pads;
std::array<index_t, Dim> tildes;
index_t K0;
KernelArgument(index_t N_,
index_t K_,
......@@ -88,8 +89,9 @@ struct KernelArgument
const std::vector<index_t>& conv_filter_dilations_,
const std::vector<index_t>& input_left_pads_,
const std::vector<index_t>& input_right_pads_,
const std::vector<index_t>& tildes_)
: N{N_}, K{K_}, C{C_}
const std::vector<index_t>& tildes_,
index_t K0_)
: N{N_}, K{K_}, C{C_}, K0{K0_}
{
#if defined(PP_COPY_MEMBER_VALUES)
#error "PP_COPY_MEMBER_VALUES macro was already defined"
......@@ -213,8 +215,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const index_t ConvStrideW = karg.conv_filter_strides[0];
const index_t ConvDilationW = karg.conv_filter_dilations[0];
const auto K0 = karg.K / K1;
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(karg.N, Wi, karg.C));
......@@ -225,14 +225,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(karg.N * Wo, karg.K)),
make_tuple(make_pass_through_transform(karg.N * Wo),
make_unmerge_transform(make_tuple(K0, K1))),
make_unmerge_transform(make_tuple(karg.K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(karg.K, karg.C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_tuple(make_unmerge_transform(make_tuple(karg.K0, K1)),
make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -310,13 +310,13 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_tuple(make_pass_through_transform(karg.N),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_unmerge_transform(make_tuple(karg.K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
make_tuple(make_merge_transform(make_tuple(XDotSlice, karg.K0)),
make_merge_transform(make_tuple(karg.N, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}),
......@@ -334,7 +334,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor(
wei_k_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_tuple(make_unmerge_transform(make_tuple(karg.K0, K1)),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_xtilde),
make_pass_through_transform(karg.C)),
......@@ -343,7 +343,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
make_tuple(make_merge_transform(make_tuple(XDotSlice, karg.K0)),
make_pass_through_transform(karg.C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}),
......@@ -418,8 +418,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const index_t ConvDilationH = karg.conv_filter_dilations[0];
const index_t ConvDilationW = karg.conv_filter_dilations[1];
const auto K0 = karg.K / K1;
const auto out_n_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(karg.N, Ho, Wo, karg.K));
const auto wei_k_y_x_c_grid_desc =
......@@ -434,14 +432,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(karg.N * Ho * Wo, karg.K)),
make_tuple(make_pass_through_transform(karg.N * Ho * Wo),
make_unmerge_transform(make_tuple(K0, K1))),
make_unmerge_transform(make_tuple(karg.K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(karg.K, karg.C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_tuple(make_unmerge_transform(make_tuple(karg.K0, K1)),
make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -533,7 +531,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_unmerge_transform(make_tuple(karg.K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -549,7 +547,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, karg.K0)),
make_merge_transform(make_tuple(karg.N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
......@@ -567,9 +565,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(karg.K0, K1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde),
......@@ -590,7 +588,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, karg.K0)),
make_pass_through_transform(karg.C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
......@@ -688,8 +686,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const index_t ConvDilationH = karg.conv_filter_dilations[1];
const index_t ConvDilationW = karg.conv_filter_dilations[2];
const auto K0 = karg.K / K1;
const auto out_n_do_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(karg.N, Do, Ho, Wo, karg.K));
const auto wei_k_z_y_x_c_grid_desc =
......@@ -704,14 +700,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(karg.N * Do * Ho * Wo, karg.K)),
make_tuple(make_pass_through_transform(karg.N * Do * Ho * Wo),
make_unmerge_transform(make_tuple(K0, K1))),
make_unmerge_transform(make_tuple(karg.K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(karg.K, karg.C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_tuple(make_unmerge_transform(make_tuple(karg.K0, K1)),
make_pass_through_transform(karg.C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -839,7 +835,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_unmerge_transform(make_tuple(karg.K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -860,7 +856,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, karg.K0)),
make_merge_transform(make_tuple(karg.N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}),
......@@ -888,8 +884,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
Sequence<7>{}));
const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
transform_tensor_descriptor(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(karg.K0, K1)),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
......@@ -916,7 +913,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, karg.K0)),
make_pass_through_transform(karg.C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
......@@ -1000,14 +998,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
static auto GetDummyABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
detail::KernelArgument<1>(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0}));
detail::KernelArgument<1>(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0}, 1));
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetDummyABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(detail::KernelArgument<2>(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}));
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, 1));
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
......@@ -1024,7 +1022,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{0, 0, 0}));
{0, 0, 0},
1));
}
// GridwiseGemm
......@@ -1140,7 +1139,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_xtilde}));
{i_xtilde},
GridwiseGemm::CalculateK0(Conv_K_)));
grid_desc_container_.push_back(descs);
}
}
......@@ -1176,7 +1176,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
detail::KernelArgument<NDimSpatial>(Conv_N_,
detail::KernelArgument<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
......@@ -1186,7 +1187,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_ytilde, i_xtilde}));
{i_ytilde, i_xtilde},
GridwiseGemm::CalculateK0(Conv_K_)));
grid_desc_container_.push_back(descs);
}
}
......@@ -1242,7 +1244,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_ztilde, i_ytilde, i_xtilde}));
{i_ztilde, i_ytilde, i_xtilde},
GridwiseGemm::CalculateK0(Conv_K_)));
grid_desc_container_.push_back(descs);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment