Commit c0bfcf91 authored by Chao Liu's avatar Chao Liu
Browse files

adding group

parent 19173ab7
......@@ -320,12 +320,12 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
const index_t C = a_g_n_c_wis_lengths[2];
const index_t GemmMRaw = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 4,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const index_t GemmKRaw = C * std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
b_g_k_c_xs_lengths.begin() + 4,
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
......@@ -432,12 +432,12 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
const index_t C = a_g_n_c_wis_lengths[2];
const index_t GemmMRaw = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 5,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const index_t GemmKRaw = C * std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
b_g_k_c_xs_lengths.begin() + 5,
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
......@@ -538,6 +538,146 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
template <typename ALay,
typename std::enable_if<NDimSpatial == 2 &&
is_same_v<ALay, tensor_layout::convolution::G_N_HW_C>,
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t GemmMRaw = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const index_t GemmKRaw = C * std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = e_g_n_k_wos_lengths[3];
const index_t Wo = e_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// This is different
const index_t CStride = a_g_n_c_wis_strides[2];
const index_t WStride = a_g_n_c_wis_strides[2+NDimSpatial];
const auto in_gemmmraw_gemmkraw_grid_desc =
make_naive_tensor_descriptor(make_tuple(GemmMRaw, GemmKRaw), make_tuple(WStride, CStride);
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
make_tuple(a_g_n_c_wis_srides[1],
a_g_n_c_wis_srides[3],
a_g_n_c_wis_srides[4],
a_g_n_c_wis_srides[2]));
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else
{
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
// This is different
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
make_tuple(a_g_n_c_wis_srides[1],
a_g_n_c_wis_srides[3],
a_g_n_c_wis_srides[4],
a_g_n_c_wis_srides[2]));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
}
template <typename ALay,
typename std::enable_if<NDimSpatial == 3 &&
is_same_v<ALay, tensor_layout::convolution::NDHWC>,
......@@ -558,12 +698,12 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
const index_t C = a_g_n_c_wis_lengths[2];
const index_t GemmMRaw = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 6,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const index_t GemmKRaw = C * std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
b_g_k_c_xs_lengths.begin() + 6,
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
......@@ -711,6 +851,34 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return wei_gemmn_gemmk_grid_desc;
}
template <typename BLay,
typename std::enable_if<is_same_v<BLay, tensor_layout::convolution::G_K_X_C> ||
is_same_v<BLay, tensor_layout::convolution::G_K_YX_C> ||
is_same_v<BLay, tensor_layout::convolution::G_K_ZYX_C>,
bool>::type = false>
static auto
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t GemmNRaw = K;
const index_t GemmKRaw = C * std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto wei_k_yxc_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmNRaw, GemmKRaw));
const auto wei_gemmn_gemmk_grid_desc =
matrix_padder.PadBDescriptor_N_K(wei_k_yxc_grid_desc);
return wei_gemmn_gemmk_grid_desc;
}
template <typename ELay,
typename std::enable_if<is_same_v<ELay, tensor_layout::convolution::NWK> ||
is_same_v<ELay, tensor_layout::convolution::NHWK> ||
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment