Commit 156423d6 authored by rocking's avatar rocking
Browse files

Do not hardcode stride

parent 74284dd2
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp" #include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp" #include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -18,6 +18,43 @@ ...@@ -18,6 +18,43 @@
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp"
template <typename TensorLayout>
std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
ck::index_t C_,
ck::index_t D,
ck::index_t H,
ck::index_t W,
TensorLayout layout)
{
using namespace ck::literals;
(void)N_;
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCDHW>::value)
return {C_ * D * H * W, D * H * W, H * W, W, 1_uz};
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_};
};
template <typename TensorLayout>
HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
std::size_t C_,
std::size_t D,
std::size_t H,
std::size_t W,
TensorLayout layout)
{
using namespace ck::literals;
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCDHW>::value)
{
return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz});
}
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
{
return HostTensorDescriptor({N_, C_, D, H, W},
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
}
};
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename ComputeDataType, typename ComputeDataType,
...@@ -48,20 +85,19 @@ bool pool3d_test(bool do_verification, ...@@ -48,20 +85,19 @@ bool pool3d_test(bool do_verification,
ck::index_t in_right_pad_w) ck::index_t in_right_pad_w)
{ {
using DevicePoolFwdInstance = using DevicePoolFwdInstance =
ck::tensor_operation::device::DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C< ck::tensor_operation::device::DevicePool3dFwdImpl<InDataType, // InDataType
InDataType, // InDataType OutDataType, // OutDataType
OutDataType, // OutDataType IndexDataType, // IndexDataType
IndexDataType, // IndexDataType ComputeDataType, // ComputeDataType
ComputeDataType, // ComputeDataType ReduceOpId,
ReduceOpId, OutputIndex,
OutputIndex, 64, // BlockSize
64, // BlockSize 64, // ReduceMThreadClusterSize
64, // ReduceMThreadClusterSize 1, // ReduceKThreadClusterSize
1, // ReduceKThreadClusterSize 1, // ReduceMThreadSliceSize
4, // ReduceMThreadSliceSize 1, // ReduceKThreadSliceSize
1, // ReduceKThreadSliceSize 1, // InSrcOutDstVectorSize
4>; // InSrcOutDstVectorSize false>; // IsFastestDimReduced
const ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Z) / window_stride_d + 1; const ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Z) / window_stride_d + 1;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1;
...@@ -72,28 +108,6 @@ bool pool3d_test(bool do_verification, ...@@ -72,28 +108,6 @@ bool pool3d_test(bool do_verification,
const std::vector<ck::index_t> input_left_pads{in_left_pad_d, in_left_pad_h, in_left_pad_w}; const std::vector<ck::index_t> input_left_pads{in_left_pad_d, in_left_pad_h, in_left_pad_w};
const std::vector<ck::index_t> input_right_pads{in_right_pad_d, in_right_pad_h, in_right_pad_w}; const std::vector<ck::index_t> input_right_pads{in_right_pad_d, in_right_pad_h, in_right_pad_w};
// tensor layout
auto f_host_tensor_descriptor = [](std::size_t N_,
std::size_t C_,
std::size_t D,
std::size_t H,
std::size_t W,
auto layout) {
using namespace ck::literals;
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCDHW>::value)
{
return HostTensorDescriptor({N_, C_, D, H, W},
{C_ * D * H * W, D * H * W, H * W, W, 1_uz});
}
else if constexpr(ck::is_same<decltype(layout),
ck::tensor_layout::convolution::NDHWC>::value)
{
return HostTensorDescriptor({N_, C_, D, H, W},
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
}
};
Tensor<InDataType> in_n_c_di_hi_wi(f_host_tensor_descriptor(N, C, Di, Hi, Wi, InLayout{})); Tensor<InDataType> in_n_c_di_hi_wi(f_host_tensor_descriptor(N, C, Di, Hi, Wi, InLayout{}));
Tensor<OutDataType> out_n_c_do_ho_wo_host( Tensor<OutDataType> out_n_c_do_ho_wo_host(
f_host_tensor_descriptor(N, C, Do, Ho, Wo, OutLayout{})); f_host_tensor_descriptor(N, C, Do, Ho, Wo, OutLayout{}));
...@@ -126,9 +140,9 @@ bool pool3d_test(bool do_verification, ...@@ -126,9 +140,9 @@ bool pool3d_test(bool do_verification,
{N, C, Di, Hi, Wi}, {N, C, Di, Hi, Wi},
{Z, Y, X}, {Z, Y, X},
{N, C, Do, Ho, Wo}, {N, C, Do, Ho, Wo},
{Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C}, f_tensor_strides_ncdhw(N, C, Di, Hi, Wi, InLayout{}),
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}, f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, OutLayout{}),
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}, f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, OutLayout{}),
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" #include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -29,8 +30,9 @@ template <typename InDataType, ...@@ -29,8 +30,9 @@ template <typename InDataType,
ck::index_t KThreadClusterSize, ck::index_t KThreadClusterSize,
ck::index_t MThreadSliceSize, ck::index_t MThreadSliceSize,
ck::index_t KThreadSliceSize, ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize> ck::index_t InSrcOutDstVectorSize,
struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C bool IsFastestDimReduced>
struct DevicePool3dFwdImpl
: public DevicePoolFwd<5, 3, InDataType, OutDataType, IndexDataType, ReduceOpId, OutputIndex> : public DevicePoolFwd<5, 3, InDataType, OutDataType, IndexDataType, ReduceOpId, OutputIndex>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -51,29 +53,27 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -51,29 +53,27 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
using AccElementwiseOperation = using AccElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation; typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
// for NDHWC, the dim C is the vector Dim for both input and output in memory, which is not
// reduced.
static constexpr index_t InSrcOutDstVectorDim = 0;
static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N, static auto MakeABGridDescriptor_A_M_K_B_M(std::vector<ck::index_t> input_lengths,
ck::index_t C, std::vector<ck::index_t> output_lengths,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> input_stride,
std::vector<ck::index_t> output_stride,
std::vector<ck::index_t> window_spatial_lengths, std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) std::vector<ck::index_t> input_right_pads)
{ {
const index_t Di = input_spatial_lengths[0]; const index_t N = input_lengths[0];
const index_t Hi = input_spatial_lengths[1]; const index_t C = input_lengths[1];
const index_t Wi = input_spatial_lengths[2]; const index_t Di = input_lengths[2];
const index_t Hi = input_lengths[3];
const index_t Wi = input_lengths[4];
const index_t Do = output_spatial_lengths[0]; const index_t Do = output_lengths[2];
const index_t Ho = output_spatial_lengths[1]; const index_t Ho = output_lengths[3];
const index_t Wo = output_spatial_lengths[2]; const index_t Wo = output_lengths[4];
const index_t Z = window_spatial_lengths[0]; const index_t Z = window_spatial_lengths[0];
const index_t Y = window_spatial_lengths[1]; const index_t Y = window_spatial_lengths[1];
...@@ -98,8 +98,15 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -98,8 +98,15 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw; const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
// A[ReduceM, ReduceK] // A[ReduceM, ReduceK]
const auto in_grid_desc_n_di_hi_wi_c = const index_t Ni_stride = input_stride[0];
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); const index_t Ci_stride = input_stride[1];
const index_t Di_stride = input_stride[2];
const index_t Hi_stride = input_stride[3];
const index_t Wi_stride = input_stride[4];
const auto in_grid_desc_n_di_hi_wi_c = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(Ni_stride, Di_stride, Hi_stride, Wi_stride, Ci_stride));
const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor( const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor(
in_grid_desc_n_di_hi_wi_c, in_grid_desc_n_di_hi_wi_c,
...@@ -139,8 +146,21 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -139,8 +146,21 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// B[ReduceM] // B[ReduceM]
const auto out_grid_desc_reducemraw = const index_t No_stride = output_stride[0];
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo * C)); const index_t Co_stride = output_stride[1];
const index_t Do_stride = output_stride[2];
const index_t Ho_stride = output_stride[3];
const index_t Wo_stride = output_stride[4];
const auto out_grid_desc_n_do_ho_wo_c = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(No_stride, Do_stride, Ho_stride, Wo_stride, Co_stride));
const auto out_grid_desc_reducemraw = transform_tensor_descriptor(
out_grid_desc_n_do_ho_wo_c,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto out_grid_desc_reducem = const auto out_grid_desc_reducem =
transform_tensor_descriptor(out_grid_desc_reducemraw, transform_tensor_descriptor(out_grid_desc_reducemraw,
...@@ -151,7 +171,8 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -151,7 +171,8 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
} }
using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(1, 1, {}, {}, {}, {}, {}, {})); using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}));
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>; using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>; using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
...@@ -160,11 +181,12 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -160,11 +181,12 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
Argument(const InDataType* p_in_dev, Argument(const InDataType* p_in_dev,
OutDataType* p_out_dev, OutDataType* p_out_dev,
IndexDataType* p_out_indices_dev, IndexDataType* p_out_indices_dev,
ck::index_t N, std::vector<ck::index_t>& input_lengths,
ck::index_t C, std::vector<ck::index_t>& output_lengths,
std::vector<ck::index_t>& input_spatial_lengths, std::vector<ck::index_t>& input_stride,
std::vector<ck::index_t>& output_stride,
std::vector<ck::index_t>&, // indices_stride
std::vector<ck::index_t>& window_spatial_lengths, std::vector<ck::index_t>& window_spatial_lengths,
std::vector<ck::index_t>& output_spatial_lengths,
std::vector<ck::index_t>& window_strides, std::vector<ck::index_t>& window_strides,
std::vector<ck::index_t>& input_left_pads, std::vector<ck::index_t>& input_left_pads,
std::vector<ck::index_t>& input_right_pads) std::vector<ck::index_t>& input_right_pads)
...@@ -174,11 +196,11 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -174,11 +196,11 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
a_grid_desc_m_k_{}, a_grid_desc_m_k_{},
b_grid_desc_m_{} b_grid_desc_m_{}
{ {
const auto descs = MakeABGridDescriptor_A_M_K_B_M(N, const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_lengths,
C, output_lengths,
input_spatial_lengths, input_stride,
output_stride,
window_spatial_lengths, window_spatial_lengths,
output_spatial_lengths,
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads);
...@@ -186,7 +208,8 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -186,7 +208,8 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
a_grid_desc_m_k_ = descs[I0]; a_grid_desc_m_k_ = descs[I0];
b_grid_desc_m_ = descs[I1]; b_grid_desc_m_ = descs[I1];
invariant_lowest_length_ = C; // C
invariant_lowest_length_ = input_lengths[1];
int32_t reduceLength = int32_t reduceLength =
window_spatial_lengths[0] * window_spatial_lengths[1] * window_spatial_lengths[2]; window_spatial_lengths[0] * window_spatial_lengths[1] * window_spatial_lengths[2];
...@@ -211,6 +234,11 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -211,6 +234,11 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
// for NDHWC, the dim C is the vector Dim for both input and output in memory, which is
// not reduced.
// reduce [D, H, W]
static constexpr index_t InSrcOutDstVectorDim = IsFastestDimReduced ? 1 : 0;
using gridwise_reduce = using gridwise_reduce =
GridwiseReduction_mk_to_m_threadwise<InDataType, GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType, OutDataType,
...@@ -291,9 +319,9 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -291,9 +319,9 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
std::vector<ck::index_t> input_lengths, std::vector<ck::index_t> input_lengths,
std::vector<ck::index_t> window_lengths, std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> output_lengths, std::vector<ck::index_t> output_lengths,
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC std::vector<ck::index_t> input_stride,
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC std::vector<ck::index_t> output_stride,
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC std::vector<ck::index_t> indices_stride,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
...@@ -307,26 +335,18 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -307,26 +335,18 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
if(pooling_dims != std::vector<ck::index_t>{2, 3, 4}) if(pooling_dims != std::vector<ck::index_t>{2, 3, 4})
throw std::runtime_error("pooling_dims only support {2, 3, 4} in pool3d so far"); throw std::runtime_error("pooling_dims only support {2, 3, 4} in pool3d so far");
index_t N = input_lengths[0]; if(output_stride != indices_stride)
index_t C = input_lengths[1]; throw std::runtime_error("output_stride need to be equal to indices_stride for now");
index_t Di = input_lengths[2];
index_t Hi = input_lengths[3];
index_t Wi = input_lengths[4];
index_t Do = output_lengths[2];
index_t Ho = output_lengths[3];
index_t Wo = output_lengths[4];
std::vector<ck::index_t> input_spatial_lengths = {Di, Hi, Wi};
std::vector<ck::index_t> output_spatial_lengths = {Do, Ho, Wo};
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
static_cast<OutDataType*>(p_out_dev), static_cast<OutDataType*>(p_out_dev),
static_cast<IndexDataType*>(p_out_indices_dev), static_cast<IndexDataType*>(p_out_indices_dev),
N, input_lengths,
C, output_lengths,
input_spatial_lengths, input_stride,
output_stride,
indices_stride,
window_lengths, window_lengths,
output_spatial_lengths,
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads);
...@@ -342,7 +362,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -342,7 +362,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<" << BlockSize << ","; str << "DevicePool3dFwdImpl<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp" #include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp" #include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_impl.hpp"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...@@ -31,8 +31,8 @@ using device_pool2d_fwd_nhwc_instances = ...@@ -31,8 +31,8 @@ using device_pool2d_fwd_nhwc_instances =
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>, DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>,
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>, DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>,
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4> DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4>
// clang-format on // clang-format on
>; >;
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
...@@ -43,11 +43,11 @@ template <typename InDataType, ...@@ -43,11 +43,11 @@ template <typename InDataType,
using device_pool3d_fwd_ndhwc_instances = using device_pool3d_fwd_ndhwc_instances =
// clang-format off // clang-format off
std::tuple < std::tuple <
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>, DevicePool3dFwdImpl<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>,
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>, DevicePool3dFwdImpl<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>,
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4> DevicePool3dFwdImpl<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4>
// clang-format on // clang-format on
>; >;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment