Commit 85bddfbc authored by Bartlomiej Kocot's avatar Bartlomiej Kocot Committed by Bartłomiej Kocot
Browse files

Add wei_strides to grouped conv3d wei to keep consistency

parent 7761e523
...@@ -109,6 +109,7 @@ bool run_grouped_conv_bwd_weight( ...@@ -109,6 +109,7 @@ bool run_grouped_conv_bwd_weight(
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& input_strides, const std::array<ck::index_t, NumDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NumDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NumDimSpatial + 3>& output_strides, const std::array<ck::index_t, NumDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations,
...@@ -160,6 +161,7 @@ bool run_grouped_conv_bwd_weight( ...@@ -160,6 +161,7 @@ bool run_grouped_conv_bwd_weight(
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -26,6 +26,7 @@ static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Wi ...@@ -26,6 +26,7 @@ static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Wi
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{X}; static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Wo}; static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, C, 1}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{K * X * C, X* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, K, 1}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1};
...@@ -48,6 +49,7 @@ int main() ...@@ -48,6 +49,7 @@ int main()
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -30,6 +30,8 @@ static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y ...@@ -30,6 +30,8 @@ static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo}; static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Hi * Wi * C, Hi* Wi* C, Wi* C, C, 1}; N * Hi * Wi * C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
K * Y * X * C, Y* X* C, X* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Ho * Wo * K, Ho* Wo* K, Wo* K, K, 1}; N * Ho * Wo * K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1};
...@@ -53,6 +55,7 @@ int main() ...@@ -53,6 +55,7 @@ int main()
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -33,6 +33,8 @@ static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z ...@@ -33,6 +33,8 @@ static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo}; static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1}; N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
K * Z * Y * X * C, Z* Y* X* C, Y* X* C, X* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1}; N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
...@@ -56,6 +58,7 @@ int main() ...@@ -56,6 +58,7 @@ int main()
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -33,6 +33,8 @@ static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z ...@@ -33,6 +33,8 @@ static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo}; static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1}; N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
K * Z * Y * X * C, Z* Y* X* C, Y* X* C, X* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1}; N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
...@@ -48,20 +50,20 @@ int main() ...@@ -48,20 +50,20 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>( OutLayout>(G,
G,
N, N,
K, K,
C, C,
{Di, Hi, Wi}, input_spatial_lengths,
{Z, Y, X}, filter_spatial_lengths,
{Do, Ho, Wo}, output_spatial_lengths,
{N * Di * Hi * Wi * C, Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C, 1}, input_strides,
{N * Do * Ho * Wo * K, Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K, 1}, weights_strides,
{1, 1, 1}, output_strides,
{1, 1, 1}, conv_filter_strides,
{1, 1, 1}, conv_filter_dilations,
{1, 1, 1}) input_left_pads,
input_right_pads)
? EXIT_SUCCESS ? EXIT_SUCCESS
: EXIT_FAILURE; : EXIT_FAILURE;
} }
...@@ -76,6 +76,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -76,6 +76,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_strides{}; std::array<ck::index_t, NDimSpatial + 3> input_strides{};
std::array<ck::index_t, NDimSpatial + 3> weights_strides{};
std::array<ck::index_t, NDimSpatial + 3> output_strides{}; std::array<ck::index_t, NDimSpatial + 3> output_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
...@@ -88,6 +89,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -88,6 +89,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths)); range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths)); range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides));
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
...@@ -108,6 +110,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -108,6 +110,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -35,6 +35,7 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator ...@@ -35,6 +35,7 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
......
...@@ -792,6 +792,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -792,6 +792,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/, const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
const std::array<ck::index_t, NDimSpatial + 3>& /*weights_strides*/,
const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/, const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -1121,6 +1122,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1121,6 +1122,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -1142,6 +1144,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1142,6 +1144,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
...@@ -1167,6 +1170,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1167,6 +1170,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -1188,6 +1192,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1188,6 +1192,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -244,23 +244,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -244,23 +244,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t Wo, const ck::index_t Wo,
const ck::index_t K, const ck::index_t K,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides) const std::array<ck::index_t, NDimSpatial + 3>& output_strides)
{
if constexpr(is_GNHWK_GKYXC_GNHWC)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
}
else if constexpr(is_NHWGK_GKYXC_NHWGC)
{ {
const index_t WoStride = output_strides[4]; const index_t WoStride = output_strides[4];
const auto KStride = Number<1>{}; const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
make_tuple(WoStride, KStride)); make_tuple(WoStride, KStride));
} }
else
{
throw std::runtime_error("wrong! unsupported layout: " + OutLayout::name());
}
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
constexpr static auto constexpr static auto
...@@ -269,20 +258,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -269,20 +258,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t Wi, const ck::index_t Wi,
const ck::index_t C, const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides) const std::array<ck::index_t, NDimSpatial + 3>& input_strides)
{
if constexpr(is_GNHWK_GKYXC_GNHWC)
{
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
}
}
else if constexpr(is_NHWGK_GKYXC_NHWGC)
{ {
const index_t NStride = input_strides[1]; const index_t NStride = input_strides[1];
const index_t HiStride = input_strides[3]; const index_t HiStride = input_strides[3];
...@@ -296,14 +271,22 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -296,14 +271,22 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
} }
else else
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); make_tuple(NStride, HiStride, WiStride, CStride));
} }
} }
else
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
constexpr static auto
make_wei_grid_desc(const ck::index_t K,
const ck::index_t Y,
const ck::index_t X,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides)
{ {
throw std::runtime_error("wrong! unsupported layout: " + InLayout::name()); const auto CStride = Number<1>{};
} const auto KStride = weights_strides[1];
return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride));
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
...@@ -314,23 +297,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -314,23 +297,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t Wo, const ck::index_t Wo,
const ck::index_t K, const ck::index_t K,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides) const std::array<ck::index_t, NDimSpatial + 3>& output_strides)
{
if constexpr(is_GNDHWK_GKZYXC_GNDHWC)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
}
else if constexpr(is_NDHWGK_GKZYXC_NDHWGC)
{ {
const index_t WoStride = output_strides[5]; const index_t WoStride = output_strides[5];
const auto KStride = Number<1>{}; const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
make_tuple(WoStride, KStride)); make_tuple(WoStride, KStride));
} }
else
{
throw std::runtime_error("wrong! unsupported layout: " + OutLayout::name());
}
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
constexpr static auto constexpr static auto
...@@ -340,20 +312,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -340,20 +312,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t Wi, const ck::index_t Wi,
const ck::index_t C, const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides) const std::array<ck::index_t, NDimSpatial + 3>& input_strides)
{
if constexpr(is_GNDHWK_GKZYXC_GNDHWC)
{
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
}
}
else if constexpr(is_NDHWGK_GKZYXC_NDHWGC)
{ {
const index_t NStride = input_strides[1]; const index_t NStride = input_strides[1];
const index_t DiStride = input_strides[3]; const index_t DiStride = input_strides[3];
...@@ -373,10 +331,20 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -373,10 +331,20 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
} }
} }
else
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
constexpr static auto
make_wei_grid_desc(const ck::index_t K,
const ck::index_t Z,
const ck::index_t Y,
const ck::index_t X,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides)
{ {
throw std::runtime_error("wrong! unsupported layout: " + InLayout::name()); const auto CStride = Number<1>{};
} const auto KStride = weights_strides[1];
return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C),
make_tuple(KStride, CStride));
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
...@@ -388,6 +356,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -388,6 +356,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */, const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */,
const std::array<ck::index_t, NDimSpatial + 3>& /* weights_strides */,
const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */, const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -542,6 +511,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -542,6 +511,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -584,6 +554,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -584,6 +554,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides); const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides); const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Y, X, C, weights_strides);
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
...@@ -618,13 +589,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -618,13 +589,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_grid_desc);
} }
else else
{ {
...@@ -684,13 +651,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -684,13 +651,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_grid_desc);
} }
} }
...@@ -703,6 +666,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -703,6 +666,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -752,6 +716,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -752,6 +716,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides); const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides); const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Z, Y, X, C, weights_strides);
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
...@@ -786,13 +751,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -786,13 +751,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_grid_desc);
} }
else else
{ {
...@@ -861,13 +822,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -861,13 +822,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_grid_desc);
} }
} // function end } // function end
...@@ -887,6 +844,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -887,6 +844,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
lengths, lengths,
strides, strides,
strides, strides,
strides,
params, params,
params, params,
params, params,
...@@ -910,6 +868,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -910,6 +868,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
lengths, lengths,
strides, strides,
strides, strides,
strides,
params, params,
params, params,
params, params,
...@@ -933,6 +892,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -933,6 +892,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
lengths, lengths,
strides, strides,
strides, strides,
strides,
params, params,
params, params,
params, params,
...@@ -1059,6 +1019,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1059,6 +1019,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -1104,6 +1065,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1104,6 +1065,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
...@@ -1350,6 +1312,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1350,6 +1312,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -1371,6 +1334,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1371,6 +1334,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
...@@ -1398,6 +1362,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1398,6 +1362,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -1419,6 +1384,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1419,6 +1384,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -140,6 +140,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ...@@ -140,6 +140,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_strides{}; std::array<ck::index_t, NDimSpatial + 3> input_strides{};
std::array<ck::index_t, NDimSpatial + 3> weights_strides{};
std::array<ck::index_t, NDimSpatial + 3> output_strides{}; std::array<ck::index_t, NDimSpatial + 3> output_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
...@@ -152,6 +153,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ...@@ -152,6 +153,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths)); range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths)); range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides));
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
...@@ -172,6 +174,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ...@@ -172,6 +174,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides, input_strides,
weights_strides,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment