Commit d9eadda9 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Change gemm layout to 3d

parent 33b4b52c
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
using InDataType = ck::half_t; using InDataType = ck::half_t;
using OutDataType = ck::half_t; using OutDataType = ck::half_t;
using ImageLayout = ck::tensor_layout::convolution::GNHWC; using ImageLayout = ck::tensor_layout::convolution::NHWGC;
static constexpr ck::index_t NumDimSpatial = 2; static constexpr ck::index_t NumDimSpatial = 2;
static constexpr ck::index_t G = 2; static constexpr ck::index_t G = 2;
...@@ -56,7 +56,7 @@ int main() ...@@ -56,7 +56,7 @@ int main()
// However, CK's API only accept length and stride with order of GNCHW // However, CK's API only accept length and stride with order of GNCHW
// Hence, we need to adjust the order of stride // Hence, we need to adjust the order of stride
std::array<ck::index_t, 5> image_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array<ck::index_t, 5> image_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C};
std::array<ck::index_t, 2> gemm_strides{Y * X * C, 1}; std::array<ck::index_t, 3> gemm_strides{Y * X * C, G * Y * X * C, 1};
std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1}; std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1};
std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1}; std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1};
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
using InDataType = ck::half_t; using InDataType = ck::half_t;
using OutDataType = ck::half_t; using OutDataType = ck::half_t;
using ImageLayout = ck::tensor_layout::convolution::GNHWC; using ImageLayout = ck::tensor_layout::convolution::NHWGC;
static constexpr ck::index_t NumDimSpatial = 2; static constexpr ck::index_t NumDimSpatial = 2;
static constexpr ck::index_t G = 2; static constexpr ck::index_t G = 2;
...@@ -52,11 +52,11 @@ int main() ...@@ -52,11 +52,11 @@ int main()
std::array<ck::index_t, 2> wei_spatial_lengths{Y, X}; std::array<ck::index_t, 2> wei_spatial_lengths{Y, X};
std::array<ck::index_t, 2> out_spatial_lengths{Ho, Wo}; std::array<ck::index_t, 2> out_spatial_lengths{Ho, Wo};
// We have NHWGC in memory space (G is dummy) // We have NHWGC in memory space
// However, CK's API only accept length and stride with order of GNCHW // However, CK's API only accept length and stride with order of GNCHW
// Hence, we need to adjust the order of stride // Hence, we need to adjust the order of stride
std::array<ck::index_t, 5> image_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array<ck::index_t, 5> image_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C};
std::array<ck::index_t, 2> gemm_strides{Y * X * C, 1}; std::array<ck::index_t, 3> gemm_strides{Y * X * C, G * Y * X * C, 1};
std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1}; std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1};
std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1}; std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1};
......
...@@ -24,15 +24,14 @@ bool RunColumnToImage(const ExecutionConfig& config, const ck::utils::conv::Conv ...@@ -24,15 +24,14 @@ bool RunColumnToImage(const ExecutionConfig& config, const ck::utils::conv::Conv
const auto N = conv_params.N_; const auto N = conv_params.N_;
const auto C = conv_params.C_; const auto C = conv_params.C_;
const ck::index_t GNDoHoWo = const ck::index_t NDoHoWo =
G * N * N * ck::accumulate_n<ck::index_t>(
ck::accumulate_n<ck::index_t>( conv_params.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
conv_params.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const ck::index_t CZYX = const ck::index_t CZYX =
C * ck::accumulate_n<ck::index_t>( C * ck::accumulate_n<ck::index_t>(
conv_params.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); conv_params.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const auto in_desc = HostTensorDescriptor({GNDoHoWo, CZYX}); const auto in_desc = HostTensorDescriptor({G, NDoHoWo, CZYX});
const auto out_desc = const auto out_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(conv_params); ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(conv_params);
...@@ -40,7 +39,7 @@ bool RunColumnToImage(const ExecutionConfig& config, const ck::utils::conv::Conv ...@@ -40,7 +39,7 @@ bool RunColumnToImage(const ExecutionConfig& config, const ck::utils::conv::Conv
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> image_g_n_c_wis_strides{}; std::array<ck::index_t, NDimSpatial + 3> image_g_n_c_wis_strides{};
std::array<ck::index_t, 2> gemm_m_k_strides{}; std::array<ck::index_t, 3> gemm_g_m_k_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{}; std::array<ck::index_t, NDimSpatial> input_left_pads{};
...@@ -51,7 +50,7 @@ bool RunColumnToImage(const ExecutionConfig& config, const ck::utils::conv::Conv ...@@ -51,7 +50,7 @@ bool RunColumnToImage(const ExecutionConfig& config, const ck::utils::conv::Conv
copy(conv_params.input_spatial_lengths_, input_spatial_lengths); copy(conv_params.input_spatial_lengths_, input_spatial_lengths);
copy(conv_params.filter_spatial_lengths_, filter_spatial_lengths); copy(conv_params.filter_spatial_lengths_, filter_spatial_lengths);
copy(conv_params.output_spatial_lengths_, output_spatial_lengths); copy(conv_params.output_spatial_lengths_, output_spatial_lengths);
copy(in_desc.GetStrides(), gemm_m_k_strides); copy(in_desc.GetStrides(), gemm_g_m_k_strides);
copy(out_desc.GetStrides(), image_g_n_c_wis_strides); copy(out_desc.GetStrides(), image_g_n_c_wis_strides);
copy(conv_params.conv_filter_strides_, conv_filter_strides); copy(conv_params.conv_filter_strides_, conv_filter_strides);
copy(conv_params.conv_filter_dilations_, conv_filter_dilations); copy(conv_params.conv_filter_dilations_, conv_filter_dilations);
...@@ -94,7 +93,7 @@ bool RunColumnToImage(const ExecutionConfig& config, const ck::utils::conv::Conv ...@@ -94,7 +93,7 @@ bool RunColumnToImage(const ExecutionConfig& config, const ck::utils::conv::Conv
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
gemm_m_k_strides, gemm_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -110,7 +109,7 @@ bool RunColumnToImage(const ExecutionConfig& config, const ck::utils::conv::Conv ...@@ -110,7 +109,7 @@ bool RunColumnToImage(const ExecutionConfig& config, const ck::utils::conv::Conv
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t num_btype = GNDoHoWo * CZYX * (sizeof(OutDataType) + sizeof(InDataType)); std::size_t num_btype = G * NDoHoWo * CZYX * (sizeof(OutDataType) + sizeof(InDataType));
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
......
...@@ -24,23 +24,22 @@ bool RunImageToColumn(const ExecutionConfig& config, const ck::utils::conv::Conv ...@@ -24,23 +24,22 @@ bool RunImageToColumn(const ExecutionConfig& config, const ck::utils::conv::Conv
const auto N = conv_params.N_; const auto N = conv_params.N_;
const auto C = conv_params.C_; const auto C = conv_params.C_;
const ck::index_t GNDoHoWo = const ck::index_t NDoHoWo =
G * N * N * ck::accumulate_n<ck::index_t>(
ck::accumulate_n<ck::index_t>( conv_params.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
conv_params.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const ck::index_t CZYX = const ck::index_t CZYX =
C * ck::accumulate_n<ck::index_t>( C * ck::accumulate_n<ck::index_t>(
conv_params.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); conv_params.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const auto in_desc = const auto in_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(conv_params); ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(conv_params);
const auto out_desc = HostTensorDescriptor({GNDoHoWo, CZYX}); const auto out_desc = HostTensorDescriptor({G, NDoHoWo, CZYX});
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> image_g_n_c_wis_strides{}; std::array<ck::index_t, NDimSpatial + 3> image_g_n_c_wis_strides{};
std::array<ck::index_t, 2> gemm_m_k_strides{}; std::array<ck::index_t, 3> gemm_m_k_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{}; std::array<ck::index_t, NDimSpatial> input_left_pads{};
...@@ -110,7 +109,7 @@ bool RunImageToColumn(const ExecutionConfig& config, const ck::utils::conv::Conv ...@@ -110,7 +109,7 @@ bool RunImageToColumn(const ExecutionConfig& config, const ck::utils::conv::Conv
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t num_btype = GNDoHoWo * CZYX * (sizeof(OutDataType) + sizeof(InDataType)); std::size_t num_btype = G * NDoHoWo * CZYX * (sizeof(OutDataType) + sizeof(InDataType));
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
......
...@@ -18,8 +18,8 @@ namespace device { ...@@ -18,8 +18,8 @@ namespace device {
* the gemm problem (Image to Column) and * the gemm problem (Image to Column) and
* conversion gemm form to the image (Column to Image). * conversion gemm form to the image (Column to Image).
* Supported layouts: * Supported layouts:
* [G, N, Di, Hi, Wi, C] <-> [G * N * Do * Ho * Wo, Z * Y * X * C] * [G, N, Di, Hi, Wi, C] <-> [G, N * Do * Ho * Wo, Z * Y * X * C]
* [N, Di, Hi, Wi, G, C] <-> [N * Do * Ho * Wo * G, Z * Y * X * C] * [N, Di, Hi, Wi, G, C] <-> [N * Do * Ho * Wo, G, Z * Y * X * C]
* *
* \tparam NDimSpatial Number of spatial dimensions. * \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Input Layout. * \tparam ImageLayout Input Layout.
...@@ -47,7 +47,7 @@ struct DeviceConvTensorRearrange : public BaseOperator ...@@ -47,7 +47,7 @@ struct DeviceConvTensorRearrange : public BaseOperator
* \param filter_spatial_lengths Filter spatial lengths. * \param filter_spatial_lengths Filter spatial lengths.
* \param output_spatial_lengths Output spatial lengths. * \param output_spatial_lengths Output spatial lengths.
* \param image_g_n_c_wis_strides Image strides in order [G, N, C, D, H, W]. * \param image_g_n_c_wis_strides Image strides in order [G, N, C, D, H, W].
* \param gemm_m_k_strides Gemm form strides. * \param gemm_g_m_k_strides Gemm form strides.
* \param conv_filter_strides Convolution filter strides. * \param conv_filter_strides Convolution filter strides.
* \param conv_filter_dilations Convolution filter dilations. * \param conv_filter_dilations Convolution filter dilations.
* \param input_left_pads Convolution left pads. * \param input_left_pads Convolution left pads.
...@@ -64,7 +64,7 @@ struct DeviceConvTensorRearrange : public BaseOperator ...@@ -64,7 +64,7 @@ struct DeviceConvTensorRearrange : public BaseOperator
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
......
...@@ -25,9 +25,9 @@ namespace tensor_operation { ...@@ -25,9 +25,9 @@ namespace tensor_operation {
namespace device { namespace device {
// Column to Image: // Column to Image:
// input : gemm form [G * N * Do * Ho * Wo, Z * Y * X * C] // input : gemm form [G, N * Do * Ho * Wo, Z * Y * X * C]
// output : input image [G, N, Di, Hi, Wi, C] // output : input image [G, N, Di, Hi, Wi, C]
// input : gemm form [N * Do * Ho * Wo * G, Z * Y * X * C] // input : gemm form [N * Do * Ho * Wo, G, Z * Y * X * C]
// output : input image [N, Di, Hi, Wi, G, C] // output : input image [N, Di, Hi, Wi, G, C]
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename ImageLayout, typename ImageLayout,
...@@ -96,13 +96,12 @@ struct DeviceColumnToImageImpl ...@@ -96,13 +96,12 @@ struct DeviceColumnToImageImpl
// Make column form descriptor // Make column form descriptor
static auto static auto
MakeInputDescriptor_M_K(const ck::index_t G, MakeInputDescriptor_M_K(const ck::index_t N,
const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& independent_filters, const std::array<index_t, NDimSpatial>& independent_filters,
const std::array<index_t, NDimSpatial>& effs) const std::array<index_t, NDimSpatial>& effs)
{ {
...@@ -112,26 +111,23 @@ struct DeviceColumnToImageImpl ...@@ -112,26 +111,23 @@ struct DeviceColumnToImageImpl
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<index_t>(
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t AdditionalGroupStride = is_NSpatialGC ? G : 1; const index_t NStride = DoHoWo * gemm_g_m_k_strides[I1] * gemm_g_m_k_strides[I2];
const index_t NStride =
DoHoWo * gemm_m_k_strides[I0] * gemm_m_k_strides[I1] * AdditionalGroupStride;
// Calculate the appropriate stride for each set of independent filters // Calculate the appropriate stride for each set of independent filters
// in each dimension // in each dimension
const index_t WStride = math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) * const index_t WStride = math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) *
gemm_m_k_strides[I0] * AdditionalGroupStride; gemm_g_m_k_strides[I1];
const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) * const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) *
output_spatial_lengths[XIdx] * gemm_m_k_strides[I0] * output_spatial_lengths[XIdx] * gemm_g_m_k_strides[I1];
AdditionalGroupStride;
const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) * const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) *
output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] * output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] *
gemm_m_k_strides[I0] * AdditionalGroupStride; gemm_g_m_k_strides[I1];
// Create descriptor for independent filters in each dimension and // Create descriptor for independent filters in each dimension and
// then merge them into column form // then merge them into column form
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
const auto desc_gemm_form = const auto desc_gemm_form =
make_naive_tensor_descriptor(make_tuple(N, independent_filters[XIdx], CZYX), make_naive_tensor_descriptor(make_tuple(N, independent_filters[XIdx], CZYX),
make_tuple(NStride, WStride, gemm_m_k_strides[I1])); make_tuple(NStride, WStride, gemm_g_m_k_strides[I2]));
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
desc_gemm_form, desc_gemm_form,
make_tuple(make_merge_transform(make_tuple(N, independent_filters[XIdx])), make_tuple(make_merge_transform(make_tuple(N, independent_filters[XIdx])),
...@@ -145,7 +141,7 @@ struct DeviceColumnToImageImpl ...@@ -145,7 +141,7 @@ struct DeviceColumnToImageImpl
{ {
const auto desc_gemm_form = make_naive_tensor_descriptor( const auto desc_gemm_form = make_naive_tensor_descriptor(
make_tuple(N, independent_filters[YIdx], independent_filters[XIdx], CZYX), make_tuple(N, independent_filters[YIdx], independent_filters[XIdx], CZYX),
make_tuple(NStride, HStride, WStride, gemm_m_k_strides[I1])); make_tuple(NStride, HStride, WStride, gemm_g_m_k_strides[I2]));
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
desc_gemm_form, desc_gemm_form,
make_tuple(make_merge_transform( make_tuple(make_merge_transform(
...@@ -164,7 +160,7 @@ struct DeviceColumnToImageImpl ...@@ -164,7 +160,7 @@ struct DeviceColumnToImageImpl
independent_filters[YIdx], independent_filters[YIdx],
independent_filters[XIdx], independent_filters[XIdx],
CZYX), CZYX),
make_tuple(NStride, DStride, HStride, WStride, gemm_m_k_strides[I1])); make_tuple(NStride, DStride, HStride, WStride, gemm_g_m_k_strides[I2]));
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
desc_gemm_form, desc_gemm_form,
make_tuple(make_merge_transform(make_tuple(N, make_tuple(make_merge_transform(make_tuple(N,
...@@ -259,7 +255,7 @@ struct DeviceColumnToImageImpl ...@@ -259,7 +255,7 @@ struct DeviceColumnToImageImpl
} }
using InputGridDesc = using InputGridDesc =
remove_cvref_t<decltype(MakeInputDescriptor_M_K(1, 1, 1, {}, {}, {}, {}, {}, {}))>; remove_cvref_t<decltype(MakeInputDescriptor_M_K(1, 1, {}, {}, {}, {}, {}, {}))>;
using OutputGridDesc = remove_cvref_t<decltype(MakeOutDescriptor_M_K( using OutputGridDesc = remove_cvref_t<decltype(MakeOutDescriptor_M_K(
1, 1, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; 1, 1, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
...@@ -292,7 +288,7 @@ struct DeviceColumnToImageImpl ...@@ -292,7 +288,7 @@ struct DeviceColumnToImageImpl
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -308,21 +304,7 @@ struct DeviceColumnToImageImpl ...@@ -308,21 +304,7 @@ struct DeviceColumnToImageImpl
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads}
{ {
using namespace tensor_layout::convolution; compute_ptr_offset_of_batch_.BatchStrideA_ = gemm_g_m_k_strides[I0];
if constexpr(is_NSpatialGC)
{
compute_ptr_offset_of_batch_.BatchStrideA_ =
gemm_m_k_strides[I0] * gemm_m_k_strides[I1];
}
else if constexpr(is_GNSpatialC)
{
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
compute_ptr_offset_of_batch_.BatchStrideA_ =
NDoHoWo * gemm_m_k_strides[I0] * gemm_m_k_strides[I1];
}
compute_ptr_offset_of_batch_.BatchStrideC_ = image_g_n_c_wis_strides[I0]; compute_ptr_offset_of_batch_.BatchStrideC_ = image_g_n_c_wis_strides[I0];
const index_t x_eff = const index_t x_eff =
...@@ -385,13 +367,12 @@ struct DeviceColumnToImageImpl ...@@ -385,13 +367,12 @@ struct DeviceColumnToImageImpl
continue; continue;
const auto in_grid_desc_m_k = const auto in_grid_desc_m_k =
MakeInputDescriptor_M_K(G, MakeInputDescriptor_M_K(N,
N,
C, C,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
conv_filter_strides, conv_filter_strides,
gemm_m_k_strides, gemm_g_m_k_strides,
independent_filters, independent_filters,
effs); effs);
const auto out_grid_desc_m_k = const auto out_grid_desc_m_k =
...@@ -421,13 +402,12 @@ struct DeviceColumnToImageImpl ...@@ -421,13 +402,12 @@ struct DeviceColumnToImageImpl
const index_t z_offset_with_pad = const index_t z_offset_with_pad =
math::max(0, z_img_offset - input_left_pads[ZIdx]); math::max(0, z_img_offset - input_left_pads[ZIdx]);
const index_t AdditionalGroupStride = is_NSpatialGC ? G : 1;
// Memory offsets to next set of independent filters, // Memory offsets to next set of independent filters,
// move to independent filters in each dimension // move to independent filters in each dimension
const index_t in_offset = const index_t in_offset =
(x_idx + y_idx * output_spatial_lengths[XIdx] + (x_idx + y_idx * output_spatial_lengths[XIdx] +
z_idx * output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx]) * z_idx * output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx]) *
gemm_m_k_strides[0] * AdditionalGroupStride; gemm_g_m_k_strides[I1];
// Move to independent filters in appropriate dimensions // Move to independent filters in appropriate dimensions
const index_t out_offset = const index_t out_offset =
x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] + x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] +
...@@ -583,7 +563,7 @@ struct DeviceColumnToImageImpl ...@@ -583,7 +563,7 @@ struct DeviceColumnToImageImpl
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -598,7 +578,7 @@ struct DeviceColumnToImageImpl ...@@ -598,7 +578,7 @@ struct DeviceColumnToImageImpl
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
gemm_m_k_strides, gemm_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -617,7 +597,7 @@ struct DeviceColumnToImageImpl ...@@ -617,7 +597,7 @@ struct DeviceColumnToImageImpl
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -632,7 +612,7 @@ struct DeviceColumnToImageImpl ...@@ -632,7 +612,7 @@ struct DeviceColumnToImageImpl
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
gemm_m_k_strides, gemm_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -116,12 +116,11 @@ struct DeviceImageToColumnImpl ...@@ -116,12 +116,11 @@ struct DeviceImageToColumnImpl
} }
static auto static auto
MakeOutDescriptor_M_K(const ck::index_t G, MakeOutDescriptor_M_K(const ck::index_t N,
const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, 2>& gemm_m_k_strides) const std::array<index_t, 3>& gemm_g_m_k_strides)
{ {
const index_t NDoHoWo = const index_t NDoHoWo =
N * ck::accumulate_n<index_t>( N * ck::accumulate_n<index_t>(
...@@ -130,24 +129,14 @@ struct DeviceImageToColumnImpl ...@@ -130,24 +129,14 @@ struct DeviceImageToColumnImpl
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<index_t>(
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
if constexpr(is_NSpatialGC) const auto desc_mraw_kraw = make_naive_tensor_descriptor(
{ make_tuple(NDoHoWo, CZYX), make_tuple(gemm_g_m_k_strides[I1], gemm_g_m_k_strides[I2]));
const auto desc_mraw_kraw = make_naive_tensor_descriptor( return matrix_padder.PadADescriptor_M_K(desc_mraw_kraw);
make_tuple(NDoHoWo, CZYX),
make_tuple(gemm_m_k_strides[I0] * G, gemm_m_k_strides[I1]));
return matrix_padder.PadADescriptor_M_K(desc_mraw_kraw);
}
else if constexpr(is_GNSpatialC)
{
const auto desc_mraw_kraw = make_naive_tensor_descriptor(
make_tuple(NDoHoWo, CZYX), make_tuple(gemm_m_k_strides[I0], gemm_m_k_strides[I1]));
return matrix_padder.PadADescriptor_M_K(desc_mraw_kraw);
}
} }
using InputGridDesc = using InputGridDesc =
remove_cvref_t<decltype(MakeInputDescriptor_M_K(1, 1, {}, {}, {}, {}, {}, {}, {}, {}))>; remove_cvref_t<decltype(MakeInputDescriptor_M_K(1, 1, {}, {}, {}, {}, {}, {}, {}, {}))>;
using OutputGridDesc = remove_cvref_t<decltype(MakeOutDescriptor_M_K(1, 1, 1, {}, {}, {}))>; using OutputGridDesc = remove_cvref_t<decltype(MakeOutDescriptor_M_K(1, 1, {}, {}, {}))>;
using Block2ETileMap = remove_cvref_t< using Block2ETileMap = remove_cvref_t<
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>( decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
...@@ -178,7 +167,7 @@ struct DeviceImageToColumnImpl ...@@ -178,7 +167,7 @@ struct DeviceImageToColumnImpl
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -207,22 +196,10 @@ struct DeviceImageToColumnImpl ...@@ -207,22 +196,10 @@ struct DeviceImageToColumnImpl
input_right_pads); input_right_pads);
out_grid_desc_m_k_ = MakeOutDescriptor_M_K( out_grid_desc_m_k_ = MakeOutDescriptor_M_K(
G, N, C, filter_spatial_lengths, output_spatial_lengths, gemm_m_k_strides); N, C, filter_spatial_lengths, output_spatial_lengths, gemm_g_m_k_strides);
compute_ptr_offset_of_batch_.BatchStrideA_ = image_g_n_c_wis_strides[I0]; compute_ptr_offset_of_batch_.BatchStrideA_ = image_g_n_c_wis_strides[I0];
if constexpr(is_NSpatialGC) compute_ptr_offset_of_batch_.BatchStrideC_ = gemm_g_m_k_strides[I0];
{
compute_ptr_offset_of_batch_.BatchStrideC_ =
gemm_m_k_strides[I0] * gemm_m_k_strides[I1];
}
else if constexpr(is_GNSpatialC)
{
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
compute_ptr_offset_of_batch_.BatchStrideC_ =
NDoHoWo * gemm_m_k_strides[I0] * gemm_m_k_strides[I1];
}
} }
void Print() const void Print() const
...@@ -346,7 +323,7 @@ struct DeviceImageToColumnImpl ...@@ -346,7 +323,7 @@ struct DeviceImageToColumnImpl
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -361,7 +338,7 @@ struct DeviceImageToColumnImpl ...@@ -361,7 +338,7 @@ struct DeviceImageToColumnImpl
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
gemm_m_k_strides, gemm_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -380,7 +357,7 @@ struct DeviceImageToColumnImpl ...@@ -380,7 +357,7 @@ struct DeviceImageToColumnImpl
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -395,7 +372,7 @@ struct DeviceImageToColumnImpl ...@@ -395,7 +372,7 @@ struct DeviceImageToColumnImpl
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
gemm_m_k_strides, gemm_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -92,9 +92,8 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -92,9 +92,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
using namespace ck::tensor_layout::convolution;
if(!(arg.output_.GetNumOfDimension() == NDimSpatial + 3 && if(!(arg.output_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.input_.GetNumOfDimension() == 2)) arg.input_.GetNumOfDimension() == 3))
{ {
throw std::runtime_error("wrong! inconsistent dimension"); throw std::runtime_error("wrong! inconsistent dimension");
} }
...@@ -111,14 +110,6 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -111,14 +110,6 @@ struct ReferenceColumnToImage : public device::BaseOperator
{ {
index_t row = n * Wo + wo; index_t row = n * Wo + wo;
index_t column = 0; index_t column = 0;
if constexpr(std::is_same_v<ImageLayout, GNWC>)
{
row = g * N * Wo + n * Wo + wo;
}
else if constexpr(std::is_same_v<ImageLayout, NWGC>)
{
row = n * Wo * G + wo * G + g;
}
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x) for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
{ {
...@@ -131,7 +122,8 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -131,7 +122,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
{ {
float v_in = ck::type_convert<float>(arg.input_(row, column)); float v_in =
ck::type_convert<float>(arg.input_(g, row, column));
float v_out = ck::type_convert<float>(arg.output_(g, n, c, wi)); float v_out = ck::type_convert<float>(arg.output_(g, n, c, wi));
arg.output_(g, n, c, wi) = arg.output_(g, n, c, wi) =
ck::type_convert<OutDataType>(v_in + v_out); ck::type_convert<OutDataType>(v_in + v_out);
...@@ -156,16 +148,8 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -156,16 +148,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
{ {
for(index_t wo = 0; wo < Wo; ++wo) for(index_t wo = 0; wo < Wo; ++wo)
{ {
index_t row = 0; index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0; index_t column = 0;
if constexpr(std::is_same_v<ImageLayout, GNHWC>)
{
row = g * N * Ho * Wo + n * Ho * Wo + ho * Wo + wo;
}
else if constexpr(std::is_same_v<ImageLayout, NHWGC>)
{
row = n * Ho * Wo * G + ho * Wo * G + wo * G + g;
}
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y) for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
{ {
...@@ -192,7 +176,7 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -192,7 +176,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg.output_.GetLengths()[4]) arg.output_.GetLengths()[4])
{ {
float v_in = float v_in =
ck::type_convert<float>(arg.input_(row, column)); ck::type_convert<float>(arg.input_(g, row, column));
float v_out = ck::type_convert<float>( float v_out = ck::type_convert<float>(
arg.output_(g, n, c, hi, wi)); arg.output_(g, n, c, hi, wi));
arg.output_(g, n, c, hi, wi) = arg.output_(g, n, c, hi, wi) =
...@@ -223,18 +207,8 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -223,18 +207,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
{ {
for(index_t wo = 0; wo < Wo; ++wo) for(index_t wo = 0; wo < Wo; ++wo)
{ {
index_t row = 0; index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0; index_t column = 0;
if constexpr(std::is_same_v<ImageLayout, GNDHWC>)
{
row = g * N * Do * Ho * Wo + n * Do * Ho * Wo + d_o * Ho * Wo +
ho * Wo + wo;
}
else if constexpr(std::is_same_v<ImageLayout, NDHWGC>)
{
row = n * Do * Ho * Wo * G + d_o * Ho * Wo * G + ho * Wo * G +
wo * G + g;
}
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z) for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
{ {
...@@ -271,7 +245,7 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -271,7 +245,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg.output_.GetLengths()[5]) arg.output_.GetLengths()[5])
{ {
float v_in = ck::type_convert<float>( float v_in = ck::type_convert<float>(
arg.input_(row, column)); arg.input_(g, row, column));
float v_out = ck::type_convert<float>( float v_out = ck::type_convert<float>(
arg.output_(g, n, c, di, hi, wi)); arg.output_(g, n, c, di, hi, wi));
arg.output_(g, n, c, di, hi, wi) = arg.output_(g, n, c, di, hi, wi) =
...@@ -329,8 +303,9 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -329,8 +303,9 @@ struct ReferenceColumnToImage : public device::BaseOperator
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(NDoHoWo) && if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) &&
arg.input_.GetLengths()[1] == static_cast<std::size_t>(CZYX))) arg.input_.GetLengths()[1] == static_cast<std::size_t>(NDoHoWo) &&
arg.input_.GetLengths()[2] == static_cast<std::size_t>(CZYX)))
{ {
return false; return false;
} }
......
...@@ -92,9 +92,8 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -92,9 +92,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
using namespace ck::tensor_layout::convolution;
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 && if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == 2)) arg.output_.GetNumOfDimension() == 3))
{ {
throw std::runtime_error("wrong! inconsistent dimension"); throw std::runtime_error("wrong! inconsistent dimension");
} }
...@@ -107,16 +106,8 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -107,16 +106,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
{ {
const index_t Wo = arg.output_spatial_lengths_[0]; const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto g, auto n, auto wo) { auto func = [&](auto g, auto n, auto wo) {
index_t row = 0; index_t row = n * Wo + wo;
index_t column = 0; index_t column = 0;
if constexpr(std::is_same_v<ImageLayout, GNWC>)
{
row = g * N * Wo + n * Wo + wo;
}
else if constexpr(std::is_same_v<ImageLayout, NWGC>)
{
row = n * Wo * G + wo * G + g;
}
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x) for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
{ {
...@@ -129,8 +120,8 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -129,8 +120,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
InDataType v_in = arg.input_(g, n, c, wi); InDataType v_in = arg.input_(g, n, c, wi);
arg.output_(row, column) = ck::type_convert<OutDataType>(v_in); arg.output_(g, row, column) = ck::type_convert<OutDataType>(v_in);
} }
column++; column++;
} }
...@@ -147,16 +138,8 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -147,16 +138,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
const index_t Wo = arg.output_spatial_lengths_[1]; const index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto g, auto n, auto ho, auto wo) { auto func = [&](auto g, auto n, auto ho, auto wo) {
index_t row = 0; index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0; index_t column = 0;
if constexpr(std::is_same_v<ImageLayout, GNHWC>)
{
row = g * N * Ho * Wo + n * Ho * Wo + ho * Wo + wo;
}
else if constexpr(std::is_same_v<ImageLayout, NHWGC>)
{
row = n * Ho * Wo * G + ho * Wo * G + wo * G + g;
}
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y) for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
{ {
...@@ -178,8 +161,9 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -178,8 +161,9 @@ struct ReferenceImageToColumn : public device::BaseOperator
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{ {
InDataType v_in = arg.input_(g, n, c, hi, wi); InDataType v_in = arg.input_(g, n, c, hi, wi);
arg.output_(row, column) = ck::type_convert<OutDataType>(v_in); arg.output_(g, row, column) =
ck::type_convert<OutDataType>(v_in);
} }
column++; column++;
} }
...@@ -198,17 +182,8 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -198,17 +182,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
const index_t Wo = arg.output_spatial_lengths_[2]; const index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) { auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
index_t row = 0; index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0; index_t column = 0;
if constexpr(std::is_same_v<ImageLayout, GNDHWC>)
{
row =
g * N * Do * Ho * Wo + n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
}
else if constexpr(std::is_same_v<ImageLayout, NDHWGC>)
{
row = n * Do * Ho * Wo * G + d_o * Ho * Wo * G + ho * Wo * G + wo * G + g;
}
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z) for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
{ {
...@@ -239,7 +214,7 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -239,7 +214,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
arg.input_.GetLengths()[5]) arg.input_.GetLengths()[5])
{ {
InDataType v_in = arg.input_(g, n, c, di, hi, wi); InDataType v_in = arg.input_(g, n, c, di, hi, wi);
arg.output_(row, column) = arg.output_(g, row, column) =
ck::type_convert<OutDataType>(v_in); ck::type_convert<OutDataType>(v_in);
} }
column++; column++;
...@@ -292,8 +267,9 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -292,8 +267,9 @@ struct ReferenceImageToColumn : public device::BaseOperator
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(NDoHoWo) && if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(G) &&
arg.output_.GetLengths()[1] == static_cast<std::size_t>(CZYX))) arg.output_.GetLengths()[1] == static_cast<std::size_t>(NDoHoWo) &&
arg.output_.GetLengths()[2] == static_cast<std::size_t>(CZYX)))
{ {
return false; return false;
} }
......
...@@ -93,6 +93,26 @@ static auto make_ref_op() ...@@ -93,6 +93,26 @@ static auto make_ref_op()
} }
} }
template <typename InputLayout>
static auto create_gemm_desc(const ck::index_t G, const ck::index_t NDoHoWo, const ck::index_t CZYX)
{
using namespace ck::tensor_layout::convolution;
if constexpr(std::is_same_v<InputLayout, GNWC> || std::is_same_v<InputLayout, GNHWC> ||
std::is_same_v<InputLayout, GNDHWC>)
{
return HostTensorDescriptor({G, NDoHoWo, CZYX});
}
else if constexpr(std::is_same_v<InputLayout, NWGC> || std::is_same_v<InputLayout, NHWGC> ||
std::is_same_v<InputLayout, NDHWGC>)
{
return HostTensorDescriptor({G, NDoHoWo, CZYX}, {CZYX, CZYX * G, 1});
}
else
{
throw std::runtime_error("Unsupported layout!");
}
}
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename InputLayout, typename InputLayout,
typename InputDataType, typename InputDataType,
...@@ -104,8 +124,8 @@ bool profile_conv_tensor_rearrange_impl(int do_verification, ...@@ -104,8 +124,8 @@ bool profile_conv_tensor_rearrange_impl(int do_verification,
bool time_kernel, bool time_kernel,
const ck::utils::conv::ConvParam& conv_param) const ck::utils::conv::ConvParam& conv_param)
{ {
const ck::index_t GNDoHoWo = const ck::index_t NDoHoWo =
conv_param.G_ * conv_param.N_ * conv_param.N_ *
ck::accumulate_n<ck::index_t>( ck::accumulate_n<ck::index_t>(
conv_param.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); conv_param.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const ck::index_t CZYX = const ck::index_t CZYX =
...@@ -116,13 +136,13 @@ bool profile_conv_tensor_rearrange_impl(int do_verification, ...@@ -116,13 +136,13 @@ bool profile_conv_tensor_rearrange_impl(int do_verification,
const auto image_desc = const auto image_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InputLayout>( ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InputLayout>(
conv_param); conv_param);
const auto gemm_desc = HostTensorDescriptor({GNDoHoWo, CZYX}); const auto gemm_desc = create_gemm_desc<InputLayout>(conv_param.G_, NDoHoWo, CZYX);
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> image_g_n_c_wis_strides{}; std::array<ck::index_t, NDimSpatial + 3> image_g_n_c_wis_strides{};
std::array<ck::index_t, 2> gemm_m_k_strides{}; std::array<ck::index_t, 3> gemm_g_m_k_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{}; std::array<ck::index_t, NDimSpatial> input_left_pads{};
...@@ -134,7 +154,7 @@ bool profile_conv_tensor_rearrange_impl(int do_verification, ...@@ -134,7 +154,7 @@ bool profile_conv_tensor_rearrange_impl(int do_verification,
copy(conv_param.filter_spatial_lengths_, filter_spatial_lengths); copy(conv_param.filter_spatial_lengths_, filter_spatial_lengths);
copy(conv_param.output_spatial_lengths_, output_spatial_lengths); copy(conv_param.output_spatial_lengths_, output_spatial_lengths);
copy(image_desc.GetStrides(), image_g_n_c_wis_strides); copy(image_desc.GetStrides(), image_g_n_c_wis_strides);
copy(gemm_desc.GetStrides(), gemm_m_k_strides); copy(gemm_desc.GetStrides(), gemm_g_m_k_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides); copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations); copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_left_pads_, input_left_pads);
...@@ -219,7 +239,7 @@ bool profile_conv_tensor_rearrange_impl(int do_verification, ...@@ -219,7 +239,7 @@ bool profile_conv_tensor_rearrange_impl(int do_verification,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
gemm_m_k_strides, gemm_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -235,7 +255,7 @@ bool profile_conv_tensor_rearrange_impl(int do_verification, ...@@ -235,7 +255,7 @@ bool profile_conv_tensor_rearrange_impl(int do_verification,
float avg_time = float avg_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t num_btype = std::size_t num_btype =
GNDoHoWo * CZYX * (sizeof(OutputDataType) + sizeof(InputDataType)); conv_param.G_ * NDoHoWo * CZYX * (sizeof(OutputDataType) + sizeof(InputDataType));
float gb_per_sec = num_btype / 1.E6 / avg_time; float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, "
<< op_name << std::endl; << op_name << std::endl;
......
...@@ -71,13 +71,13 @@ class TestConvTensorRearrangeInterface : public ::testing::Test ...@@ -71,13 +71,13 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
const auto image_desc = const auto image_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>( ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(
conv_param); conv_param);
const auto gemm_desc = HostTensorDescriptor({NDoHoWo, CZYX}); const auto gemm_desc = HostTensorDescriptor({G, NDoHoWo, CZYX});
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_g_n_c_wis_strides{}; std::array<ck::index_t, NDimSpatial + 3> input_g_n_c_wis_strides{};
std::array<ck::index_t, 2> output_m_k_strides{}; std::array<ck::index_t, 3> output_g_m_k_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{}; std::array<ck::index_t, NDimSpatial> input_left_pads{};
...@@ -89,7 +89,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test ...@@ -89,7 +89,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
copy(conv_param.filter_spatial_lengths_, filter_spatial_lengths); copy(conv_param.filter_spatial_lengths_, filter_spatial_lengths);
copy(conv_param.output_spatial_lengths_, output_spatial_lengths); copy(conv_param.output_spatial_lengths_, output_spatial_lengths);
copy(image_desc.GetStrides(), input_g_n_c_wis_strides); copy(image_desc.GetStrides(), input_g_n_c_wis_strides);
copy(gemm_desc.GetStrides(), output_m_k_strides); copy(gemm_desc.GetStrides(), output_g_m_k_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides); copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations); copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_left_pads_, input_left_pads);
...@@ -107,7 +107,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test ...@@ -107,7 +107,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_g_n_c_wis_strides, input_g_n_c_wis_strides,
output_m_k_strides, output_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -127,7 +127,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test ...@@ -127,7 +127,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_g_n_c_wis_strides, input_g_n_c_wis_strides,
output_m_k_strides, output_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment