Commit 158ec146 authored by rocking's avatar rocking
Browse files

Add type to base class

parent 2280e5bb
...@@ -13,7 +13,13 @@ namespace ck { ...@@ -13,7 +13,13 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t InOutRank, index_t WindowRank, ReduceTensorOp ReduceOpId, bool OuputIndex> template <index_t InOutRank,
index_t WindowRank,
typename InDataType,
typename OutDataType,
typename IndexDataType,
ReduceTensorOp ReduceOpId,
bool OuputIndex>
struct DevicePoolFwd : public BaseOperator struct DevicePoolFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
......
...@@ -31,7 +31,7 @@ template <typename InDataType, ...@@ -31,7 +31,7 @@ template <typename InDataType,
ck::index_t ReduceKThreadSliceSize, ck::index_t ReduceKThreadSliceSize,
ck::index_t InSrcOutDstVectorSize> ck::index_t InSrcOutDstVectorSize>
struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
: public DevicePoolFwd<4, 2, ReduceOpId, OuputIndex> : public DevicePoolFwd<4, 2, InDataType, OutDataType, IndexDataType, ReduceOpId, OuputIndex>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
......
...@@ -31,7 +31,7 @@ template <typename InDataType, ...@@ -31,7 +31,7 @@ template <typename InDataType,
ck::index_t KThreadSliceSize, ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize> ck::index_t InSrcOutDstVectorSize>
struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
: public DevicePoolFwd<5, 3, ReduceOpId, OuputIndex> : public DevicePoolFwd<5, 3, InDataType, OutDataType, IndexDataType, ReduceOpId, OuputIndex>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
......
...@@ -11,7 +11,7 @@ namespace instance { ...@@ -11,7 +11,7 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pooling2d_fwd_nhwc_f16_instances( void add_device_pooling2d_fwd_nhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{}); instances, device_pooling2d_fwd_nhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{});
......
...@@ -11,7 +11,7 @@ namespace instance { ...@@ -11,7 +11,7 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pooling2d_fwd_nhwc_f32_instances( void add_device_pooling2d_fwd_nhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{}); instances, device_pooling2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
......
...@@ -11,7 +11,7 @@ namespace instance { ...@@ -11,7 +11,7 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pooling3d_fwd_ndhwc_f16_instances( void add_device_pooling3d_fwd_ndhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{}); instances, device_pooling3d_fwd_ndhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{});
......
...@@ -11,7 +11,7 @@ namespace instance { ...@@ -11,7 +11,7 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pooling3d_fwd_ndhwc_f32_instances( void add_device_pooling3d_fwd_ndhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{}); instances, device_pooling3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
......
...@@ -11,14 +11,14 @@ namespace instance { ...@@ -11,14 +11,14 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pooling2d_fwd_nhwc_f16_instances( void add_device_pooling2d_fwd_nhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F16, F16, I32, F16, ReduceOpId, false>{}); instances, device_pooling2d_fwd_nhwc_instances<F16, F16, I32, F16, ReduceOpId, false>{});
} }
void add_device_pooling2d_fwd_nhwc_index_f16_instances( void add_device_pooling2d_fwd_nhwc_index_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, true>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, ReduceOpId, true>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F16, F16, I32, F16, ReduceOpId, true>{}); instances, device_pooling2d_fwd_nhwc_instances<F16, F16, I32, F16, ReduceOpId, true>{});
......
...@@ -11,14 +11,14 @@ namespace instance { ...@@ -11,14 +11,14 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pooling2d_fwd_nhwc_f32_instances( void add_device_pooling2d_fwd_nhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{}); instances, device_pooling2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
} }
void add_device_pooling2d_fwd_nhwc_index_f32_instances( void add_device_pooling2d_fwd_nhwc_index_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, true>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, ReduceOpId, true>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, true>{}); instances, device_pooling2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, true>{});
......
...@@ -11,14 +11,14 @@ namespace instance { ...@@ -11,14 +11,14 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pooling3d_fwd_ndhwc_f16_instances( void add_device_pooling3d_fwd_ndhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F16, F16, I32, F16, ReduceOpId, false>{}); instances, device_pooling3d_fwd_ndhwc_instances<F16, F16, I32, F16, ReduceOpId, false>{});
} }
void add_device_pooling3d_fwd_ndhwc_index_f16_instances( void add_device_pooling3d_fwd_ndhwc_index_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, true>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, ReduceOpId, true>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F16, F16, I32, F16, ReduceOpId, true>{}); instances, device_pooling3d_fwd_ndhwc_instances<F16, F16, I32, F16, ReduceOpId, true>{});
......
...@@ -11,14 +11,14 @@ namespace instance { ...@@ -11,14 +11,14 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pooling3d_fwd_ndhwc_f32_instances( void add_device_pooling3d_fwd_ndhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{}); instances, device_pooling3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
} }
void add_device_pooling3d_fwd_ndhwc_index_f32_instances( void add_device_pooling3d_fwd_ndhwc_index_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, true>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, ReduceOpId, true>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, true>{}); instances, device_pooling3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, true>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment