Commit 8332414c authored by rocking's avatar rocking
Browse files

Expand the base class of pool2d, prepare to share base class with pool3d

parent 54c90aae
......@@ -13,8 +13,8 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <ck::ReduceTensorOp ReduceOpId>
struct DevicePool2dFwd : public BaseOperator
template <index_t NDimSpatial, ReduceTensorOp ReduceOpId>
struct DevicePoolFwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* in_dev,
......@@ -22,19 +22,16 @@ struct DevicePool2dFwd : public BaseOperator
void* out_indices_dev,
ck::index_t N,
ck::index_t C,
std::array<ck::index_t, 2> input_spatial_lengths,
std::array<ck::index_t, 2> window_spatial_lengths,
std::array<ck::index_t, 2> output_spatial_lengths,
std::array<ck::index_t, 2> window_strides,
std::array<ck::index_t, 2> input_left_pads,
std::array<ck::index_t, 2> input_right_pads) = 0;
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> window_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> window_strides,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <ck::ReduceTensorOp ReduceOpId>
using DevicePool2dFwdPtr = std::unique_ptr<DevicePool2dFwd<ReduceOpId>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -29,7 +29,7 @@ template <typename InDataType,
ck::index_t ReduceMThreadSliceSize,
ck::index_t ReduceKThreadSliceSize,
ck::index_t InSrcOutDstVectorSize>
struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd<ReduceOpId>
struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePoolFwd<2, ReduceOpId>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -141,8 +141,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
}
using ABGridDescs = decltype(
MakeABGridDescriptor_A_M_K_B_M(1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(
1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment