Commit 85bddfbc authored by Bartlomiej Kocot's avatar Bartlomiej Kocot Committed by Bartłomiej Kocot
Browse files

Add wei_strides to grouped conv3d wei to keep consistency

parent 7761e523
......@@ -109,6 +109,7 @@ bool run_grouped_conv_bwd_weight(
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NumDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NumDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations,
......@@ -160,6 +161,7 @@ bool run_grouped_conv_bwd_weight(
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
weights_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......
......@@ -26,6 +26,7 @@ static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Wi
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{K * X * C, X* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1};
......@@ -48,6 +49,7 @@ int main()
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
weights_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......
......@@ -30,6 +30,8 @@ static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Hi * Wi * C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
K * Y * X * C, Y* X* C, X* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Ho * Wo * K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1};
......@@ -53,6 +55,7 @@ int main()
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
weights_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......
......@@ -33,6 +33,8 @@ static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
K * Z * Y * X * C, Z* Y* X* C, Y* X* C, X* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
......@@ -56,6 +58,7 @@ int main()
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
weights_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......
......@@ -33,6 +33,8 @@ static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
K * Z * Y * X * C, Z* Y* X* C, Y* X* C, X* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
......@@ -48,20 +50,20 @@ int main()
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
G,
N,
K,
C,
{Di, Hi, Wi},
{Z, Y, X},
{Do, Ho, Wo},
{N * Di * Hi * Wi * C, Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C, 1},
{N * Do * Ho * Wo * K, Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1})
OutLayout>(G,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
weights_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)
? EXIT_SUCCESS
: EXIT_FAILURE;
}
......@@ -76,6 +76,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_strides{};
std::array<ck::index_t, NDimSpatial + 3> weights_strides{};
std::array<ck::index_t, NDimSpatial + 3> output_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
......@@ -88,6 +89,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides));
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
......@@ -108,6 +110,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
weights_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......
......@@ -35,6 +35,7 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
......
......@@ -792,6 +792,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
const std::array<ck::index_t, NDimSpatial + 3>& /*weights_strides*/,
const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
......@@ -1121,6 +1122,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
......@@ -1142,6 +1144,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
weights_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......@@ -1167,6 +1170,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
......@@ -1188,6 +1192,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
weights_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......
......@@ -245,21 +245,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t K,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides)
{
if constexpr(is_GNHWK_GKYXC_GNHWC)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
}
else if constexpr(is_NHWGK_GKYXC_NHWGC)
{
const index_t WoStride = output_strides[4];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
make_tuple(WoStride, KStride));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + OutLayout::name());
}
const index_t WoStride = output_strides[4];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
make_tuple(WoStride, KStride));
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
......@@ -270,42 +259,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides)
{
if constexpr(is_GNHWK_GKYXC_GNHWC)
{
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
}
}
else if constexpr(is_NHWGK_GKYXC_NHWGC)
const index_t NStride = input_strides[1];
const index_t HiStride = input_strides[3];
const index_t WiStride = input_strides[4];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
const index_t NStride = input_strides[1];
const index_t HiStride = input_strides[3];
const index_t WiStride = input_strides[4];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C),
make_tuple(WiStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
}
return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C),
make_tuple(WiStride, CStride));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + InLayout::name());
return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
make_tuple(NStride, HiStride, WiStride, CStride));
}
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
constexpr static auto
make_wei_grid_desc(const ck::index_t K,
const ck::index_t Y,
const ck::index_t X,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides)
{
const auto CStride = Number<1>{};
const auto KStride = weights_strides[1];
return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride));
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
constexpr static auto
make_out_grid_desc(const ck::index_t N,
......@@ -315,21 +298,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t K,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides)
{
if constexpr(is_GNDHWK_GKZYXC_GNDHWC)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
}
else if constexpr(is_NDHWGK_GKZYXC_NDHWGC)
{
const index_t WoStride = output_strides[5];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
make_tuple(WoStride, KStride));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + OutLayout::name());
}
const index_t WoStride = output_strides[5];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
make_tuple(WoStride, KStride));
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
......@@ -341,44 +313,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides)
{
if constexpr(is_GNDHWK_GKZYXC_GNDHWC)
{
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
}
}
else if constexpr(is_NDHWGK_GKZYXC_NDHWGC)
const index_t NStride = input_strides[1];
const index_t DiStride = input_strides[3];
const index_t HiStride = input_strides[4];
const index_t WiStride = input_strides[5];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
const index_t NStride = input_strides[1];
const index_t DiStride = input_strides[3];
const index_t HiStride = input_strides[4];
const index_t WiStride = input_strides[5];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
make_tuple(WiStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
}
return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
make_tuple(WiStride, CStride));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + InLayout::name());
return make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
}
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
constexpr static auto
make_wei_grid_desc(const ck::index_t K,
const ck::index_t Z,
const ck::index_t Y,
const ck::index_t X,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides)
{
const auto CStride = Number<1>{};
const auto KStride = weights_strides[1];
return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C),
make_tuple(KStride, CStride));
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N,
......@@ -388,6 +356,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */,
const std::array<ck::index_t, NDimSpatial + 3>& /* weights_strides */,
const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
......@@ -542,6 +511,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
......@@ -584,6 +554,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Y, X, C, weights_strides);
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
......@@ -618,13 +589,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_grid_desc);
}
else
{
......@@ -684,13 +651,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_grid_desc);
}
}
......@@ -703,6 +666,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
......@@ -752,6 +716,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Z, Y, X, C, weights_strides);
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
......@@ -786,13 +751,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_grid_desc);
}
else
{
......@@ -861,13 +822,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_grid_desc);
}
} // function end
......@@ -887,6 +844,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
lengths,
strides,
strides,
strides,
params,
params,
params,
......@@ -910,6 +868,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
lengths,
strides,
strides,
strides,
params,
params,
params,
......@@ -933,6 +892,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
lengths,
strides,
strides,
strides,
params,
params,
params,
......@@ -1059,6 +1019,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
......@@ -1104,6 +1065,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
weights_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......@@ -1350,6 +1312,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
......@@ -1371,6 +1334,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
weights_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......@@ -1398,6 +1362,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
......@@ -1419,6 +1384,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
weights_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......
......@@ -140,6 +140,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_strides{};
std::array<ck::index_t, NDimSpatial + 3> weights_strides{};
std::array<ck::index_t, NDimSpatial + 3> output_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
......@@ -152,6 +153,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides));
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
......@@ -172,6 +174,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
weights_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment