Commit 85bddfbc authored by Bartlomiej Kocot's avatar Bartlomiej Kocot Committed by Bartłomiej Kocot
Browse files

Add wei_strides to grouped conv3d wei to keep consistency

parent 7761e523
...@@ -109,6 +109,7 @@ bool run_grouped_conv_bwd_weight( ...@@ -109,6 +109,7 @@ bool run_grouped_conv_bwd_weight(
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& input_strides, const std::array<ck::index_t, NumDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NumDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NumDimSpatial + 3>& output_strides, const std::array<ck::index_t, NumDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations,
...@@ -160,6 +161,7 @@ bool run_grouped_conv_bwd_weight( ...@@ -160,6 +161,7 @@ bool run_grouped_conv_bwd_weight(
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -26,6 +26,7 @@ static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Wi ...@@ -26,6 +26,7 @@ static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Wi
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{X}; static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Wo}; static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, C, 1}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{K * X * C, X* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, K, 1}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1};
...@@ -48,6 +49,7 @@ int main() ...@@ -48,6 +49,7 @@ int main()
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -30,6 +30,8 @@ static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y ...@@ -30,6 +30,8 @@ static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo}; static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Hi * Wi * C, Hi* Wi* C, Wi* C, C, 1}; N * Hi * Wi * C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
K * Y * X * C, Y* X* C, X* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Ho * Wo * K, Ho* Wo* K, Wo* K, K, 1}; N * Ho * Wo * K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1};
...@@ -53,6 +55,7 @@ int main() ...@@ -53,6 +55,7 @@ int main()
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -33,6 +33,8 @@ static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z ...@@ -33,6 +33,8 @@ static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo}; static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1}; N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
K * Z * Y * X * C, Z* Y* X* C, Y* X* C, X* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1}; N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
...@@ -56,6 +58,7 @@ int main() ...@@ -56,6 +58,7 @@ int main()
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -33,6 +33,8 @@ static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z ...@@ -33,6 +33,8 @@ static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo}; static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1}; N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
K * Z * Y * X * C, Z* Y* X* C, Y* X* C, X* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1}; N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
...@@ -48,20 +50,20 @@ int main() ...@@ -48,20 +50,20 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>( OutLayout>(G,
G, N,
N, K,
K, C,
C, input_spatial_lengths,
{Di, Hi, Wi}, filter_spatial_lengths,
{Z, Y, X}, output_spatial_lengths,
{Do, Ho, Wo}, input_strides,
{N * Di * Hi * Wi * C, Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C, 1}, weights_strides,
{N * Do * Ho * Wo * K, Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K, 1}, output_strides,
{1, 1, 1}, conv_filter_strides,
{1, 1, 1}, conv_filter_dilations,
{1, 1, 1}, input_left_pads,
{1, 1, 1}) input_right_pads)
? EXIT_SUCCESS ? EXIT_SUCCESS
: EXIT_FAILURE; : EXIT_FAILURE;
} }
...@@ -76,6 +76,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -76,6 +76,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_strides{}; std::array<ck::index_t, NDimSpatial + 3> input_strides{};
std::array<ck::index_t, NDimSpatial + 3> weights_strides{};
std::array<ck::index_t, NDimSpatial + 3> output_strides{}; std::array<ck::index_t, NDimSpatial + 3> output_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
...@@ -88,6 +89,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -88,6 +89,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths)); range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths)); range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides));
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
...@@ -108,6 +110,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -108,6 +110,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -35,6 +35,7 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator ...@@ -35,6 +35,7 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
......
...@@ -792,6 +792,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -792,6 +792,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/, const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
const std::array<ck::index_t, NDimSpatial + 3>& /*weights_strides*/,
const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/, const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -1121,6 +1122,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1121,6 +1122,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -1142,6 +1144,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1142,6 +1144,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
...@@ -1167,6 +1170,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1167,6 +1170,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -1188,6 +1192,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1188,6 +1192,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -245,21 +245,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -245,21 +245,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t K, const ck::index_t K,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides) const std::array<ck::index_t, NDimSpatial + 3>& output_strides)
{ {
if constexpr(is_GNHWK_GKYXC_GNHWC) const index_t WoStride = output_strides[4];
{ const auto KStride = Number<1>{};
return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
} make_tuple(WoStride, KStride));
else if constexpr(is_NHWGK_GKYXC_NHWGC)
{
const index_t WoStride = output_strides[4];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
make_tuple(WoStride, KStride));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + OutLayout::name());
}
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
...@@ -270,42 +259,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -270,42 +259,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t C, const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides) const std::array<ck::index_t, NDimSpatial + 3>& input_strides)
{ {
if constexpr(is_GNHWK_GKYXC_GNHWC) const index_t NStride = input_strides[1];
{ const index_t HiStride = input_strides[3];
if constexpr(ConvBackwardWeightSpecialization == const index_t WiStride = input_strides[4];
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) const auto CStride = input_strides[2];
{ if constexpr(ConvBackwardWeightSpecialization ==
return make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
}
}
else if constexpr(is_NHWGK_GKYXC_NHWGC)
{ {
const index_t NStride = input_strides[1]; return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C),
const index_t HiStride = input_strides[3]; make_tuple(WiStride, CStride));
const index_t WiStride = input_strides[4];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C),
make_tuple(WiStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
}
} }
else else
{ {
throw std::runtime_error("wrong! unsupported layout: " + InLayout::name()); return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
make_tuple(NStride, HiStride, WiStride, CStride));
} }
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
constexpr static auto
make_wei_grid_desc(const ck::index_t K,
const ck::index_t Y,
const ck::index_t X,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides)
{
const auto CStride = Number<1>{};
const auto KStride = weights_strides[1];
return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride));
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
constexpr static auto constexpr static auto
make_out_grid_desc(const ck::index_t N, make_out_grid_desc(const ck::index_t N,
...@@ -315,21 +298,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -315,21 +298,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t K, const ck::index_t K,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides) const std::array<ck::index_t, NDimSpatial + 3>& output_strides)
{ {
if constexpr(is_GNDHWK_GKZYXC_GNDHWC) const index_t WoStride = output_strides[5];
{ const auto KStride = Number<1>{};
return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
} make_tuple(WoStride, KStride));
else if constexpr(is_NDHWGK_GKZYXC_NDHWGC)
{
const index_t WoStride = output_strides[5];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
make_tuple(WoStride, KStride));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + OutLayout::name());
}
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
...@@ -341,44 +313,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -341,44 +313,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t C, const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides) const std::array<ck::index_t, NDimSpatial + 3>& input_strides)
{ {
if constexpr(is_GNDHWK_GKZYXC_GNDHWC) const index_t NStride = input_strides[1];
{ const index_t DiStride = input_strides[3];
if constexpr(ConvBackwardWeightSpecialization == const index_t HiStride = input_strides[4];
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) const index_t WiStride = input_strides[5];
{ const auto CStride = input_strides[2];
return make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C)); if constexpr(ConvBackwardWeightSpecialization ==
} ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
}
}
else if constexpr(is_NDHWGK_GKZYXC_NDHWGC)
{ {
const index_t NStride = input_strides[1]; return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
const index_t DiStride = input_strides[3]; make_tuple(WiStride, CStride));
const index_t HiStride = input_strides[4];
const index_t WiStride = input_strides[5];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
make_tuple(WiStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
}
} }
else else
{ {
throw std::runtime_error("wrong! unsupported layout: " + InLayout::name()); return make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
} }
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
constexpr static auto
make_wei_grid_desc(const ck::index_t K,
const ck::index_t Z,
const ck::index_t Y,
const ck::index_t X,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides)
{
const auto CStride = Number<1>{};
const auto KStride = weights_strides[1];
return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C),
make_tuple(KStride, CStride));
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N, const ck::index_t N,
...@@ -388,6 +356,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -388,6 +356,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */, const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */,
const std::array<ck::index_t, NDimSpatial + 3>& /* weights_strides */,
const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */, const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -542,6 +511,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -542,6 +511,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -584,6 +554,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -584,6 +554,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides); const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides); const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Y, X, C, weights_strides);
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
...@@ -618,13 +589,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -618,13 +589,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_grid_desc);
} }
else else
{ {
...@@ -684,13 +651,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -684,13 +651,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_grid_desc);
} }
} }
...@@ -703,6 +666,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -703,6 +666,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -752,6 +716,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -752,6 +716,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides); const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides); const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Z, Y, X, C, weights_strides);
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
...@@ -786,13 +751,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -786,13 +751,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_grid_desc);
} }
else else
{ {
...@@ -861,13 +822,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -861,13 +822,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_grid_desc);
} }
} // function end } // function end
...@@ -887,6 +844,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -887,6 +844,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
lengths, lengths,
strides, strides,
strides, strides,
strides,
params, params,
params, params,
params, params,
...@@ -910,6 +868,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -910,6 +868,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
lengths, lengths,
strides, strides,
strides, strides,
strides,
params, params,
params, params,
params, params,
...@@ -933,6 +892,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -933,6 +892,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
lengths, lengths,
strides, strides,
strides, strides,
strides,
params, params,
params, params,
params, params,
...@@ -1059,6 +1019,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1059,6 +1019,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -1104,6 +1065,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1104,6 +1065,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
...@@ -1350,6 +1312,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1350,6 +1312,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -1371,6 +1334,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1371,6 +1334,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
...@@ -1398,6 +1362,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1398,6 +1362,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -1419,6 +1384,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1419,6 +1384,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -140,6 +140,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ...@@ -140,6 +140,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_strides{}; std::array<ck::index_t, NDimSpatial + 3> input_strides{};
std::array<ck::index_t, NDimSpatial + 3> weights_strides{};
std::array<ck::index_t, NDimSpatial + 3> output_strides{}; std::array<ck::index_t, NDimSpatial + 3> output_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
...@@ -152,6 +153,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ...@@ -152,6 +153,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths)); range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths)); range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides));
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
...@@ -172,6 +174,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ...@@ -172,6 +174,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment