Unverified Commit 22443f7a authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Add wei_strides to grouped conv3d wei to keep consistency (#817)



* Add wei_strides to grouped conv3d wei to keep consistency

* Fix strides in client examples

* Unify backward weight api with forward

* Fix for example

* Fixes for examples

---------
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
parent 2474dddb
...@@ -32,63 +32,49 @@ struct SimpleDeviceMem ...@@ -32,63 +32,49 @@ struct SimpleDeviceMem
}; };
template <ck::index_t NumDimSpatial> template <ck::index_t NumDimSpatial>
std::size_t GetFlops(ck::index_t G, std::size_t GetFlops(const std::array<ck::index_t, NumDimSpatial>& output_lengths,
ck::index_t N, const std::array<ck::index_t, NumDimSpatial>& filter_lengths)
ck::index_t K,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths)
{ {
constexpr ck::index_t spatial_offset = 3;
const auto C = filter_lengths[2];
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product> // 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return static_cast<std::size_t>(2) * G * N * K * C * return static_cast<std::size_t>(2) * C *
std::accumulate(std::begin(output_spatial_lengths), std::accumulate(std::begin(output_lengths),
std::end(output_spatial_lengths), std::end(output_lengths),
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
std::multiplies<>()) * std::multiplies<>()) *
std::accumulate(std::begin(filter_spatial_lengths), std::accumulate(std::begin(filter_lengths) + spatial_offset,
std::end(filter_spatial_lengths), std::end(filter_lengths),
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
std::multiplies<>()); std::multiplies<>());
} }
template <typename InDataType, ck::index_t NumDimSpatial> template <typename InDataType, ck::index_t NumDimSpatial>
std::size_t GetInputByte(ck::index_t G, std::size_t GetInputByte(const std::array<ck::index_t, NumDimSpatial>& input_lengths)
ck::index_t N,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths)
{ {
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) + // sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return sizeof(InDataType) * (G * N * C * return sizeof(InDataType) * (std::accumulate(std::begin(input_lengths),
std::accumulate(std::begin(input_spatial_lengths), std::end(input_lengths),
std::end(input_spatial_lengths),
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
std::multiplies<>())); std::multiplies<>()));
} }
template <typename WeiDataType, ck::index_t NumDimSpatial> template <typename WeiDataType, ck::index_t NumDimSpatial>
std::size_t GetWeightByte(ck::index_t G, std::size_t GetWeightByte(const std::array<ck::index_t, NumDimSpatial>& filter_lengths)
ck::index_t K,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths)
{ {
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) + // sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return sizeof(WeiDataType) * (G * K * C * return sizeof(WeiDataType) * (std::accumulate(std::begin(filter_lengths),
std::accumulate(std::begin(filter_spatial_lengths), std::end(filter_lengths),
std::end(filter_spatial_lengths),
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
std::multiplies<>())); std::multiplies<>()));
} }
template <typename OutDataType, ck::index_t NumDimSpatial> template <typename OutDataType, ck::index_t NumDimSpatial>
std::size_t GetOutputByte(ck::index_t G, std::size_t GetOutputByte(const std::array<ck::index_t, NumDimSpatial>& output_lengths)
ck::index_t N,
ck::index_t K,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths)
{ {
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>); // sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(OutDataType) * (G * N * K * return sizeof(OutDataType) * (std::accumulate(std::begin(output_lengths),
std::accumulate(std::begin(output_spatial_lengths), std::end(output_lengths),
std::end(output_spatial_lengths),
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
std::multiplies<std::size_t>())); std::multiplies<std::size_t>()));
} }
...@@ -101,14 +87,11 @@ template <ck::index_t NumDimSpatial, ...@@ -101,14 +87,11 @@ template <ck::index_t NumDimSpatial,
typename WeiLayout, typename WeiLayout,
typename OutLayout> typename OutLayout>
bool run_grouped_conv_bwd_weight( bool run_grouped_conv_bwd_weight(
const ck::index_t G, const std::array<ck::index_t, NumDimSpatial + 3>& input_lengths,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& input_strides, const std::array<ck::index_t, NumDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NumDimSpatial + 3>& filter_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NumDimSpatial + 3>& output_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& output_strides, const std::array<ck::index_t, NumDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations,
...@@ -117,9 +100,9 @@ bool run_grouped_conv_bwd_weight( ...@@ -117,9 +100,9 @@ bool run_grouped_conv_bwd_weight(
{ {
ck::index_t split_k = 2; ck::index_t split_k = 2;
SimpleDeviceMem in(GetInputByte<InDataType, NumDimSpatial>(G, N, C, input_spatial_lengths)); SimpleDeviceMem in(GetInputByte<InDataType, NumDimSpatial + 3>(input_lengths));
SimpleDeviceMem wei(GetWeightByte<WeiDataType, NumDimSpatial>(G, K, C, filter_spatial_lengths)); SimpleDeviceMem wei(GetWeightByte<WeiDataType, NumDimSpatial + 3>(filter_lengths));
SimpleDeviceMem out(GetOutputByte<OutDataType, NumDimSpatial>(G, N, K, output_spatial_lengths)); SimpleDeviceMem out(GetOutputByte<OutDataType, NumDimSpatial + 3>(output_lengths));
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NumDimSpatial, using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NumDimSpatial,
InLayout, InLayout,
...@@ -143,6 +126,10 @@ bool run_grouped_conv_bwd_weight( ...@@ -143,6 +126,10 @@ bool run_grouped_conv_bwd_weight(
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
float best_tflops = 0; float best_tflops = 0;
std::array<ck::index_t, NumDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NumDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NumDimSpatial + 3> b_g_k_c_xs_lengths{};
// profile device operation instances // profile device operation instances
std::cout << "Run all instances and do timing" << std::endl; std::cout << "Run all instances and do timing" << std::endl;
...@@ -152,14 +139,11 @@ bool run_grouped_conv_bwd_weight( ...@@ -152,14 +139,11 @@ bool run_grouped_conv_bwd_weight(
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(), wei.GetDeviceBuffer(),
out.GetDeviceBuffer(), out.GetDeviceBuffer(),
G, input_lengths,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides, input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
...@@ -176,12 +160,10 @@ bool run_grouped_conv_bwd_weight( ...@@ -176,12 +160,10 @@ bool run_grouped_conv_bwd_weight(
{ {
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = std::size_t flop = GetFlops<NumDimSpatial + 3>(output_lengths, filter_lengths);
GetFlops<NumDimSpatial>(G, N, K, C, output_spatial_lengths, filter_spatial_lengths); std::size_t num_bytes = GetInputByte<InDataType, NumDimSpatial + 3>(input_lengths) +
std::size_t num_bytes = GetWeightByte<WeiDataType, NumDimSpatial + 3>(filter_lengths) +
GetInputByte<InDataType, NumDimSpatial>(G, N, C, input_spatial_lengths) + GetOutputByte<OutDataType, NumDimSpatial + 3>(output_lengths);
GetWeightByte<WeiDataType, NumDimSpatial>(G, K, C, filter_spatial_lengths) +
GetOutputByte<OutDataType, NumDimSpatial>(G, N, K, output_spatial_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / avg_time; float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time;
...@@ -221,14 +203,11 @@ bool run_grouped_conv_bwd_weight( ...@@ -221,14 +203,11 @@ bool run_grouped_conv_bwd_weight(
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(), wei.GetDeviceBuffer(),
out.GetDeviceBuffer(), out.GetDeviceBuffer(),
G, input_lengths,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides, input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -22,11 +22,12 @@ static constexpr ck::index_t C = 192; ...@@ -22,11 +22,12 @@ static constexpr ck::index_t C = 192;
static constexpr ck::index_t X = 3; static constexpr ck::index_t X = 3;
static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Wo = 28; static constexpr ck::index_t Wo = 28;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Wi}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{X}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Wo}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, C, 1}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, 1, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, K, 1}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{K * X * C, X* C, 1, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, 1, K};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1}; static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1};
...@@ -40,14 +41,11 @@ int main() ...@@ -40,14 +41,11 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>(G, OutLayout>(input_lengths,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides, input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -25,13 +25,15 @@ static constexpr ck::index_t Hi = 28; ...@@ -25,13 +25,15 @@ static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 28; static constexpr ck::index_t Wo = 28;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Hi, Wi}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y, X}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Hi * Wi * C, Hi* Wi* C, Wi* C, C, 1}; N * Hi * Wi * C, Hi* Wi* C, 1, Wi* C, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
K * Y * X * C, Y* X* C, 1, X* C, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Ho * Wo * K, Ho* Wo* K, Wo* K, K, 1}; N * Ho * Wo * K, Ho* Wo* K, 1, Wo* K, K};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1};
...@@ -45,14 +47,11 @@ int main() ...@@ -45,14 +47,11 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>(G, OutLayout>(input_lengths,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides, input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -28,13 +28,15 @@ static constexpr ck::index_t Wi = 3; ...@@ -28,13 +28,15 @@ static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28; static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3; static constexpr ck::index_t Wo = 3;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Di, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, Z, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1}; N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1}; N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
...@@ -48,14 +50,11 @@ int main() ...@@ -48,14 +50,11 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>(G, OutLayout>(input_lengths,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides, input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -28,13 +28,15 @@ static constexpr ck::index_t Wi = 3; ...@@ -28,13 +28,15 @@ static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28; static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3; static constexpr ck::index_t Wo = 3;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Di, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, Z, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1}; N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1}; N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
...@@ -48,20 +50,16 @@ int main() ...@@ -48,20 +50,16 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>( OutLayout>(input_lengths,
G, input_strides,
N, filter_lengths,
K, weights_strides,
C, output_lengths,
{Di, Hi, Wi}, output_strides,
{Z, Y, X}, conv_filter_strides,
{Do, Ho, Wo}, conv_filter_dilations,
{N * Di * Hi * Wi * C, Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C, 1}, input_left_pads,
{N * Do * Ho * Wo * K, Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K, 1}, input_right_pads)
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1})
? EXIT_SUCCESS ? EXIT_SUCCESS
: EXIT_FAILURE; : EXIT_FAILURE;
} }
...@@ -72,10 +72,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -72,10 +72,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
// init to 0 // init to 0
wei_device_buf.SetZero(); wei_device_buf.SetZero();
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{}; std::array<ck::index_t, NDimSpatial + 3> input_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_strides{}; std::array<ck::index_t, NDimSpatial + 3> input_strides{};
std::array<ck::index_t, NDimSpatial + 3> filter_lengths{};
std::array<ck::index_t, NDimSpatial + 3> weights_strides{};
std::array<ck::index_t, NDimSpatial + 3> output_lengths{};
std::array<ck::index_t, NDimSpatial + 3> output_strides{}; std::array<ck::index_t, NDimSpatial + 3> output_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
...@@ -84,10 +85,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -84,10 +85,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths)); range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths));
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths));
range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides));
range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths));
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
...@@ -100,14 +102,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -100,14 +102,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.G_, input_lengths,
conv_param.N_,
conv_param.K_,
conv_param.C_,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides, input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -27,15 +27,12 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator ...@@ -27,15 +27,12 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
MakeArgumentPointer(const void* p_in, MakeArgumentPointer(const void* p_in,
void* p_wei, void* p_wei,
const void* p_out, const void* p_out,
const ck::index_t G, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
......
...@@ -784,15 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -784,15 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
Argument(const InDataType* p_in_grid, Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const ck::index_t G, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& /*a_g_n_c_wis_strides*/,
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& /*b_g_k_c_xs_strides*/,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& /*e_g_n_k_wos_strides*/,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -812,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -812,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
a_element_op_{out_element_op}, a_element_op_{out_element_op},
b_element_op_{wei_element_op}, b_element_op_{wei_element_op},
c_element_op_{in_element_op}, c_element_op_{in_element_op},
Conv_G_{G}, Conv_G_{a_g_n_c_wis_lengths[0]},
Conv_N_{N}, Conv_N_{a_g_n_c_wis_lengths[1]},
Conv_K_{K}, Conv_K_{b_g_k_c_xs_lengths[1]},
Conv_C_{C}, Conv_C_{a_g_n_c_wis_lengths[2]},
input_spatial_lengths_{input_spatial_lengths}, input_spatial_lengths_{},
filter_spatial_lengths_{filter_spatial_lengths}, filter_spatial_lengths_{},
output_spatial_lengths_{output_spatial_lengths}, output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations}, conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}, input_right_pads_{input_right_pads},
k_batch_{split_k} k_batch_{split_k}
{ {
constexpr index_t spatial_offset = 3;
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
end(a_g_n_c_wis_lengths),
begin(input_spatial_lengths_));
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
end(b_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
end(e_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
N, Conv_N_,
K, Conv_K_,
C, Conv_C_,
input_spatial_lengths, input_spatial_lengths_,
filter_spatial_lengths, filter_spatial_lengths_,
output_spatial_lengths, output_spatial_lengths_,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -856,21 +864,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -856,21 +864,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// A/B/C Batch Stride // A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = compute_ptr_offset_of_batch_.BatchStrideA_ =
N * K * Conv_N_ * Conv_K_ *
std::accumulate(begin(output_spatial_lengths), std::accumulate(begin(output_spatial_lengths_),
end(output_spatial_lengths), end(output_spatial_lengths_),
index_t{1}, index_t{1},
std::multiplies<>{}); std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideB_ = compute_ptr_offset_of_batch_.BatchStrideB_ =
N * C * Conv_N_ * Conv_C_ *
std::accumulate(begin(input_spatial_lengths), std::accumulate(begin(input_spatial_lengths_),
end(input_spatial_lengths), end(input_spatial_lengths_),
index_t{1}, index_t{1},
std::multiplies<>{}); std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ = compute_ptr_offset_of_batch_.BatchStrideC_ =
K * C * Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths), std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths), end(filter_spatial_lengths_),
index_t{1}, index_t{1},
std::multiplies<>{}); std::multiplies<>{});
} }
...@@ -904,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -904,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const index_t Conv_K_; const index_t Conv_K_;
const index_t Conv_C_; const index_t Conv_C_;
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths_; std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_; std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_; const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_; const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_; const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
...@@ -1110,39 +1118,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1110,39 +1118,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const InDataType* p_in_grid, static auto
WeiDataType* p_wei_grid, MakeArgument(const InDataType* p_in_grid,
const OutDataType* p_out_grid, WeiDataType* p_wei_grid,
const ck::index_t G, const OutDataType* p_out_grid,
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, InElementwiseOperation in_element_op,
const std::array<ck::index_t, NDimSpatial>& input_right_pads, WeiElementwiseOperation wei_element_op,
InElementwiseOperation in_element_op, OutElementwiseOperation out_element_op,
WeiElementwiseOperation wei_element_op, ck::index_t split_k)
OutElementwiseOperation out_element_op,
ck::index_t split_k)
{ {
return Argument{p_in_grid, return Argument{p_in_grid,
p_wei_grid, p_wei_grid,
p_out_grid, p_out_grid,
G, a_g_n_c_wis_lengths, // input
N, a_g_n_c_wis_strides,
K, b_g_k_c_xs_lengths, // weight
C, b_g_k_c_xs_strides,
input_spatial_lengths, e_g_n_k_wos_lengths, // output
filter_spatial_lengths, e_g_n_k_wos_strides,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -1159,15 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1159,15 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid, void* p_wei_grid,
const void* p_out_grid, const void* p_out_grid,
const ck::index_t G, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -1180,15 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1180,15 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<WeiDataType*>(p_wei_grid), static_cast<WeiDataType*>(p_wei_grid),
static_cast<const OutDataType*>(p_out_grid), static_cast<const OutDataType*>(p_out_grid),
G, a_g_n_c_wis_lengths, // input
N, a_g_n_c_wis_strides,
K, b_g_k_c_xs_lengths, // weight
C, b_g_k_c_xs_strides,
input_spatial_lengths, e_g_n_k_wos_lengths, // output
filter_spatial_lengths, e_g_n_k_wos_strides,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -136,10 +136,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ...@@ -136,10 +136,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
// profile device Conv instances // profile device Conv instances
bool all_pass = true; bool all_pass = true;
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{}; std::array<ck::index_t, NDimSpatial + 3> input_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{}; std::array<ck::index_t, NDimSpatial + 3> filter_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{}; std::array<ck::index_t, NDimSpatial + 3> output_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_strides{}; std::array<ck::index_t, NDimSpatial + 3> input_strides{};
std::array<ck::index_t, NDimSpatial + 3> weights_strides{};
std::array<ck::index_t, NDimSpatial + 3> output_strides{}; std::array<ck::index_t, NDimSpatial + 3> output_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
...@@ -148,10 +149,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ...@@ -148,10 +149,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths)); range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths));
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths));
range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides));
range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths));
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
...@@ -164,14 +166,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ...@@ -164,14 +166,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.G_, input_lengths,
conv_param.N_,
conv_param.K_,
conv_param.C_,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides, input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -70,10 +70,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -70,10 +70,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>( ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param); conv_param);
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{}; std::array<ck::index_t, NDimSpatial + 3> input_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{}; std::array<ck::index_t, NDimSpatial + 3> filter_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{}; std::array<ck::index_t, NDimSpatial + 3> output_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_strides{}; std::array<ck::index_t, NDimSpatial + 3> input_strides{};
std::array<ck::index_t, NDimSpatial + 3> weights_strides{};
std::array<ck::index_t, NDimSpatial + 3> output_strides{}; std::array<ck::index_t, NDimSpatial + 3> output_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
...@@ -82,10 +83,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -82,10 +83,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths)); range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths));
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths));
range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides));
range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths));
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
...@@ -97,14 +99,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -97,14 +99,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
auto argument = conv.MakeArgument(nullptr, auto argument = conv.MakeArgument(nullptr,
nullptr, nullptr,
nullptr, nullptr,
conv_param.G_, input_lengths,
conv_param.N_,
conv_param.K_,
conv_param.C_,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides, input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment