"example/vscode:/vscode.git/clone" did not exist on "7a3b49e580671eedfe3aef48e4e673e393bc6999"
Commit b75abe7f authored by rocking's avatar rocking
Browse files

Refactor the base class. implement generic pooling in the future

parent 646f689c
...@@ -219,14 +219,13 @@ bool pool_test(bool do_verification, ...@@ -219,14 +219,13 @@ bool pool_test(bool do_verification,
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
N, {N, C, Hi, Wi},
C, {Y, X},
std::array<ck::index_t, 2>{{Hi, Wi}}, {N, C, Ho, Wo},
std::array<ck::index_t, 2>{{Y, X}},
std::array<ck::index_t, 2>{{Ho, Wo}},
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
{2, 3});
if(!pool.IsSupportedArgument(argument_ptr.get())) if(!pool.IsSupportedArgument(argument_ptr.get()))
{ {
......
...@@ -32,7 +32,8 @@ static void pool3d_host_verify(const Tensor<InDataType>& in, ...@@ -32,7 +32,8 @@ static void pool3d_host_verify(const Tensor<InDataType>& in,
const std::array<ck::index_t, 3>& in_left_pads, const std::array<ck::index_t, 3>& in_left_pads,
const std::array<ck::index_t, 3>& /*in_right_pads*/) const std::array<ck::index_t, 3>& /*in_right_pads*/)
{ {
const int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1]; const int32_t reduceLength =
window_spatial_lengths[0] * window_spatial_lengths[1] * window_spatial_lengths[2];
using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType; using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
...@@ -241,14 +242,13 @@ bool pool3d_test(bool do_verification, ...@@ -241,14 +242,13 @@ bool pool3d_test(bool do_verification,
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
N, {N, C, Di, Hi, Wi},
C, {Z, Y, X},
std::array<ck::index_t, 3>{{Di, Hi, Wi}}, {N, C, Do, Ho, Wo},
std::array<ck::index_t, 3>{{Z, Y, X}},
std::array<ck::index_t, 3>{{Do, Ho, Wo}},
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
{2, 3, 4});
if(!pool.IsSupportedArgument(argument_ptr.get())) if(!pool.IsSupportedArgument(argument_ptr.get()))
{ {
......
...@@ -13,21 +13,20 @@ namespace ck { ...@@ -13,21 +13,20 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t NDimSpatial, ReduceTensorOp ReduceOpId> template <index_t InOutRank, index_t WindowRank, ReduceTensorOp ReduceOpId>
struct DevicePoolFwd : public BaseOperator struct DevicePoolFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* in_dev, MakeArgumentPointer(const void* p_in_dev,
void* out_dev, void* p_out_dev,
void* out_indices_dev, void* p_out_indices_dev,
ck::index_t N, std::array<ck::index_t, InOutRank> input_lengths,
ck::index_t C, std::array<ck::index_t, WindowRank> window_lengths,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, std::array<ck::index_t, InOutRank> output_lengths,
std::array<ck::index_t, NDimSpatial> window_spatial_lengths, std::array<ck::index_t, WindowRank> window_strides,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, std::array<ck::index_t, WindowRank> input_left_pads,
std::array<ck::index_t, NDimSpatial> window_strides, std::array<ck::index_t, WindowRank> input_right_pads,
std::array<ck::index_t, NDimSpatial> input_left_pads, std::array<ck::index_t, WindowRank> pooling_dims) = 0;
std::array<ck::index_t, NDimSpatial> input_right_pads) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -29,7 +29,7 @@ template <typename InDataType, ...@@ -29,7 +29,7 @@ template <typename InDataType,
ck::index_t ReduceMThreadSliceSize, ck::index_t ReduceMThreadSliceSize,
ck::index_t ReduceKThreadSliceSize, ck::index_t ReduceKThreadSliceSize,
ck::index_t InSrcOutDstVectorSize> ck::index_t InSrcOutDstVectorSize>
struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePoolFwd<2, ReduceOpId> struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePoolFwd<5, 2, ReduceOpId>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -38,6 +38,9 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePoolFwd<2 ...@@ -38,6 +38,9 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePoolFwd<2
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr index_t InOutRank = 5;
static constexpr index_t WindowRank = 2;
using IndexDataType = int32_t; using IndexDataType = int32_t;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
...@@ -57,14 +60,15 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePoolFwd<2 ...@@ -57,14 +60,15 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePoolFwd<2
static constexpr ck::index_t ReduceK_BlockTileSize = static constexpr ck::index_t ReduceK_BlockTileSize =
ReduceKThreadClusterSize * ReduceKThreadSliceSize; ReduceKThreadClusterSize * ReduceKThreadSliceSize;
static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N, static auto
MakeABGridDescriptor_A_M_K_B_M(ck::index_t N,
ck::index_t C, ck::index_t C,
std::array<ck::index_t, 2> input_spatial_lengths, std::array<ck::index_t, WindowRank> input_spatial_lengths,
std::array<ck::index_t, 2> window_spatial_lengths, std::array<ck::index_t, WindowRank> window_spatial_lengths,
std::array<ck::index_t, 2> output_spatial_lengths, std::array<ck::index_t, WindowRank> output_spatial_lengths,
std::array<ck::index_t, 2> window_strides, std::array<ck::index_t, WindowRank> window_strides,
std::array<ck::index_t, 2> input_left_pads, std::array<ck::index_t, WindowRank> input_left_pads,
std::array<ck::index_t, 2> input_right_pads) std::array<ck::index_t, WindowRank> input_right_pads)
{ {
const index_t Hi = input_spatial_lengths[0]; const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1]; const index_t Wi = input_spatial_lengths[1];
...@@ -155,12 +159,12 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePoolFwd<2 ...@@ -155,12 +159,12 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePoolFwd<2
int* p_out_indices_dev, int* p_out_indices_dev,
ck::index_t N, ck::index_t N,
ck::index_t C, ck::index_t C,
std::array<ck::index_t, 2>& input_spatial_lengths, std::array<ck::index_t, WindowRank>& input_spatial_lengths,
std::array<ck::index_t, 2>& window_spatial_lengths, std::array<ck::index_t, WindowRank>& window_spatial_lengths,
std::array<ck::index_t, 2>& output_spatial_lengths, std::array<ck::index_t, WindowRank>& output_spatial_lengths,
std::array<ck::index_t, 2>& window_strides, std::array<ck::index_t, WindowRank>& window_strides,
std::array<ck::index_t, 2>& input_left_pads, std::array<ck::index_t, WindowRank>& input_left_pads,
std::array<ck::index_t, 2>& input_right_pads) std::array<ck::index_t, WindowRank>& input_right_pads)
: p_in_dev_{p_in_dev}, : p_in_dev_{p_in_dev},
p_out_dev_{p_out_dev}, p_out_dev_{p_out_dev},
p_out_indices_dev_{p_out_indices_dev}, p_out_indices_dev_{p_out_indices_dev},
...@@ -280,22 +284,31 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePoolFwd<2 ...@@ -280,22 +284,31 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePoolFwd<2
MakeArgumentPointer(const void* p_in_dev, MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev, void* p_out_dev,
void* p_out_indices_dev, void* p_out_indices_dev,
ck::index_t N, std::array<ck::index_t, InOutRank> input_lengths,
ck::index_t C, std::array<ck::index_t, WindowRank> window_lengths,
std::array<ck::index_t, 2> input_spatial_lengths, std::array<ck::index_t, InOutRank> output_lengths,
std::array<ck::index_t, 2> window_spatial_lengths, std::array<ck::index_t, WindowRank> window_strides,
std::array<ck::index_t, 2> output_spatial_lengths, std::array<ck::index_t, WindowRank> input_left_pads,
std::array<ck::index_t, 2> window_strides, std::array<ck::index_t, WindowRank> input_right_pads,
std::array<ck::index_t, 2> input_left_pads, std::array<ck::index_t, WindowRank>) override
std::array<ck::index_t, 2> input_right_pads) override
{ {
index_t N = input_lengths[0];
index_t C = input_lengths[1];
index_t Hi = input_lengths[2];
index_t Wi = input_lengths[3];
index_t Ho = output_lengths[2];
index_t Wo = output_lengths[3];
std::array<ck::index_t, WindowRank> input_spatial_lengths = {Hi, Wi};
std::array<ck::index_t, WindowRank> output_spatial_lengths = {Ho, Wo};
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
static_cast<OutDataType*>(p_out_dev), static_cast<OutDataType*>(p_out_dev),
static_cast<int*>(p_out_indices_dev), static_cast<int*>(p_out_indices_dev),
N, N,
C, C,
input_spatial_lengths, input_spatial_lengths,
window_spatial_lengths, window_lengths,
output_spatial_lengths, output_spatial_lengths,
window_strides, window_strides,
input_left_pads, input_left_pads,
......
...@@ -29,7 +29,8 @@ template <typename InDataType, ...@@ -29,7 +29,8 @@ template <typename InDataType,
ck::index_t MThreadSliceSize, ck::index_t MThreadSliceSize,
ck::index_t KThreadSliceSize, ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize> ck::index_t InSrcOutDstVectorSize>
struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C : public DevicePoolFwd<3, ReduceOpId> struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
: public DevicePoolFwd<5, 3, ReduceOpId>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -38,6 +39,9 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C : public DevicePoo ...@@ -38,6 +39,9 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C : public DevicePoo
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr index_t InOutRank = 5;
static constexpr index_t WindowRank = 3;
using IndexDataType = int32_t; using IndexDataType = int32_t;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
...@@ -55,14 +59,15 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C : public DevicePoo ...@@ -55,14 +59,15 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C : public DevicePoo
static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N, static auto
MakeABGridDescriptor_A_M_K_B_M(ck::index_t N,
ck::index_t C, ck::index_t C,
std::array<ck::index_t, 3> input_spatial_lengths, std::array<ck::index_t, WindowRank> input_spatial_lengths,
std::array<ck::index_t, 3> window_spatial_lengths, std::array<ck::index_t, WindowRank> window_spatial_lengths,
std::array<ck::index_t, 3> output_spatial_lengths, std::array<ck::index_t, WindowRank> output_spatial_lengths,
std::array<ck::index_t, 3> window_strides, std::array<ck::index_t, WindowRank> window_strides,
std::array<ck::index_t, 3> input_left_pads, std::array<ck::index_t, WindowRank> input_left_pads,
std::array<ck::index_t, 3> input_right_pads) std::array<ck::index_t, WindowRank> input_right_pads)
{ {
const index_t Di = input_spatial_lengths[0]; const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1]; const index_t Hi = input_spatial_lengths[1];
...@@ -161,12 +166,12 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C : public DevicePoo ...@@ -161,12 +166,12 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C : public DevicePoo
int* p_out_indices_dev, int* p_out_indices_dev,
ck::index_t N, ck::index_t N,
ck::index_t C, ck::index_t C,
std::array<ck::index_t, 3>& input_spatial_lengths, std::array<ck::index_t, WindowRank>& input_spatial_lengths,
std::array<ck::index_t, 3>& window_spatial_lengths, std::array<ck::index_t, WindowRank>& window_spatial_lengths,
std::array<ck::index_t, 3>& output_spatial_lengths, std::array<ck::index_t, WindowRank>& output_spatial_lengths,
std::array<ck::index_t, 3>& window_strides, std::array<ck::index_t, WindowRank>& window_strides,
std::array<ck::index_t, 3>& input_left_pads, std::array<ck::index_t, WindowRank>& input_left_pads,
std::array<ck::index_t, 3>& input_right_pads) std::array<ck::index_t, WindowRank>& input_right_pads)
: p_in_dev_{p_in_dev}, : p_in_dev_{p_in_dev},
p_out_dev_{p_out_dev}, p_out_dev_{p_out_dev},
p_out_indices_dev_{p_out_indices_dev}, p_out_indices_dev_{p_out_indices_dev},
...@@ -285,22 +290,33 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C : public DevicePoo ...@@ -285,22 +290,33 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C : public DevicePoo
MakeArgumentPointer(const void* p_in_dev, MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev, void* p_out_dev,
void* p_out_indices_dev, void* p_out_indices_dev,
ck::index_t N, std::array<ck::index_t, InOutRank> input_lengths,
ck::index_t C, std::array<ck::index_t, WindowRank> window_lengths,
std::array<ck::index_t, 3> input_spatial_lengths, std::array<ck::index_t, InOutRank> output_lengths,
std::array<ck::index_t, 3> window_spatial_lengths, std::array<ck::index_t, WindowRank> window_strides,
std::array<ck::index_t, 3> output_spatial_lengths, std::array<ck::index_t, WindowRank> input_left_pads,
std::array<ck::index_t, 3> window_strides, std::array<ck::index_t, WindowRank> input_right_pads,
std::array<ck::index_t, 3> input_left_pads, std::array<ck::index_t, WindowRank>) override
std::array<ck::index_t, 3> input_right_pads) override
{ {
index_t N = input_lengths[0];
index_t C = input_lengths[1];
index_t Di = input_lengths[2];
index_t Hi = input_lengths[3];
index_t Wi = input_lengths[4];
index_t Do = output_lengths[2];
index_t Ho = output_lengths[3];
index_t Wo = output_lengths[4];
std::array<ck::index_t, WindowRank> input_spatial_lengths = {Di, Hi, Wi};
std::array<ck::index_t, WindowRank> output_spatial_lengths = {Do, Ho, Wo};
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
static_cast<OutDataType*>(p_out_dev), static_cast<OutDataType*>(p_out_dev),
static_cast<int*>(p_out_indices_dev), static_cast<int*>(p_out_indices_dev),
N, N,
C, C,
input_spatial_lengths, input_spatial_lengths,
window_spatial_lengths, window_lengths,
output_spatial_lengths, output_spatial_lengths,
window_strides, window_strides,
input_left_pads, input_left_pads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment