"examples/vscode:/vscode.git/clone" did not exist on "e62dd5cfa8cf2d6c538366b0b9b85ce3fe613500"
Commit 8332414c authored by rocking's avatar rocking
Browse files

Expand the base class of pool2d, prepare to share base class with pool3d

parent 54c90aae
...@@ -13,8 +13,8 @@ namespace ck { ...@@ -13,8 +13,8 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <ck::ReduceTensorOp ReduceOpId> template <index_t NDimSpatial, ReduceTensorOp ReduceOpId>
struct DevicePool2dFwd : public BaseOperator struct DevicePoolFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* in_dev, MakeArgumentPointer(const void* in_dev,
...@@ -22,19 +22,16 @@ struct DevicePool2dFwd : public BaseOperator ...@@ -22,19 +22,16 @@ struct DevicePool2dFwd : public BaseOperator
void* out_indices_dev, void* out_indices_dev,
ck::index_t N, ck::index_t N,
ck::index_t C, ck::index_t C,
std::array<ck::index_t, 2> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, 2> window_spatial_lengths, std::array<ck::index_t, NDimSpatial> window_spatial_lengths,
std::array<ck::index_t, 2> output_spatial_lengths, std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, 2> window_strides, std::array<ck::index_t, NDimSpatial> window_strides,
std::array<ck::index_t, 2> input_left_pads, std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, 2> input_right_pads) = 0; std::array<ck::index_t, NDimSpatial> input_right_pads) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <ck::ReduceTensorOp ReduceOpId>
using DevicePool2dFwdPtr = std::unique_ptr<DevicePool2dFwd<ReduceOpId>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp" #include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -29,7 +29,7 @@ template <typename InDataType, ...@@ -29,7 +29,7 @@ template <typename InDataType,
ck::index_t ReduceMThreadSliceSize, ck::index_t ReduceMThreadSliceSize,
ck::index_t ReduceKThreadSliceSize, ck::index_t ReduceKThreadSliceSize,
ck::index_t InSrcOutDstVectorSize> ck::index_t InSrcOutDstVectorSize>
struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd<ReduceOpId> struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePoolFwd<2, ReduceOpId>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -141,8 +141,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -141,8 +141,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
} }
using ABGridDescs = decltype( using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(
MakeABGridDescriptor_A_M_K_B_M(1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>; using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>; using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment